Skip to content

Commit 101bbc4

Browse files
committed
fix: Fixed the issue where tool call information was lost when using DefaultChatOptions.
Signed-off-by: Sun Yuhan <[email protected]>
1 parent ea995df commit 101bbc4

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
import org.springframework.ai.chat.messages.SystemMessage;
2828
import org.springframework.ai.chat.messages.UserMessage;
2929
import org.springframework.ai.chat.prompt.ChatOptions;
30+
import org.springframework.ai.chat.prompt.DefaultChatOptions;
3031
import org.springframework.ai.chat.prompt.Prompt;
3132
import org.springframework.ai.chat.prompt.PromptTemplate;
33+
import org.springframework.ai.model.tool.DefaultToolCallingChatOptions;
3234
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3335
import org.springframework.ai.tool.ToolCallback;
3436
import org.springframework.util.Assert;
@@ -39,6 +41,7 @@
3941
* Utilities for supporting the {@link DefaultChatClient} implementation.
4042
*
4143
* @author Thomas Vitale
44+
* @author Sun Yuhan
4245
* @since 1.0.0
4346
*/
4447
final class DefaultChatClientUtils {
@@ -94,6 +97,23 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient
9497
*/
9598

9699
ChatOptions processedChatOptions = inputRequest.getChatOptions();
100+
101+
if (processedChatOptions instanceof DefaultChatOptions defaultChatOptions) {
102+
if (!inputRequest.getToolNames().isEmpty() || !inputRequest.getToolCallbacks().isEmpty()
103+
|| !CollectionUtils.isEmpty(inputRequest.getToolContext())) {
104+
processedChatOptions = DefaultToolCallingChatOptions.builder()
105+
.model(defaultChatOptions.getModel())
106+
.frequencyPenalty(defaultChatOptions.getFrequencyPenalty())
107+
.maxTokens(defaultChatOptions.getMaxTokens())
108+
.presencePenalty(defaultChatOptions.getPresencePenalty())
109+
.stopSequences(defaultChatOptions.getStopSequences())
110+
.temperature(defaultChatOptions.getTemperature())
111+
.topK(defaultChatOptions.getTopK())
112+
.topP(defaultChatOptions.getTopP())
113+
.build();
114+
}
115+
}
116+
97117
if (processedChatOptions instanceof ToolCallingChatOptions toolCallingChatOptions) {
98118
if (!inputRequest.getToolNames().isEmpty()) {
99119
Set<String> toolNames = ToolCallingChatOptions

spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.springframework.ai.chat.messages.SystemMessage;
2727
import org.springframework.ai.chat.messages.UserMessage;
2828
import org.springframework.ai.chat.model.ChatModel;
29+
import org.springframework.ai.chat.prompt.DefaultChatOptions;
2930
import org.springframework.ai.content.Media;
3031
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3132
import org.springframework.ai.template.TemplateRenderer;
@@ -43,6 +44,7 @@
4344
* Unit tests for {@link DefaultChatClientUtils}.
4445
*
4546
* @author Thomas Vitale
47+
* @author Sun Yuhan
4648
*/
4749
class DefaultChatClientUtilsTests {
4850

@@ -322,6 +324,64 @@ void whenToolContextAndChatOptionsAreProvidedThenTheValuesAreMerged() {
322324
.containsAllEntriesOf(toolContext2);
323325
}
324326

327+
@Test
328+
void whenToolNamesAndChatOptionsAreDefaultChatOptions() {
329+
Set<String> toolNames1 = Set.of("toolA", "toolB");
330+
DefaultChatOptions chatOptions = new DefaultChatOptions();
331+
ChatModel chatModel = mock(ChatModel.class);
332+
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
333+
.create(chatModel)
334+
.prompt()
335+
.options(chatOptions)
336+
.toolNames(toolNames1.toArray(new String[0]));
337+
338+
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
339+
340+
assertThat(result).isNotNull();
341+
assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class);
342+
ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions();
343+
assertThat(resultOptions).isNotNull();
344+
assertThat(resultOptions.getToolNames()).containsExactlyInAnyOrderElementsOf(toolNames1);
345+
}
346+
347+
@Test
348+
void whenToolCallbacksAndChatOptionsAreDefaultChatOptions() {
349+
ToolCallback toolCallback1 = new TestToolCallback("tool1");
350+
DefaultChatOptions chatOptions = new DefaultChatOptions();
351+
ChatModel chatModel = mock(ChatModel.class);
352+
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
353+
.create(chatModel)
354+
.prompt()
355+
.options(chatOptions)
356+
.toolCallbacks(toolCallback1);
357+
358+
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
359+
360+
assertThat(result).isNotNull();
361+
assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class);
362+
ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions();
363+
assertThat(resultOptions).isNotNull();
364+
assertThat(resultOptions.getToolCallbacks()).containsExactlyInAnyOrder(toolCallback1);
365+
}
366+
367+
@Test
368+
void whenToolContextAndChatOptionsAreDefaultChatOptions() {
369+
Map<String, Object> toolContext1 = Map.of("key1", "value1");
370+
DefaultChatOptions chatOptions = new DefaultChatOptions();
371+
ChatModel chatModel = mock(ChatModel.class);
372+
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
373+
.create(chatModel)
374+
.prompt()
375+
.options(chatOptions)
376+
.toolContext(toolContext1);
377+
378+
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
379+
assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class);
380+
ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions();
381+
assertThat(resultOptions).isNotNull();
382+
assertThat(resultOptions.getToolContext()).containsAllEntriesOf(toolContext1);
383+
}
384+
325385
@Test
326386
void whenAdvisorParamsAreProvidedThenTheyAreAddedToContext() {
327387
Map<String, Object> advisorParams = Map.of("key1", "value1", "key2", "value2");

0 commit comments

Comments
 (0)