diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 20b207d5c5c..df432283586 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -618,6 +618,8 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe private final List toolCallbacks = new ArrayList<>(); + private final List toolCallbackProviders = new ArrayList<>(); + private final List messages = new ArrayList<>(); private final Map userParams = new HashMap<>(); @@ -648,16 +650,17 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe /* copy constructor */ DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) { this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.userMetadata, ccr.systemText, ccr.systemParams, - ccr.systemMetadata, ccr.toolCallbacks, ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions, - ccr.advisors, ccr.advisorParams, ccr.observationRegistry, ccr.observationConvention, - ccr.toolContext, ccr.templateRenderer); + ccr.systemMetadata, ccr.toolCallbacks, ccr.toolCallbackProviders, ccr.messages, ccr.toolNames, + ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams, ccr.observationRegistry, + ccr.observationConvention, ccr.toolContext, ccr.templateRenderer); } public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText, Map userParams, Map userMetadata, @Nullable String systemText, Map systemParams, Map systemMetadata, List toolCallbacks, - List messages, List toolNames, List media, @Nullable ChatOptions chatOptions, - List advisors, Map advisorParams, ObservationRegistry observationRegistry, + List toolCallbackProviders, List messages, List toolNames, + List media, @Nullable ChatOptions chatOptions, List advisors, + Map advisorParams, ObservationRegistry observationRegistry, @Nullable ChatClientObservationConvention observationConvention, Map toolContext, @Nullable TemplateRenderer templateRenderer) { @@ -667,6 +670,7 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe Assert.notNull(systemParams, "systemParams cannot be null"); Assert.notNull(systemMetadata, "systemMetadata cannot be null"); Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + Assert.notNull(toolCallbackProviders, "toolCallbackProviders cannot be null"); Assert.notNull(messages, "messages cannot be null"); Assert.notNull(toolNames, "toolNames cannot be null"); Assert.notNull(media, "media cannot be null"); @@ -689,6 +693,7 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe this.toolNames.addAll(toolNames); this.toolCallbacks.addAll(toolCallbacks); + this.toolCallbackProviders.addAll(toolCallbackProviders); this.messages.addAll(messages); this.media.addAll(media); this.advisors.addAll(advisors); @@ -755,6 +760,10 @@ public List getToolCallbacks() { return this.toolCallbacks; } + public List getToolCallbackProviders() { + return this.toolCallbackProviders; + } + public Map getToolContext() { return this.toolContext; } @@ -773,6 +782,7 @@ public Builder mutate() { .builder(this.chatModel, this.observationRegistry, this.observationConvention) .defaultTemplateRenderer(this.templateRenderer) .defaultToolCallbacks(this.toolCallbacks) + .defaultToolCallbacks(this.toolCallbackProviders.toArray(new ToolCallback[0])) .defaultToolContext(this.toolContext) .defaultToolNames(StringUtils.toStringArray(this.toolNames)); @@ -885,9 +895,7 @@ public ChatClientRequestSpec tools(Object... toolObjects) { public ChatClientRequestSpec toolCallbacks(ToolCallbackProvider... toolCallbackProviders) { Assert.notNull(toolCallbackProviders, "toolCallbackProviders cannot be null"); Assert.noNullElements(toolCallbackProviders, "toolCallbackProviders cannot contain null elements"); - for (ToolCallbackProvider toolCallbackProvider : toolCallbackProviders) { - this.toolCallbacks.addAll(List.of(toolCallbackProvider.getToolCallbacks())); - } + this.toolCallbackProviders.addAll(List.of(toolCallbackProviders)); return this; } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index a937356e543..6778dc222e5 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -65,8 +65,8 @@ public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observa Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null"); Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null"); this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), Map.of(), null, Map.of(), - Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, - customObservationConvention, Map.of(), null); + Map.of(), List.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), + observationRegistry, customObservationConvention, Map.of(), null); } public ChatClient build() { diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java index e0259c36ab0..857ca9b2ef5 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java @@ -106,9 +106,16 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient ChatOptions processedChatOptions = inputRequest.getChatOptions(); - if (processedChatOptions instanceof DefaultChatOptions defaultChatOptions) { - if (!inputRequest.getToolNames().isEmpty() || !inputRequest.getToolCallbacks().isEmpty() - || !CollectionUtils.isEmpty(inputRequest.getToolContext())) { + // If we have tool-related configuration but no tool or non-tool options, + // create ToolCallingChatOptions + if (!inputRequest.getToolNames().isEmpty() || !inputRequest.getToolCallbacks().isEmpty() + || !inputRequest.getToolCallbackProviders().isEmpty() + || !CollectionUtils.isEmpty(inputRequest.getToolContext())) { + + if (processedChatOptions == null) { + processedChatOptions = new DefaultToolCallingChatOptions(); + } + else if (processedChatOptions instanceof DefaultChatOptions defaultChatOptions) { processedChatOptions = ModelOptionsUtils.copyToTarget(defaultChatOptions, ChatOptions.class, DefaultToolCallingChatOptions.class); } @@ -120,9 +127,16 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient .mergeToolNames(new HashSet<>(inputRequest.getToolNames()), toolCallingChatOptions.getToolNames()); toolCallingChatOptions.setToolNames(toolNames); } - if (!inputRequest.getToolCallbacks().isEmpty()) { - List toolCallbacks = ToolCallingChatOptions - .mergeToolCallbacks(inputRequest.getToolCallbacks(), toolCallingChatOptions.getToolCallbacks()); + + // Lazily resolve ToolCallbackProvider instances to ToolCallback instances + List allToolCallbacks = new ArrayList<>(inputRequest.getToolCallbacks()); + for (var provider : inputRequest.getToolCallbackProviders()) { + allToolCallbacks.addAll(java.util.List.of(provider.getToolCallbacks())); + } + + if (!allToolCallbacks.isEmpty()) { + List toolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(allToolCallbacks, + toolCallingChatOptions.getToolCallbacks()); ToolCallingChatOptions.validateToolCallbacks(toolCallbacks); toolCallingChatOptions.setToolCallbacks(toolCallbacks); } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java index 07adcf72b48..23b8acdaf58 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -50,6 +50,7 @@ import org.springframework.ai.converter.StructuredOutputConverter; import org.springframework.ai.template.TemplateRenderer; import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.convert.support.DefaultConversionService; @@ -61,6 +62,9 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; /** @@ -1474,15 +1478,15 @@ void buildChatClientRequestSpec() { ChatModel chatModel = mock(ChatModel.class); DefaultChatClient.DefaultChatClientRequestSpec spec = new DefaultChatClient.DefaultChatClientRequestSpec( chatModel, null, Map.of(), Map.of(), null, Map.of(), Map.of(), List.of(), List.of(), List.of(), - List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of(), null); + List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of(), null); assertThat(spec).isNotNull(); } @Test void whenChatModelIsNullThenThrow() { assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(null, null, Map.of(), Map.of(), - null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), - ObservationRegistry.NOOP, null, Map.of(), null)) + null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), + Map.of(), ObservationRegistry.NOOP, null, Map.of(), null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("chatModel cannot be null"); } @@ -1490,8 +1494,8 @@ void whenChatModelIsNullThenThrow() { @Test void whenObservationRegistryIsNullThenThrow() { assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(mock(ChatModel.class), null, - Map.of(), Map.of(), null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), null, - List.of(), Map.of(), null, null, Map.of(), null)) + Map.of(), Map.of(), null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), List.of(), + null, List.of(), Map.of(), null, null, Map.of(), null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("observationRegistry cannot be null"); } @@ -2197,6 +2201,115 @@ void whenUserConsumerWithNullParamValueThenThrow() { .hasMessage("value cannot be null"); } + @Test + void whenToolCallbackProviderThenNotEagerlyEvaluated() { + ChatModel chatModel = mock(ChatModel.class); + ToolCallbackProvider provider = mock(ToolCallbackProvider.class); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider); + + // Verify that getToolCallbacks() was NOT called during configuration + verify(provider, never()).getToolCallbacks(); + } + + @Test + void whenToolCallbackProviderThenLazilyEvaluatedOnCall() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + ToolCallbackProvider provider = mock(ToolCallbackProvider.class); + when(provider.getToolCallbacks()).thenReturn(new ToolCallback[] {}); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider); + + // Verify not called yet + verify(provider, never()).getToolCallbacks(); + + // Execute the call + spec.call().content(); + + // Verify getToolCallbacks() WAS called during execution + verify(provider, times(1)).getToolCallbacks(); + } + + @Test + void whenToolCallbackProviderThenLazilyEvaluatedOnStream() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.stream(promptCaptor.capture())) + .willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))))); + + ToolCallbackProvider provider = mock(ToolCallbackProvider.class); + when(provider.getToolCallbacks()).thenReturn(new ToolCallback[] {}); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider); + + // Verify not called yet + verify(provider, never()).getToolCallbacks(); + + // Execute the stream + spec.stream().content().blockLast(); + + // Verify getToolCallbacks() WAS called during execution + verify(provider, times(1)).getToolCallbacks(); + } + + @Test + void whenMultipleToolCallbackProvidersThenAllLazilyEvaluated() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + ToolCallbackProvider provider1 = mock(ToolCallbackProvider.class); + when(provider1.getToolCallbacks()).thenReturn(new ToolCallback[] {}); + + ToolCallbackProvider provider2 = mock(ToolCallbackProvider.class); + when(provider2.getToolCallbacks()).thenReturn(new ToolCallback[] {}); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider1, provider2); + + // Verify not called yet + verify(provider1, never()).getToolCallbacks(); + verify(provider2, never()).getToolCallbacks(); + + // Execute the call + spec.call().content(); + + // Verify both getToolCallbacks() were called during execution + verify(provider1, times(1)).getToolCallbacks(); + verify(provider2, times(1)).getToolCallbacks(); + } + + @Test + void whenToolCallbacksAndProvidersThenBothUsed() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + ToolCallbackProvider provider = mock(ToolCallbackProvider.class); + when(provider.getToolCallbacks()).thenReturn(new ToolCallback[] {}); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider); + + // Verify provider not called yet + verify(provider, never()).getToolCallbacks(); + + // Execute the call + spec.call().content(); + + // Verify provider was called during execution + verify(provider, times(1)).getToolCallbacks(); + } + record Person(String name) { }