Skip to content

Commit eb3a91d

Browse files
committed
Parallel Tool Execution
Closes gh-4254 Signed-off-by: Rafael Cunha <[email protected]>
1 parent ad2e1bc commit eb3a91d

File tree

1 file changed

+83
-59
lines changed

1 file changed

+83
-59
lines changed

spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java

Lines changed: 83 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
import java.util.HashMap;
2121
import java.util.List;
2222
import java.util.Map;
23+
import java.util.Queue;
2324
import java.util.Optional;
25+
import java.util.concurrent.CompletableFuture;
26+
import java.util.concurrent.ConcurrentLinkedDeque;
2427

2528
import io.micrometer.observation.ObservationRegistry;
2629
import org.slf4j.Logger;
@@ -44,6 +47,10 @@
4447
import org.springframework.ai.tool.observation.ToolCallingObservationDocumentation;
4548
import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver;
4649
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
50+
import org.springframework.core.task.TaskExecutor;
51+
import org.springframework.core.task.support.ContextPropagatingTaskDecorator;
52+
import org.springframework.lang.Nullable;
53+
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
4754
import org.springframework.util.Assert;
4855
import org.springframework.util.CollectionUtils;
4956

@@ -71,6 +78,8 @@ public final class DefaultToolCallingManager implements ToolCallingManager {
7178
private static final ToolExecutionExceptionProcessor DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR
7279
= DefaultToolExecutionExceptionProcessor.builder().build();
7380

81+
private static final TaskExecutor DEFAULT_TASK_EXECUTOR = buildDefaultTaskExecutor();
82+
7483
// @formatter:on
7584

7685
private final ObservationRegistry observationRegistry;
@@ -79,17 +88,20 @@ public final class DefaultToolCallingManager implements ToolCallingManager {
7988

8089
private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor;
8190

91+
private final TaskExecutor taskExecutor;
92+
8293
private ToolCallingObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
8394

8495
public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCallbackResolver toolCallbackResolver,
85-
ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) {
96+
ToolExecutionExceptionProcessor toolExecutionExceptionProcessor, @Nullable TaskExecutor taskExecutor) {
8697
Assert.notNull(observationRegistry, "observationRegistry cannot be null");
8798
Assert.notNull(toolCallbackResolver, "toolCallbackResolver cannot be null");
8899
Assert.notNull(toolExecutionExceptionProcessor, "toolCallExceptionConverter cannot be null");
89100

90101
this.observationRegistry = observationRegistry;
91102
this.toolCallbackResolver = toolCallbackResolver;
92103
this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor;
104+
this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor();
93105
}
94106

95107
@Override
@@ -173,64 +185,59 @@ private static List<Message> buildConversationHistoryBeforeToolExecution(Prompt
173185
*/
174186
private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMessage assistantMessage,
175187
ToolContext toolContext) {
176-
List<ToolCallback> toolCallbacks = List.of();
177-
if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
178-
toolCallbacks = toolCallingChatOptions.getToolCallbacks();
179-
}
180-
181-
List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();
182-
183-
Boolean returnDirect = null;
184-
185-
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
186-
187-
logger.debug("Executing tool call: {}", toolCall.name());
188-
189-
String toolName = toolCall.name();
190-
String toolInputArguments = toolCall.arguments();
191-
192-
ToolCallback toolCallback = toolCallbacks.stream()
193-
.filter(tool -> toolName.equals(tool.getToolDefinition().name()))
194-
.findFirst()
195-
.orElseGet(() -> this.toolCallbackResolver.resolve(toolName));
196-
197-
if (toolCallback == null) {
198-
throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
199-
}
188+
final List<ToolCallback> toolCallbacks = (prompt
189+
.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions)
190+
? toolCallingChatOptions.getToolCallbacks() : List.of();
200191

201-
if (returnDirect == null) {
202-
returnDirect = toolCallback.getToolMetadata().returnDirect();
203-
}
204-
else {
205-
returnDirect = returnDirect && toolCallback.getToolMetadata().returnDirect();
206-
}
207-
208-
ToolCallingObservationContext observationContext = ToolCallingObservationContext.builder()
209-
.toolDefinition(toolCallback.getToolDefinition())
210-
.toolMetadata(toolCallback.getToolMetadata())
211-
.toolCallArguments(toolInputArguments)
212-
.build();
213-
214-
String toolCallResult = ToolCallingObservationDocumentation.TOOL_CALL
215-
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
216-
this.observationRegistry)
217-
.observe(() -> {
218-
String toolResult;
219-
try {
220-
toolResult = toolCallback.call(toolInputArguments, toolContext);
221-
}
222-
catch (ToolExecutionException ex) {
223-
toolResult = this.toolExecutionExceptionProcessor.process(ex);
224-
}
225-
observationContext.setToolCallResult(toolResult);
226-
return toolResult;
227-
});
228-
229-
toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName,
230-
toolCallResult != null ? toolCallResult : ""));
231-
}
232-
233-
return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), returnDirect);
192+
final Queue<Boolean> toolsReturnDirect = new ConcurrentLinkedDeque<>();
193+
List<ToolResponseMessage.ToolResponse> toolResponses = assistantMessage.getToolCalls()
194+
.stream()
195+
.map(toolCall -> CompletableFuture.supplyAsync(() -> {
196+
logger.debug("Executing tool call: {}", toolCall.name());
197+
198+
String toolName = toolCall.name();
199+
String toolInputArguments = toolCall.arguments();
200+
201+
ToolCallback toolCallback = toolCallbacks.stream()
202+
.filter(tool -> toolName.equals(tool.getToolDefinition().name()))
203+
.findFirst()
204+
.orElseGet(() -> this.toolCallbackResolver.resolve(toolName));
205+
206+
if (toolCallback == null) {
207+
throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
208+
}
209+
210+
toolsReturnDirect.add(toolCallback.getToolMetadata().returnDirect());
211+
212+
ToolCallingObservationContext observationContext = ToolCallingObservationContext.builder()
213+
.toolDefinition(toolCallback.getToolDefinition())
214+
.toolMetadata(toolCallback.getToolMetadata())
215+
.toolCallArguments(toolInputArguments)
216+
.build();
217+
218+
String toolCallResult = ToolCallingObservationDocumentation.TOOL_CALL
219+
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
220+
this.observationRegistry)
221+
.observe(() -> {
222+
String toolResult;
223+
try {
224+
toolResult = toolCallback.call(toolInputArguments, toolContext);
225+
}
226+
catch (ToolExecutionException ex) {
227+
toolResult = this.toolExecutionExceptionProcessor.process(ex);
228+
}
229+
observationContext.setToolCallResult(toolResult);
230+
return toolResult;
231+
});
232+
233+
return new ToolResponseMessage.ToolResponse(toolCall.id(), toolName,
234+
toolCallResult != null ? toolCallResult : "");
235+
}, this.taskExecutor))
236+
.map(CompletableFuture::join)
237+
.toList();
238+
239+
return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()),
240+
toolsReturnDirect.stream().allMatch(Boolean::booleanValue));
234241
}
235242

