Skip to content

Commit c13f222

Browse files
committed
Code review improvements
Closes gh-4254 Signed-off-by: Rafael Cunha <[email protected]>
1 parent bca6afd commit c13f222

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ public final class DefaultToolCallingManager implements ToolCallingManager {
7878
private static final ToolExecutionExceptionProcessor DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR
7979
= DefaultToolExecutionExceptionProcessor.builder().build();
8080

81-
private static final TaskExecutor DEFAULT_TASK_EXECUTOR = buildDefaultTaskExecutor();
82-
8381
// @formatter:on
8482

8583
private final ObservationRegistry observationRegistry;
@@ -101,7 +99,7 @@ public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCa
10199
this.observationRegistry = observationRegistry;
102100
this.toolCallbackResolver = toolCallbackResolver;
103101
this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor;
104-
this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor();
102+
this.taskExecutor = taskExecutor != null ? taskExecutor : this.buildDefaultTaskExecutor();
105103
}
106104

107105
@Override
@@ -190,7 +188,7 @@ private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMess
190188
? toolCallingChatOptions.getToolCallbacks() : List.of();
191189

192190
final Queue<Boolean> toolsReturnDirect = new ConcurrentLinkedDeque<>();
193-
List<ToolResponseMessage.ToolResponse> toolResponses = assistantMessage.getToolCalls()
191+
List<CompletableFuture<ToolResponseMessage.ToolResponse>> futuresToolResponses = assistantMessage.getToolCalls()
194192
.stream()
195193
.map(toolCall -> CompletableFuture.supplyAsync(() -> {
196194
logger.debug("Executing tool call: {}", toolCall.name());
@@ -233,9 +231,13 @@ private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMess
233231
return new ToolResponseMessage.ToolResponse(toolCall.id(), toolName,
234232
toolCallResult != null ? toolCallResult : "");
235233
}, this.taskExecutor))
236-
.map(CompletableFuture::join)
237234
.toList();
238235

236+
final List<ToolResponseMessage.ToolResponse> toolResponses = CompletableFuture
237+
.allOf(futuresToolResponses.toArray(new CompletableFuture[0]))
238+
.thenApply(result -> futuresToolResponses.stream().map(CompletableFuture::join).toList())
239+
.join();
240+
239241
return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()),
240242
toolsReturnDirect.stream().allMatch(Boolean::booleanValue));
241243
}
@@ -252,9 +254,9 @@ public void setObservationConvention(ToolCallingObservationConvention observatio
252254
this.observationConvention = observationConvention;
253255
}
254256

255-
private static TaskExecutor buildDefaultTaskExecutor() {
257+
private TaskExecutor buildDefaultTaskExecutor() {
256258
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
257-
taskExecutor.setThreadNamePrefix("ai-toll-calling-");
259+
taskExecutor.setThreadNamePrefix("ai-tool-calling-");
258260
taskExecutor.setCorePoolSize(4);
259261
taskExecutor.setMaxPoolSize(16);
260262
taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator());
@@ -277,7 +279,7 @@ public final static class Builder {
277279

278280
private ToolExecutionExceptionProcessor toolExecutionExceptionProcessor = DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR;
279281

280-
private TaskExecutor taskExecutor = DEFAULT_TASK_EXECUTOR;
282+
private TaskExecutor taskExecutor;
281283

282284
private Builder() {
283285
}

0 commit comments

Comments
 (0)