2020import java .util .HashMap ;
2121import java .util .List ;
2222import java .util .Map ;
23+ import java .util .Queue ;
2324import java .util .Optional ;
25+ import java .util .concurrent .CompletableFuture ;
26+ import java .util .concurrent .ConcurrentLinkedDeque ;
2427
2528import io .micrometer .observation .ObservationRegistry ;
2629import org .slf4j .Logger ;
4447import org .springframework .ai .tool .observation .ToolCallingObservationDocumentation ;
4548import org .springframework .ai .tool .resolution .DelegatingToolCallbackResolver ;
4649import org .springframework .ai .tool .resolution .ToolCallbackResolver ;
50+ import org .springframework .core .task .TaskExecutor ;
51+ import org .springframework .core .task .support .ContextPropagatingTaskDecorator ;
52+ import org .springframework .lang .Nullable ;
53+ import org .springframework .scheduling .concurrent .ThreadPoolTaskExecutor ;
4754import org .springframework .util .Assert ;
4855import org .springframework .util .CollectionUtils ;
4956
@@ -71,6 +78,8 @@ public final class DefaultToolCallingManager implements ToolCallingManager {
7178 private static final ToolExecutionExceptionProcessor DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR
7279 = DefaultToolExecutionExceptionProcessor .builder ().build ();
7380
81+ private static final TaskExecutor DEFAULT_TASK_EXECUTOR = buildDefaultTaskExecutor ();
82+
7483 // @formatter:on
7584
7685 private final ObservationRegistry observationRegistry ;
@@ -79,17 +88,20 @@ public final class DefaultToolCallingManager implements ToolCallingManager {
7988
8089 private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor ;
8190
91+ private final TaskExecutor taskExecutor ;
92+
8293 private ToolCallingObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION ;
8394
8495 public DefaultToolCallingManager (ObservationRegistry observationRegistry , ToolCallbackResolver toolCallbackResolver ,
85- ToolExecutionExceptionProcessor toolExecutionExceptionProcessor ) {
96+ ToolExecutionExceptionProcessor toolExecutionExceptionProcessor , @ Nullable TaskExecutor taskExecutor ) {
8697 Assert .notNull (observationRegistry , "observationRegistry cannot be null" );
8798 Assert .notNull (toolCallbackResolver , "toolCallbackResolver cannot be null" );
8899 Assert .notNull (toolExecutionExceptionProcessor , "toolCallExceptionConverter cannot be null" );
89100
90101 this .observationRegistry = observationRegistry ;
91102 this .toolCallbackResolver = toolCallbackResolver ;
92103 this .toolExecutionExceptionProcessor = toolExecutionExceptionProcessor ;
104+ this .taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor ();
93105 }
94106
95107 @ Override
@@ -173,64 +185,59 @@ private static List<Message> buildConversationHistoryBeforeToolExecution(Prompt
173185 */
174186 private InternalToolExecutionResult executeToolCall (Prompt prompt , AssistantMessage assistantMessage ,
175187 ToolContext toolContext ) {
176- List <ToolCallback > toolCallbacks = List .of ();
177- if (prompt .getOptions () instanceof ToolCallingChatOptions toolCallingChatOptions ) {
178- toolCallbacks = toolCallingChatOptions .getToolCallbacks ();
179- }
180-
181- List <ToolResponseMessage .ToolResponse > toolResponses = new ArrayList <>();
182-
183- Boolean returnDirect = null ;
184-
185- for (AssistantMessage .ToolCall toolCall : assistantMessage .getToolCalls ()) {
186-
187- logger .debug ("Executing tool call: {}" , toolCall .name ());
188-
189- String toolName = toolCall .name ();
190- String toolInputArguments = toolCall .arguments ();
191-
192- ToolCallback toolCallback = toolCallbacks .stream ()
193- .filter (tool -> toolName .equals (tool .getToolDefinition ().name ()))
194- .findFirst ()
195- .orElseGet (() -> this .toolCallbackResolver .resolve (toolName ));
196-
197- if (toolCallback == null ) {
198- throw new IllegalStateException ("No ToolCallback found for tool name: " + toolName );
199- }
188+ final List <ToolCallback > toolCallbacks = (prompt
189+ .getOptions () instanceof ToolCallingChatOptions toolCallingChatOptions )
190+ ? toolCallingChatOptions .getToolCallbacks () : List .of ();
200191
201- if (returnDirect == null ) {
202- returnDirect = toolCallback .getToolMetadata ().returnDirect ();
203- }
204- else {
205- returnDirect = returnDirect && toolCallback .getToolMetadata ().returnDirect ();
206- }
207-
208- ToolCallingObservationContext observationContext = ToolCallingObservationContext .builder ()
209- .toolDefinition (toolCallback .getToolDefinition ())
210- .toolMetadata (toolCallback .getToolMetadata ())
211- .toolCallArguments (toolInputArguments )
212- .build ();
213-
214- String toolCallResult = ToolCallingObservationDocumentation .TOOL_CALL
215- .observation (this .observationConvention , DEFAULT_OBSERVATION_CONVENTION , () -> observationContext ,
216- this .observationRegistry )
217- .observe (() -> {
218- String toolResult ;
219- try {
220- toolResult = toolCallback .call (toolInputArguments , toolContext );
221- }
222- catch (ToolExecutionException ex ) {
223- toolResult = this .toolExecutionExceptionProcessor .process (ex );
224- }
225- observationContext .setToolCallResult (toolResult );
226- return toolResult ;
227- });
228-
229- toolResponses .add (new ToolResponseMessage .ToolResponse (toolCall .id (), toolName ,
230- toolCallResult != null ? toolCallResult : "" ));
231- }
232-
233- return new InternalToolExecutionResult (new ToolResponseMessage (toolResponses , Map .of ()), returnDirect );
192+ final Queue <Boolean > toolsReturnDirect = new ConcurrentLinkedDeque <>();
193+ List <ToolResponseMessage .ToolResponse > toolResponses = assistantMessage .getToolCalls ()
194+ .stream ()
195+ .map (toolCall -> CompletableFuture .supplyAsync (() -> {
196+ logger .debug ("Executing tool call: {}" , toolCall .name ());
197+
198+ String toolName = toolCall .name ();
199+ String toolInputArguments = toolCall .arguments ();
200+
201+ ToolCallback toolCallback = toolCallbacks .stream ()
202+ .filter (tool -> toolName .equals (tool .getToolDefinition ().name ()))
203+ .findFirst ()
204+ .orElseGet (() -> this .toolCallbackResolver .resolve (toolName ));
205+
206+ if (toolCallback == null ) {
207+ throw new IllegalStateException ("No ToolCallback found for tool name: " + toolName );
208+ }
209+
210+ toolsReturnDirect .add (toolCallback .getToolMetadata ().returnDirect ());
211+
212+ ToolCallingObservationContext observationContext = ToolCallingObservationContext .builder ()
213+ .toolDefinition (toolCallback .getToolDefinition ())
214+ .toolMetadata (toolCallback .getToolMetadata ())
215+ .toolCallArguments (toolInputArguments )
216+ .build ();
217+
218+ String toolCallResult = ToolCallingObservationDocumentation .TOOL_CALL
219+ .observation (this .observationConvention , DEFAULT_OBSERVATION_CONVENTION , () -> observationContext ,
220+ this .observationRegistry )
221+ .observe (() -> {
222+ String toolResult ;
223+ try {
224+ toolResult = toolCallback .call (toolInputArguments , toolContext );
225+ }
226+ catch (ToolExecutionException ex ) {
227+ toolResult = this .toolExecutionExceptionProcessor .process (ex );
228+ }
229+ observationContext .setToolCallResult (toolResult );
230+ return toolResult ;
231+ });
232+
233+ return new ToolResponseMessage .ToolResponse (toolCall .id (), toolName ,
234+ toolCallResult != null ? toolCallResult : "" );
235+ }, this .taskExecutor ))
236+ .map (CompletableFuture ::join )
237+ .toList ();
238+
239+ return new InternalToolExecutionResult (new ToolResponseMessage (toolResponses , Map .of ()),
240+ toolsReturnDirect .stream ().allMatch (Boolean ::booleanValue ));
234241 }
235242
236243 private List <Message > buildConversationHistoryAfterToolExecution (List <Message > previousMessages ,
@@ -245,6 +252,16 @@ public void setObservationConvention(ToolCallingObservationConvention observatio
245252 this .observationConvention = observationConvention ;
246253 }
247254
255+ private static TaskExecutor buildDefaultTaskExecutor () {
256+ ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor ();
257+ taskExecutor .setThreadNamePrefix ("ai-toll-calling-" );
258+ taskExecutor .setCorePoolSize (4 );
259+ taskExecutor .setMaxPoolSize (16 );
260+ taskExecutor .setTaskDecorator (new ContextPropagatingTaskDecorator ());
261+ taskExecutor .initialize ();
262+ return taskExecutor ;
263+ }
264+
248265 public static Builder builder () {
249266 return new Builder ();
250267 }
@@ -260,6 +277,8 @@ public final static class Builder {
260277
261278 private ToolExecutionExceptionProcessor toolExecutionExceptionProcessor = DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR ;
262279
280+ private TaskExecutor taskExecutor = DEFAULT_TASK_EXECUTOR ;
281+
263282 private Builder () {
264283 }
265284
@@ -279,9 +298,14 @@ public Builder toolExecutionExceptionProcessor(
279298 return this ;
280299 }
281300
301+ public Builder taskExecutor (TaskExecutor taskExecutor ) {
302+ this .taskExecutor = taskExecutor ;
303+ return this ;
304+ }
305+
282306 public DefaultToolCallingManager build () {
283307 return new DefaultToolCallingManager (this .observationRegistry , this .toolCallbackResolver ,
284- this .toolExecutionExceptionProcessor );
308+ this .toolExecutionExceptionProcessor , taskExecutor );
285309 }
286310
287311 }
0 commit comments