236243
private List<Message> buildConversationHistoryAfterToolExecution(List<Message> previousMessages,
@@ -245,6 +252,16 @@ public void setObservationConvention(ToolCallingObservationConvention observatio
245252
this.observationConvention = observationConvention;
246253
}
247254

255+
private static TaskExecutor buildDefaultTaskExecutor() {
256+
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
257+
taskExecutor.setThreadNamePrefix("ai-toll-calling-");
258+
taskExecutor.setCorePoolSize(4);
259+
taskExecutor.setMaxPoolSize(16);
260+
taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator());
261+
taskExecutor.initialize();
262+
return taskExecutor;
263+
}
264+
248265
public static Builder builder() {
249266
return new Builder();
250267
}
@@ -260,6 +277,8 @@ public final static class Builder {
260277

261278
private ToolExecutionExceptionProcessor toolExecutionExceptionProcessor = DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR;
262279

280+
private TaskExecutor taskExecutor = DEFAULT_TASK_EXECUTOR;
281+
263282
private Builder() {
264283
}
265284

@@ -279,9 +298,14 @@ public Builder toolExecutionExceptionProcessor(
279298
return this;
280299
}
281300

301+
public Builder taskExecutor(TaskExecutor taskExecutor) {
302+
this.taskExecutor = taskExecutor;
303+
return this;
304+
}
305+
282306
public DefaultToolCallingManager build() {
283307
return new DefaultToolCallingManager(this.observationRegistry, this.toolCallbackResolver,
284-
this.toolExecutionExceptionProcessor);
308+
this.toolExecutionExceptionProcessor, taskExecutor);
285309
}
286310

287311
}

0 commit comments

Comments
 (0)