diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java index 0edbbcf326a..a148de5ab98 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java @@ -19,6 +19,8 @@ import java.lang.reflect.Method; import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Collectors; @@ -27,8 +29,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.openai.OpenAiTestConfiguration; import org.springframework.ai.openai.api.tool.MockWeatherService; +import org.springframework.ai.openai.api.tool.MockWeatherService.Request; +import org.springframework.ai.openai.api.tool.MockWeatherService.Response; import org.springframework.ai.openai.testutils.AbstractIT; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; @@ -107,6 +112,87 @@ void defaultFunctionCallTest() { assertThat(response).contains("30", "10", "15"); } + @Test + void defaultFunctionCallTestWithToolContext() { + + var biFunction = new BiFunction, MockWeatherService.Response>() { + + @Override + public Response apply(Request request, Map toolContext) { + + assertThat(toolContext).containsEntry("sessionId", "123"); + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new MockWeatherService.Response(temperature, 15, 20, 2, 53, 45, MockWeatherService.Unit.C); + } + + }; + + // @formatter:off + String response = ChatClient.builder(chatModel) + .defaultFunction("getCurrentWeather", "Get the weather in location", biFunction) + .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) + .defaultToolContext(Map.of("sessionId", "123")) + .build() + .prompt().call().content(); + // @formatter:on + + logger.info("Response: {}", response); + + assertThat(response).contains("30", "10", "15"); + } + + @Test + void functionCallTestWithToolContext() { + + var biFunction = new BiFunction, MockWeatherService.Response>() { + + @Override + public Response apply(Request request, Map toolContext) { + + assertThat(toolContext).containsEntry("sessionId", "123"); + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new MockWeatherService.Response(temperature, 15, 20, 2, 53, 45, MockWeatherService.Unit.C); + } + + }; + + // @formatter:off + String response = ChatClient.builder(chatModel) + .defaultFunction("getCurrentWeather", "Get the weather in location", biFunction) + .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) + .build() + .prompt() + .toolContext(Map.of("sessionId", "123")) + .call().content(); + // @formatter:on + + logger.info("Response: {}", response); + + assertThat(response).contains("30", "10", "15"); + } + @Test void streamFunctionCallTest() { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index 81d96ffd3c4..6de79e79200 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -211,6 +211,8 @@ ChatClientRequestSpec function(String name, String description, Class ChatClientRequestSpec functions(String... functionBeanNames); + ChatClientRequestSpec toolContext(Map toolContext); + ChatClientRequestSpec system(String text); ChatClientRequestSpec system(Resource textResource, Charset charset); @@ -271,6 +273,8 @@ Builder defaultFunction(String name, String description, Builder defaultFunctions(FunctionCallback... functionCallbacks); + Builder defaultToolContext(Map toolContext); + ChatClient build(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index ab38447c416..377055ad60d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -549,14 +549,14 @@ public Map getToolContext() { DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) { this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks, ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams, - ccr.observationRegistry, ccr.customObservationConvention); + ccr.observationRegistry, ccr.customObservationConvention, ccr.toolContext); } public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map userParams, String systemText, Map systemParams, List functionCallbacks, List messages, List functionNames, List media, ChatOptions chatOptions, List advisors, Map advisorParams, ObservationRegistry observationRegistry, - ChatClientObservationConvention customObservationConvention) { + ChatClientObservationConvention customObservationConvention, Map toolContext) { this.chatModel = chatModel; this.chatOptions = chatOptions != null ? chatOptions.copy() @@ -575,6 +575,7 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map toolContext) { + Assert.notNull(toolContext, "the toolContext must be non-null"); + this.toolContext.putAll(toolContext); + return this; + } + public ChatClientRequestSpec system(String text) { Assert.notNull(text, "the text must be non-null"); this.systemText = text; @@ -827,7 +835,8 @@ private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inpu return new AdvisedRequest(inputRequest.chatModel, inputRequest.userText, inputRequest.systemText, inputRequest.chatOptions, inputRequest.media, inputRequest.functionNames, inputRequest.functionCallbacks, inputRequest.messages, inputRequest.userParams, - inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext); + inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext, + inputRequest.toolContext); } public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest, @@ -837,7 +846,8 @@ public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(Advise advisedRequest.userParams(), advisedRequest.systemText(), advisedRequest.systemParams(), advisedRequest.functionCallbacks(), advisedRequest.messages(), advisedRequest.functionNames(), advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(), - advisedRequest.advisorParams(), observationRegistry, customObservationConvention); + advisedRequest.advisorParams(), observationRegistry, customObservationConvention, + advisedRequest.toolContext()); } // Prompt diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 93af46815c3..f876d97e3c9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -61,7 +61,7 @@ public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observa Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null"); this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, - customObservationConvention); + customObservationConvention, Map.of()); } public ChatClient build() { @@ -157,4 +157,9 @@ public Builder defaultFunctions(FunctionCallback... functionCallbacks) { return this; } + public Builder defaultToolContext(Map toolContext) { + this.defaultRequest.toolContext(toolContext); + return this; + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java index 8630c697e40..2a8ea7aa2af 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java @@ -60,7 +60,7 @@ public record AdvisedRequest(ChatModel chatModel, String userText, String systemText, ChatOptions chatOptions, List media, List functionNames, List functionCallbacks, List messages, Map userParams, Map systemParams, List advisors, - Map advisorParams, Map adviseContext) { + Map advisorParams, Map adviseContext, Map toolContext) { public AdvisedRequest updateContext(Function, Map> contextTransform) { return from(this) @@ -83,6 +83,7 @@ public static Builder from(AdvisedRequest from) { builder.advisors = from.advisors; builder.advisorParams = from.advisorParams; builder.adviseContext = from.adviseContext; + builder.toolContext = from.toolContext; return builder; } @@ -93,6 +94,8 @@ public static Builder builder() { public static class Builder { + public Map toolContext; + private ChatModel chatModel; private String userText = ""; @@ -154,6 +157,11 @@ public Builder withFunctionCallbacks(List functionCallbacks) { return this; } + public Builder withToolContext(Map toolContext) { + this.toolContext = toolContext; + return this; + } + public Builder withMessages(List messages) { this.messages = messages; return this; @@ -187,7 +195,7 @@ public Builder withAdviseContext(Map adviseContext) { public AdvisedRequest build() { return new AdvisedRequest(chatModel, this.userText, this.systemText, this.chatOptions, this.media, this.functionNames, this.functionCallbacks, this.messages, this.userParams, this.systemParams, - this.advisors, this.advisorParams, this.adviseContext); + this.advisors, this.advisorParams, this.adviseContext, this.toolContext); } } @@ -228,6 +236,9 @@ public Prompt toPrompt() { if (!this.functionCallbacks().isEmpty()) { functionCallingOptions.setFunctionCallbacks(this.functionCallbacks()); } + if (!CollectionUtils.isEmpty(this.toolContext())) { + functionCallingOptions.setToolContext(this.toolContext()); + } } return new Prompt(messages, this.chatOptions()); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java index 5480129157f..2b189fdf797 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java @@ -60,7 +60,8 @@ void whenEmptyInputContentThenReturnOriginalContext() { ChatClientObservationConvention customObservationConvention = null; var request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), - List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention); + List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention, + Map.of()); var expectedContext = ChatClientObservationContext.builder().withRequest(request).build(); @@ -76,7 +77,7 @@ void whenWithTextThenAugmentContext() { var request = new DefaultChatClientRequestSpec(chatModel, "sample user text", Map.of("up1", "upv1"), "sample system text", Map.of("sp1", "sp1v"), List.of(), List.of(), List.of(), List.of(), null, - List.of(), Map.of(), observationRegistry, customObservationConvention); + List.of(), Map.of(), observationRegistry, customObservationConvention, Map.of()); var originalContext = ChatClientObservationContext.builder().withRequest(request).build(); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java index 61f5f00942a..0cc401e8735 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java @@ -45,7 +45,7 @@ class ChatClientObservationContextTests { void whenMandatoryRequestOptionsThenReturn() { var request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), - List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null); + List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of()); var observationContext = ChatClientObservationContext.builder().withRequest(request).withStream(true).build(); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java index 3a470d70864..5809cc19dd3 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java @@ -26,20 +26,19 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.RequestResponseAdvisor; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames; -import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.ai.observation.conventions.SpringAiKind; import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; -import org.springframework.ai.observation.conventions.AiProvider; -import org.springframework.ai.observation.conventions.SpringAiKind; /** * Unit tests for {@link DefaultChatClientObservationConvention}. @@ -60,7 +59,7 @@ class DefaultChatClientObservationConventionTests { @BeforeEach public void beforeEach() { request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), - List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null); + List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of()); } @Test @@ -160,7 +159,7 @@ void shouldHaveOptionalKeyValues() { List.of(dummyFunction("functionCallback1"), dummyFunction("functionCallback2")), List.of(), List.of("function1", "function2"), List.of(), null, List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2")), Map.of("advParam1", "advisorParam1Value"), - ObservationRegistry.NOOP, null); + ObservationRegistry.NOOP, null, Map.of()); ChatClientObservationContext observationContext = ChatClientObservationContext.builder() .withRequest(request)