Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,8 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe

private final List<ToolCallback> toolCallbacks = new ArrayList<>();

private final List<ToolCallbackProvider> toolCallbackProviders = new ArrayList<>();

private final List<Message> messages = new ArrayList<>();

private final Map<String, Object> userParams = new HashMap<>();
Expand Down Expand Up @@ -648,16 +650,17 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
/* copy constructor */
DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) {
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.userMetadata, ccr.systemText, ccr.systemParams,
ccr.systemMetadata, ccr.toolCallbacks, ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions,
ccr.advisors, ccr.advisorParams, ccr.observationRegistry, ccr.observationConvention,
ccr.toolContext, ccr.templateRenderer);
ccr.systemMetadata, ccr.toolCallbacks, ccr.toolCallbackProviders, ccr.messages, ccr.toolNames,
ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams, ccr.observationRegistry,
ccr.observationConvention, ccr.toolContext, ccr.templateRenderer);
}

public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText,
Map<String, Object> userParams, Map<String, Object> userMetadata, @Nullable String systemText,
Map<String, Object> systemParams, Map<String, Object> systemMetadata, List<ToolCallback> toolCallbacks,
List<Message> messages, List<String> toolNames, List<Media> media, @Nullable ChatOptions chatOptions,
List<Advisor> advisors, Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
List<ToolCallbackProvider> toolCallbackProviders, List<Message> messages, List<String> toolNames,
List<Media> media, @Nullable ChatOptions chatOptions, List<Advisor> advisors,
Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention observationConvention, Map<String, Object> toolContext,
@Nullable TemplateRenderer templateRenderer) {

Expand All @@ -667,6 +670,7 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe
Assert.notNull(systemParams, "systemParams cannot be null");
Assert.notNull(systemMetadata, "systemMetadata cannot be null");
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
Assert.notNull(toolCallbackProviders, "toolCallbackProviders cannot be null");
Assert.notNull(messages, "messages cannot be null");
Assert.notNull(toolNames, "toolNames cannot be null");
Assert.notNull(media, "media cannot be null");
Expand All @@ -689,6 +693,7 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe

this.toolNames.addAll(toolNames);
this.toolCallbacks.addAll(toolCallbacks);
this.toolCallbackProviders.addAll(toolCallbackProviders);
this.messages.addAll(messages);
this.media.addAll(media);
this.advisors.addAll(advisors);
Expand Down Expand Up @@ -755,6 +760,10 @@ public List<ToolCallback> getToolCallbacks() {
return this.toolCallbacks;
}

public List<ToolCallbackProvider> getToolCallbackProviders() {
return this.toolCallbackProviders;
}

public Map<String, Object> getToolContext() {
return this.toolContext;
}
Expand All @@ -773,6 +782,7 @@ public Builder mutate() {
.builder(this.chatModel, this.observationRegistry, this.observationConvention)
.defaultTemplateRenderer(this.templateRenderer)
.defaultToolCallbacks(this.toolCallbacks)
.defaultToolCallbacks(this.toolCallbackProviders.toArray(new ToolCallback[0]))
.defaultToolContext(this.toolContext)
.defaultToolNames(StringUtils.toStringArray(this.toolNames));

Expand Down Expand Up @@ -885,9 +895,7 @@ public ChatClientRequestSpec tools(Object... toolObjects) {
public ChatClientRequestSpec toolCallbacks(ToolCallbackProvider... toolCallbackProviders) {
Assert.notNull(toolCallbackProviders, "toolCallbackProviders cannot be null");
Assert.noNullElements(toolCallbackProviders, "toolCallbackProviders cannot contain null elements");
for (ToolCallbackProvider toolCallbackProvider : toolCallbackProviders) {
this.toolCallbacks.addAll(List.of(toolCallbackProvider.getToolCallbacks()));
}
this.toolCallbackProviders.addAll(List.of(toolCallbackProviders));
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observa
Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null");
Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null");
this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), Map.of(), null, Map.of(),
Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry,
customObservationConvention, Map.of(), null);
Map.of(), List.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(),
observationRegistry, customObservationConvention, Map.of(), null);
}

public ChatClient build() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,16 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient

ChatOptions processedChatOptions = inputRequest.getChatOptions();

