Skip to content

Commit 96c5358

Browse files
committed
fix(provieders): support parallel agent tool calls, consolidate utils
1 parent 41bb7e8 commit 96c5358

File tree

22 files changed

+1387
-1262
lines changed

22 files changed

+1387
-1262
lines changed

apps/sim/providers/anthropic/index.ts

Lines changed: 222 additions & 155 deletions
Large diffs are not rendered by default.

apps/sim/providers/azure-openai/index.ts

Lines changed: 96 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,11 @@ export const azureOpenAIProvider: ProviderConfig = {
303303
usedForcedTools = firstCheckResult.usedForcedTools
304304

305305
while (iterationCount < MAX_TOOL_ITERATIONS) {
306+
// Extract text content FIRST, before checking for tool calls
307+
if (currentResponse.choices[0]?.message?.content) {
308+
content = currentResponse.choices[0].message.content
309+
}
310+
306311
// Check for tool calls
307312
const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls
308313
if (!toolCallsInResponse || toolCallsInResponse.length === 0) {
@@ -316,85 +321,111 @@ export const azureOpenAIProvider: ProviderConfig = {
316321
// Track time for tool calls in this batch
317322
const toolsStartTime = Date.now()
318323

319-
// Process each tool call
320-
for (const toolCall of toolCallsInResponse) {
324+
// Execute all tool calls in parallel using Promise.allSettled for resilience
325+
const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => {
326+
const toolCallStartTime = Date.now()
327+
const toolName = toolCall.function.name
328+
321329
try {
322-
const toolName = toolCall.function.name
323330
const toolArgs = JSON.parse(toolCall.function.arguments)
324-
325-
// Get the tool from the tools registry
326331
const tool = request.tools?.find((t) => t.id === toolName)
327-
if (!tool) continue
328332

329-
// Execute the tool
330-
const toolCallStartTime = Date.now()
333+
if (!tool) return null
331334

332335
const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request)
333-
334336
const result = await executeTool(toolName, executionParams, true)
335337
const toolCallEndTime = Date.now()
336-
const toolCallDuration = toolCallEndTime - toolCallStartTime
337338

338-
// Add to time segments for both success and failure
339-
timeSegments.push({
340-
type: 'tool',
341-
name: toolName,
339+
return {
340+
toolCall,
341+
toolName,
342+
toolParams,
343+
result,
342344
startTime: toolCallStartTime,
343345
endTime: toolCallEndTime,
344-
duration: toolCallDuration,
345-
})
346-
347-
// Prepare result content for the LLM
348-
let resultContent: any
349-
if (result.success) {
350-
toolResults.push(result.output)
351-
resultContent = result.output
352-
} else {
353-
// Include error information so LLM can respond appropriately
354-
resultContent = {
355-
error: true,
356-
message: result.error || 'Tool execution failed',
357-
tool: toolName,
358-
}
346+
duration: toolCallEndTime - toolCallStartTime,
359347
}
360-
361-
toolCalls.push({
362-
name: toolName,
363-
arguments: toolParams,
364-
startTime: new Date(toolCallStartTime).toISOString(),
365-
endTime: new Date(toolCallEndTime).toISOString(),
366-
duration: toolCallDuration,
367-
result: resultContent,
368-
success: result.success,
369-
})
370-
371-
// Add the tool call and result to messages (both success and failure)
372-
currentMessages.push({
373-
role: 'assistant',
374-
content: null,
375-
tool_calls: [
376-
{
377-
id: toolCall.id,
378-
type: 'function',
379-
function: {
380-
name: toolName,
381-
arguments: toolCall.function.arguments,
382-
},
383-
},
384-
],
385-
})
386-
387-
currentMessages.push({
388-
role: 'tool',
389-
tool_call_id: toolCall.id,
390-
content: JSON.stringify(resultContent),
391-
})
392348
} catch (error) {
393-
logger.error('Error processing tool call:', {
394-
error,
395-
toolName: toolCall?.function?.name,
396-
})
349+
const toolCallEndTime = Date.now()
350+
logger.error('Error processing tool call:', { error, toolName })
351+
352+
return {
353+
toolCall,
354+
toolName,
355+
toolParams: {},
356+
result: {
357+
success: false,
358+
output: undefined,
359+
error: error instanceof Error ? error.message : 'Tool execution failed',
360+
},
361+
startTime: toolCallStartTime,
362+
endTime: toolCallEndTime,
363+
duration: toolCallEndTime - toolCallStartTime,
364+
}
365+
}
366+
})
367+
368+
const executionResults = await Promise.allSettled(toolExecutionPromises)
369+
370+
// Add ONE assistant message with ALL tool calls BEFORE processing results
371+
currentMessages.push({
372+
role: 'assistant',
373+
content: null,
374+
tool_calls: toolCallsInResponse.map((tc) => ({
375+
id: tc.id,
376+
type: 'function',
377+
function: {
378+
name: tc.function.name,
379+
arguments: tc.function.arguments,
380+
},
381+
})),
382+
})
383+
384+
// Process results in order to maintain consistency
385+
for (const settledResult of executionResults) {
386+
if (settledResult.status === 'rejected' || !settledResult.value) continue
387+
388+
const { toolCall, toolName, toolParams, result, startTime, endTime, duration } =
389+
settledResult.value
390+
391+
// Add to time segments
392+
timeSegments.push({
393+
type: 'tool',
394+
name: toolName,
395+
startTime: startTime,
396+
endTime: endTime,
397+
duration: duration,
398+
})
399+
400+
// Prepare result content for the LLM
401+
let resultContent: any
402+
if (result.success) {
403+
toolResults.push(result.output)
404+
resultContent = result.output
405+
} else {
406+
resultContent = {
407+
error: true,
408+
message: result.error || 'Tool execution failed',
409+
tool: toolName,
410+
}
397411
}
412+
413+
toolCalls.push({
414+
name: toolName,
415+
arguments: toolParams,
416+
startTime: new Date(startTime).toISOString(),
417+
endTime: new Date(endTime).toISOString(),
418+
duration: duration,
419+
result: resultContent,
420+
success: result.success,
421+
})
422+
423+
// Add tool result message
424+
currentMessages.push({
425+
role: 'tool',
426+
tool_call_id: toolCall.id,
427+
content: JSON.stringify(resultContent),
428+
})
398429
}
399430

400431
// Calculate tool call time for this iteration
Lines changed: 18 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,37 @@
11
import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
22
import type { CompletionUsage } from 'openai/resources/completions'
33
import type { Stream } from 'openai/streaming'
4-
import { createLogger, type Logger } from '@/lib/logs/console/logger'
5-
import { trackForcedToolUsage } from '@/providers/utils'
6-
7-
const logger = createLogger('AzureOpenAIUtils')
4+
import type { Logger } from '@/lib/logs/console/logger'
5+
import { checkForForcedToolUsageOpenAI, createOpenAICompatibleStream } from '@/providers/utils'
86

7+
/**
8+
* Creates a ReadableStream from an Azure OpenAI streaming response.
9+
* Uses the shared OpenAI-compatible streaming utility.
10+
*/
911
export function createReadableStreamFromAzureOpenAIStream(
1012
azureOpenAIStream: Stream<ChatCompletionChunk>,
1113
onComplete?: (content: string, usage: CompletionUsage) => void
1214
): ReadableStream {
13-
let fullContent = ''
14-
let promptTokens = 0
15-
let completionTokens = 0
16-
let totalTokens = 0
17-
18-
return new ReadableStream({
19-
async start(controller) {
20-
try {
21-
for await (const chunk of azureOpenAIStream) {
22-
if (chunk.usage) {
23-
promptTokens = chunk.usage.prompt_tokens ?? 0
24-
completionTokens = chunk.usage.completion_tokens ?? 0
25-
totalTokens = chunk.usage.total_tokens ?? 0
26-
}
27-
28-
const content = chunk.choices[0]?.delta?.content || ''
29-
if (content) {
30-
fullContent += content
31-
controller.enqueue(new TextEncoder().encode(content))
32-
}
33-
}
34-
35-
if (onComplete) {
36-
if (promptTokens === 0 && completionTokens === 0) {
37-
logger.warn('Azure OpenAI stream completed without usage data')
38-
}
39-
onComplete(fullContent, {
40-
prompt_tokens: promptTokens,
41-
completion_tokens: completionTokens,
42-
total_tokens: totalTokens || promptTokens + completionTokens,
43-
})
44-
}
45-
46-
controller.close()
47-
} catch (error) {
48-
controller.error(error)
49-
}
50-
},
51-
})
15+
return createOpenAICompatibleStream(azureOpenAIStream, 'Azure OpenAI', onComplete)
5216
}
5317

5418
/**
55-
* Helper function to check for forced tool usage in responses
19+
* Checks if a forced tool was used in an Azure OpenAI response.
20+
* Uses the shared OpenAI-compatible forced tool usage helper.
5621
*/
5722
export function checkForForcedToolUsage(
5823
response: any,
5924
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any },
60-
logger: Logger,
25+
_logger: Logger,
6126
forcedTools: string[],
6227
usedForcedTools: string[]
6328
): { hasUsedForcedTool: boolean; usedForcedTools: string[] } {
64-
let hasUsedForcedTool = false
65-
let updatedUsedForcedTools = [...usedForcedTools]
66-
67-
if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) {
68-
const toolCallsResponse = response.choices[0].message.tool_calls
69-
const result = trackForcedToolUsage(
70-
toolCallsResponse,
71-
toolChoice,
72-
logger,
73-
'azure-openai',
74-
forcedTools,
75-
updatedUsedForcedTools
76-
)
77-
hasUsedForcedTool = result.hasUsedForcedTool
78-
updatedUsedForcedTools = result.usedForcedTools
79-
}
80-
81-
return { hasUsedForcedTool, usedForcedTools: updatedUsedForcedTools }
29+
return checkForForcedToolUsageOpenAI(
30+
response,
31+
toolChoice,
32+
'Azure OpenAI',
33+
forcedTools,
34+
usedForcedTools,
35+
_logger
36+
)
8237
}

0 commit comments

Comments
 (0)