groupToolUses.ts 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import type { BetaToolUseBlock } from '@anthropic-ai/sdk/resources/beta/messages/messages.mjs'
  2. import type { ContentBlockParam, ToolResultBlockParam } from '@anthropic-ai/sdk/resources/messages/messages.mjs'
  3. import type { Tools } from '../Tool.js'
  4. import type {
  5. GroupedToolUseMessage,
  6. NormalizedAssistantMessage,
  7. NormalizedMessage,
  8. NormalizedUserMessage,
  9. ProgressMessage,
  10. RenderableMessage,
  11. } from '../types/message.js'
  12. export type MessageWithoutProgress = Exclude<NormalizedMessage, ProgressMessage>
  13. export type GroupingResult = {
  14. messages: RenderableMessage[]
  15. }
  16. // Cache the set of tool names that support grouped rendering, keyed by the
  17. // tools array reference. The tools array is stable across renders (only
  18. // replaced on MCP connect/disconnect), so this avoids rebuilding the set on
  19. // every call. WeakMap lets old entries be GC'd when the array is replaced.
  20. const GROUPING_CACHE = new WeakMap<Tools, Set<string>>()
  21. function getToolsWithGrouping(tools: Tools): Set<string> {
  22. let cached = GROUPING_CACHE.get(tools)
  23. if (!cached) {
  24. cached = new Set(tools.filter(t => t.renderGroupedToolUse).map(t => t.name))
  25. GROUPING_CACHE.set(tools, cached)
  26. }
  27. return cached
  28. }
  29. function getToolUseInfo(
  30. msg: MessageWithoutProgress,
  31. ): { messageId: string; toolUseId: string; toolName: string } | null {
  32. if (msg.type === 'assistant' && msg.message?.content && Array.isArray(msg.message.content) && (msg.message.content[0] as { type?: string })?.type === 'tool_use') {
  33. const content = msg.message.content[0] as unknown as { type: 'tool_use'; id: string; name: string; [key: string]: unknown }
  34. return {
  35. messageId: msg.message.id as string,
  36. toolUseId: content.id,
  37. toolName: content.name,
  38. }
  39. }
  40. return null
  41. }
  42. /**
  43. * Groups tool uses by message.id (same API response) if the tool supports grouped rendering.
  44. * Only groups 2+ tools of the same type from the same message.
  45. * Also collects corresponding tool_results and attaches them to the grouped message.
  46. * When verbose is true, skips grouping so messages render at original positions.
  47. */
  48. export function applyGrouping(
  49. messages: MessageWithoutProgress[],
  50. tools: Tools,
  51. verbose: boolean = false,
  52. ): GroupingResult {
  53. // In verbose mode, don't group - each message renders at its original position
  54. if (verbose) {
  55. return {
  56. messages: messages as RenderableMessage[],
  57. }
  58. }
  59. const toolsWithGrouping = getToolsWithGrouping(tools)
  60. // First pass: group tool uses by message.id + tool name
  61. const groups = new Map<
  62. string,
  63. NormalizedAssistantMessage<BetaToolUseBlock>[]
  64. >()
  65. for (const msg of messages) {
  66. const info = getToolUseInfo(msg)
  67. if (info && toolsWithGrouping.has(info.toolName)) {
  68. const key = `${info.messageId}:${info.toolName}`
  69. const group = groups.get(key) ?? []
  70. group.push(msg as NormalizedAssistantMessage<BetaToolUseBlock>)
  71. groups.set(key, group)
  72. }
  73. }
  74. // Identify valid groups (2+ items) and collect their tool use IDs
  75. const validGroups = new Map<
  76. string,
  77. NormalizedAssistantMessage<BetaToolUseBlock>[]
  78. >()
  79. const groupedToolUseIds = new Set<string>()
  80. for (const [key, group] of groups) {
  81. if (group.length >= 2) {
  82. validGroups.set(key, group)
  83. for (const msg of group) {
  84. const info = getToolUseInfo(msg)
  85. if (info) {
  86. groupedToolUseIds.add(info.toolUseId)
  87. }
  88. }
  89. }
  90. }
  91. // Collect result messages for grouped tool_uses
  92. // Map from tool_use_id to the user message containing that result
  93. const resultsByToolUseId = new Map<string, NormalizedUserMessage>()
  94. for (const msg of messages) {
  95. if (msg.type === 'user' && msg.message?.content && Array.isArray(msg.message.content)) {
  96. for (const content of msg.message.content) {
  97. if (
  98. (content as { type?: string }).type === 'tool_result' &&
  99. groupedToolUseIds.has((content as { tool_use_id: string }).tool_use_id)
  100. ) {
  101. resultsByToolUseId.set((content as { tool_use_id: string }).tool_use_id, msg as NormalizedUserMessage)
  102. }
  103. }
  104. }
  105. }
  106. // Second pass: build output, emitting each group only once
  107. const result: RenderableMessage[] = []
  108. const emittedGroups = new Set<string>()
  109. for (const msg of messages) {
  110. const info = getToolUseInfo(msg)
  111. if (info) {
  112. const key = `${info.messageId}:${info.toolName}`
  113. const group = validGroups.get(key)
  114. if (group) {
  115. if (!emittedGroups.has(key)) {
  116. emittedGroups.add(key)
  117. const firstMsg = group[0]!
  118. // Collect results for this group
  119. const results: NormalizedUserMessage[] = []
  120. for (const assistantMsg of group) {
  121. const toolUseId = (
  122. assistantMsg.message.content[0] as { id: string }
  123. ).id
  124. const resultMsg = resultsByToolUseId.get(toolUseId)
  125. if (resultMsg) {
  126. results.push(resultMsg)
  127. }
  128. }
  129. const groupedMessage: GroupedToolUseMessage = {
  130. type: 'grouped_tool_use',
  131. toolName: info.toolName,
  132. messages: group,
  133. results,
  134. displayMessage: firstMsg,
  135. uuid: `grouped-${firstMsg.uuid}`,
  136. timestamp: firstMsg.timestamp,
  137. messageId: info.messageId,
  138. }
  139. result.push(groupedMessage)
  140. }
  141. continue
  142. }
  143. }
  144. // Skip user messages whose tool_results are all grouped
  145. if (msg.type === 'user' && msg.message?.content && Array.isArray(msg.message.content)) {
  146. const toolResults = (msg.message.content as Array<ContentBlockParam>).filter(
  147. (c): c is ToolResultBlockParam => c.type === 'tool_result',
  148. )
  149. if (toolResults.length > 0) {
  150. const allGrouped = toolResults.every(tr =>
  151. groupedToolUseIds.has(tr.tool_use_id),
  152. )
  153. if (allGrouped) {
  154. continue
  155. }
  156. }
  157. }
  158. result.push(msg as RenderableMessage)
  159. }
  160. return { messages: result }
  161. }