if (processedChatOptions instanceof DefaultChatOptions defaultChatOptions) {
if (!inputRequest.getToolNames().isEmpty() || !inputRequest.getToolCallbacks().isEmpty()
|| !CollectionUtils.isEmpty(inputRequest.getToolContext())) {
// If we have tool-related configuration but no tool or non-tool options,
// create ToolCallingChatOptions
if (!inputRequest.getToolNames().isEmpty() || !inputRequest.getToolCallbacks().isEmpty()
|| !inputRequest.getToolCallbackProviders().isEmpty()
|| !CollectionUtils.isEmpty(inputRequest.getToolContext())) {

if (processedChatOptions == null) {
processedChatOptions = new DefaultToolCallingChatOptions();
}
else if (processedChatOptions instanceof DefaultChatOptions defaultChatOptions) {
processedChatOptions = ModelOptionsUtils.copyToTarget(defaultChatOptions, ChatOptions.class,
DefaultToolCallingChatOptions.class);
}
Expand All @@ -120,9 +127,16 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient
.mergeToolNames(new HashSet<>(inputRequest.getToolNames()), toolCallingChatOptions.getToolNames());
toolCallingChatOptions.setToolNames(toolNames);
}
if (!inputRequest.getToolCallbacks().isEmpty()) {
List<ToolCallback> toolCallbacks = ToolCallingChatOptions
.mergeToolCallbacks(inputRequest.getToolCallbacks(), toolCallingChatOptions.getToolCallbacks());

// Lazily resolve ToolCallbackProvider instances to ToolCallback instances
List<ToolCallback> allToolCallbacks = new ArrayList<>(inputRequest.getToolCallbacks());
for (var provider : inputRequest.getToolCallbackProviders()) {
allToolCallbacks.addAll(java.util.List.of(provider.getToolCallbacks()));
}

if (!allToolCallbacks.isEmpty()) {
List<ToolCallback> toolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(allToolCallbacks,
toolCallingChatOptions.getToolCallbacks());
ToolCallingChatOptions.validateToolCallbacks(toolCallbacks);
toolCallingChatOptions.setToolCallbacks(toolCallbacks);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.springframework.ai.converter.StructuredOutputConverter;
import org.springframework.ai.template.TemplateRenderer;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.convert.support.DefaultConversionService;
Expand All @@ -61,6 +62,9 @@
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

/**
Expand Down Expand Up @@ -1474,24 +1478,24 @@ void buildChatClientRequestSpec() {
ChatModel chatModel = mock(ChatModel.class);
DefaultChatClient.DefaultChatClientRequestSpec spec = new DefaultChatClient.DefaultChatClientRequestSpec(
chatModel, null, Map.of(), Map.of(), null, Map.of(), Map.of(), List.of(), List.of(), List.of(),
List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of(), null);
List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of(), null);
assertThat(spec).isNotNull();
}

@Test
void whenChatModelIsNullThenThrow() {
assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(null, null, Map.of(), Map.of(),
null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(),
ObservationRegistry.NOOP, null, Map.of(), null))
null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(),
Map.of(), ObservationRegistry.NOOP, null, Map.of(), null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("chatModel cannot be null");
}

@Test
void whenObservationRegistryIsNullThenThrow() {
assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(mock(ChatModel.class), null,
Map.of(), Map.of(), null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), null,
List.of(), Map.of(), null, null, Map.of(), null))
Map.of(), Map.of(), null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), List.of(),
null, List.of(), Map.of(), null, null, Map.of(), null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("observationRegistry cannot be null");
}
Expand Down Expand Up @@ -2197,6 +2201,115 @@ void whenUserConsumerWithNullParamValueThenThrow() {
.hasMessage("value cannot be null");
}

@Test
void whenToolCallbackProviderThenNotEagerlyEvaluated() {
ChatModel chatModel = mock(ChatModel.class);
ToolCallbackProvider provider = mock(ToolCallbackProvider.class);

ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider);

// Verify that getToolCallbacks() was NOT called during configuration
verify(provider, never()).getToolCallbacks();
}

@Test
void whenToolCallbackProviderThenLazilyEvaluatedOnCall() {
ChatModel chatModel = mock(ChatModel.class);
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
given(chatModel.call(promptCaptor.capture()))
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))));

ToolCallbackProvider provider = mock(ToolCallbackProvider.class);
when(provider.getToolCallbacks()).thenReturn(new ToolCallback[] {});

ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider);

// Verify not called yet
verify(provider, never()).getToolCallbacks();

// Execute the call
spec.call().content();

// Verify getToolCallbacks() WAS called during execution
verify(provider, times(1)).getToolCallbacks();
}

@Test
void whenToolCallbackProviderThenLazilyEvaluatedOnStream() {
ChatModel chatModel = mock(ChatModel.class);
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
given(chatModel.stream(promptCaptor.capture()))
.willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))));

ToolCallbackProvider provider = mock(ToolCallbackProvider.class);
when(provider.getToolCallbacks()).thenReturn(new ToolCallback[] {});

ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider);

// Verify not called yet
verify(provider, never()).getToolCallbacks();

// Execute the stream
spec.stream().content().blockLast();

// Verify getToolCallbacks() WAS called during execution
verify(provider, times(1)).getToolCallbacks();
}

@Test
void whenMultipleToolCallbackProvidersThenAllLazilyEvaluated() {
ChatModel chatModel = mock(ChatModel.class);
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
given(chatModel.call(promptCaptor.capture()))
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))));

ToolCallbackProvider provider1 = mock(ToolCallbackProvider.class);
when(provider1.getToolCallbacks()).thenReturn(new ToolCallback[] {});

ToolCallbackProvider provider2 = mock(ToolCallbackProvider.class);
when(provider2.getToolCallbacks()).thenReturn(new ToolCallback[] {});

ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider1, provider2);

// Verify not called yet
verify(provider1, never()).getToolCallbacks();
verify(provider2, never()).getToolCallbacks();

// Execute the call
spec.call().content();

// Verify both getToolCallbacks() were called during execution
verify(provider1, times(1)).getToolCallbacks();
verify(provider2, times(1)).getToolCallbacks();
}

@Test
void whenToolCallbacksAndProvidersThenBothUsed() {
ChatModel chatModel = mock(ChatModel.class);
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
given(chatModel.call(promptCaptor.capture()))
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))));

ToolCallbackProvider provider = mock(ToolCallbackProvider.class);
when(provider.getToolCallbacks()).thenReturn(new ToolCallback[] {});

ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("test").toolCallbacks(provider);

// Verify provider not called yet
verify(provider, never()).getToolCallbacks();

// Execute the call
spec.call().content();

// Verify provider was called during execution
verify(provider, times(1)).getToolCallbacks();
}

record Person(String name) {
}

Expand Down