Skip to content

Commit 3ad02ea

Browse files
committed
Add tool context support to ChatClient and related classes
- Introduce toolContext to ChatClient, DefaultChatClient, and AdvisedRequest - Add methods to set and manage toolContext via the FunctionCallingOptions - Update tests to include toolContext in relevant scenarios - Implement toolContext handling in function calling options
1 parent 2babb2a commit 3ad02ea

File tree

8 files changed

+132
-16
lines changed

8 files changed

+132
-16
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import java.lang.reflect.Method;
2121
import java.util.List;
22+
import java.util.Map;
23+
import java.util.function.BiFunction;
2224
import java.util.function.Function;
2325
import java.util.stream.Collectors;
2426

@@ -27,8 +29,11 @@
2729
import org.slf4j.Logger;
2830
import org.slf4j.LoggerFactory;
2931
import org.springframework.ai.chat.client.ChatClient;
32+
import org.springframework.ai.model.function.FunctionCallback;
3033
import org.springframework.ai.openai.OpenAiTestConfiguration;
3134
import org.springframework.ai.openai.api.tool.MockWeatherService;
35+
import org.springframework.ai.openai.api.tool.MockWeatherService.Request;
36+
import org.springframework.ai.openai.api.tool.MockWeatherService.Response;
3237
import org.springframework.ai.openai.testutils.AbstractIT;
3338
import org.springframework.beans.factory.annotation.Value;
3439
import org.springframework.boot.test.context.SpringBootTest;
@@ -107,6 +112,87 @@ void defaultFunctionCallTest() {
107112
assertThat(response).contains("30", "10", "15");
108113
}
109114

115+
@Test
116+
void defaultFunctionCallTestWithToolContext() {
117+
118+
var biFunction = new BiFunction<MockWeatherService.Request, Map<String, Object>, MockWeatherService.Response>() {
119+
120+
@Override
121+
public Response apply(Request request, Map<String, Object> toolContext) {
122+
123+
assertThat(toolContext).containsEntry("sessionId", "123");
124+
125+
double temperature = 0;
126+
if (request.location().contains("Paris")) {
127+
temperature = 15;
128+
}
129+
else if (request.location().contains("Tokyo")) {
130+
temperature = 10;
131+
}
132+
else if (request.location().contains("San Francisco")) {
133+
temperature = 30;
134+
}
135+
136+
return new MockWeatherService.Response(temperature, 15, 20, 2, 53, 45, MockWeatherService.Unit.C);
137+
}
138+
139+
};
140+
141+
// @formatter:off
142+
String response = ChatClient.builder(chatModel)
143+
.defaultFunction("getCurrentWeather", "Get the weather in location", biFunction)
144+
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
145+
.defaultToolContext(Map.of("sessionId", "123"))
146+
.build()
147+
.prompt().call().content();
148+
// @formatter:on
149+
150+
logger.info("Response: {}", response);
151+
152+
assertThat(response).contains("30", "10", "15");
153+
}
154+
155+
@Test
156+
void functionCallTestWithToolContext() {
157+
158+
var biFunction = new BiFunction<MockWeatherService.Request, Map<String, Object>, MockWeatherService.Response>() {
159+
160+
@Override
161+
public Response apply(Request request, Map<String, Object> toolContext) {
162+
163+
assertThat(toolContext).containsEntry("sessionId", "123");
164+
165+
double temperature = 0;
166+
if (request.location().contains("Paris")) {
167+
temperature = 15;
168+
}
169+
else if (request.location().contains("Tokyo")) {
170+
temperature = 10;
171+
}
172+
else if (request.location().contains("San Francisco")) {
173+
temperature = 30;
174+
}
175+
176+
return new MockWeatherService.Response(temperature, 15, 20, 2, 53, 45, MockWeatherService.Unit.C);
177+
}
178+
179+
};
180+
181+
// @formatter:off
182+
String response = ChatClient.builder(chatModel)
183+
.defaultFunction("getCurrentWeather", "Get the weather in location", biFunction)
184+
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
185+
.build()
186+
.prompt()
187+
.toolContext(Map.of("sessionId", "123"))
188+
.call().content();
189+
// @formatter:on
190+
191+
logger.info("Response: {}", response);
192+
193+
assertThat(response).contains("30", "10", "15");
194+
}
195+
110196
@Test
111197
void streamFunctionCallTest() {
112198

spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ <I, O> ChatClientRequestSpec function(String name, String description, Class<I>
211211

212212
ChatClientRequestSpec functions(String... functionBeanNames);
213213

214+
ChatClientRequestSpec toolContext(Map<String, Object> toolContext);
215+
214216
ChatClientRequestSpec system(String text);
215217

216218
ChatClientRequestSpec system(Resource textResource, Charset charset);
@@ -271,6 +273,8 @@ <I, O> Builder defaultFunction(String name, String description,
271273

272274
Builder defaultFunctions(FunctionCallback... functionCallbacks);
273275

276+
Builder defaultToolContext(Map<String, Object> toolContext);
277+
274278
ChatClient build();
275279

276280
}

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -549,14 +549,14 @@ public Map<String, Object> getToolContext() {
549549
DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) {
550550
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks,
551551
ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams,
552-
ccr.observationRegistry, ccr.customObservationConvention);
552+
ccr.observationRegistry, ccr.customObservationConvention, ccr.toolContext);
553553
}
554554

555555
public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map<String, Object> userParams,
556556
String systemText, Map<String, Object> systemParams, List<FunctionCallback> functionCallbacks,
557557
List<Message> messages, List<String> functionNames, List<Media> media, ChatOptions chatOptions,
558558
List<Advisor> advisors, Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
559-
ChatClientObservationConvention customObservationConvention) {
559+
ChatClientObservationConvention customObservationConvention, Map<String, Object> toolContext) {
560560

561561
this.chatModel = chatModel;
562562
this.chatOptions = chatOptions != null ? chatOptions.copy()
@@ -575,6 +575,7 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map<St
575575
this.advisorParams.putAll(advisorParams);
576576
this.observationRegistry = observationRegistry;
577577
this.customObservationConvention = customObservationConvention;
578+
this.toolContext.putAll(toolContext);
578579

579580
// @formatter:off
580581
// At the stack bottom add the non-streaming and streaming model call advisors.
@@ -639,6 +640,7 @@ public Builder mutate() {
639640
// workaround to set the missing fields.
640641
builder.defaultRequest.getMessages().addAll(this.messages);
641642
builder.defaultRequest.getFunctionCallbacks().addAll(this.functionCallbacks);
643+
builder.defaultRequest.getToolContext().putAll(this.toolContext);
642644

643645
return builder;
644646
}
@@ -735,6 +737,12 @@ public ChatClientRequestSpec functions(FunctionCallback... functionCallbacks) {
735737
return this;
736738
}
737739

740+
public ChatClientRequestSpec toolContext(Map<String, Object> toolContext) {
741+
Assert.notNull(toolContext, "the toolContext must be non-null");
742+
this.toolContext.putAll(toolContext);
743+
return this;
744+
}
745+
738746
public ChatClientRequestSpec system(String text) {
739747
Assert.notNull(text, "the text must be non-null");
740748
this.systemText = text;
@@ -827,7 +835,8 @@ private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inpu
827835
return new AdvisedRequest(inputRequest.chatModel, inputRequest.userText, inputRequest.systemText,
828836
inputRequest.chatOptions, inputRequest.media, inputRequest.functionNames,
829837
inputRequest.functionCallbacks, inputRequest.messages, inputRequest.userParams,
830-
inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext);
838+
inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext,
839+
inputRequest.toolContext);
831840
}
832841

833842
public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest,
@@ -837,7 +846,8 @@ public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(Advise
837846
advisedRequest.userParams(), advisedRequest.systemText(), advisedRequest.systemParams(),
838847
advisedRequest.functionCallbacks(), advisedRequest.messages(), advisedRequest.functionNames(),
839848
advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(),
840-
advisedRequest.advisorParams(), observationRegistry, customObservationConvention);
849+
advisedRequest.advisorParams(), observationRegistry, customObservationConvention,
850+
advisedRequest.toolContext());
841851
}
842852

843853
// Prompt

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observa
6161
Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null");
6262
this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(),
6363
List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry,
64-
customObservationConvention);
64+
customObservationConvention, Map.of());
6565
}
6666

6767
public ChatClient build() {
@@ -157,4 +157,9 @@ public Builder defaultFunctions(FunctionCallback... functionCallbacks) {
157157
return this;
158158
}
159159

160+
public Builder defaultToolContext(Map<String, Object> toolContext) {
161+
this.defaultRequest.toolContext(toolContext);
162+
return this;
163+
}
164+
160165
}

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
public record AdvisedRequest(ChatModel chatModel, String userText, String systemText, ChatOptions chatOptions,
6161
List<Media> media, List<String> functionNames, List<FunctionCallback> functionCallbacks, List<Message> messages,
6262
Map<String, Object> userParams, Map<String, Object> systemParams, List<Advisor> advisors,
63-
Map<String, Object> advisorParams, Map<String, Object> adviseContext) {
63+
Map<String, Object> advisorParams, Map<String, Object> adviseContext, Map<String, Object> toolContext) {
6464

6565
public AdvisedRequest updateContext(Function<Map<String, Object>, Map<String, Object>> contextTransform) {
6666
return from(this)
@@ -83,6 +83,7 @@ public static Builder from(AdvisedRequest from) {
8383
builder.advisors = from.advisors;
8484
builder.advisorParams = from.advisorParams;
8585
builder.adviseContext = from.adviseContext;
86+
builder.toolContext = from.toolContext;
8687

8788
return builder;
8889
}
@@ -93,6 +94,8 @@ public static Builder builder() {
9394

9495
public static class Builder {
9596

97+
public Map<String, Object> toolContext;
98+
9699
private ChatModel chatModel;
97100

98101
private String userText = "";
@@ -154,6 +157,11 @@ public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
154157
return this;
155158
}
156159

160+
public Builder withToolContext(Map<String, Object> toolContext) {
161+
this.toolContext = toolContext;
162+
return this;
163+
}
164+
157165
public Builder withMessages(List<Message> messages) {
158166
this.messages = messages;
159167
return this;
@@ -187,7 +195,7 @@ public Builder withAdviseContext(Map<String, Object> adviseContext) {
187195
public AdvisedRequest build() {
188196
return new AdvisedRequest(chatModel, this.userText, this.systemText, this.chatOptions, this.media,
189197
this.functionNames, this.functionCallbacks, this.messages, this.userParams, this.systemParams,
190-
this.advisors, this.advisorParams, this.adviseContext);
198+
this.advisors, this.advisorParams, this.adviseContext, this.toolContext);
191199
}
192200

193201
}
@@ -228,6 +236,9 @@ public Prompt toPrompt() {
228236
if (!this.functionCallbacks().isEmpty()) {
229237
functionCallingOptions.setFunctionCallbacks(this.functionCallbacks());
230238
}
239+
if (!CollectionUtils.isEmpty(this.toolContext())) {
240+
functionCallingOptions.setToolContext(this.toolContext());
241+
}
231242
}
232243

233244
return new Prompt(messages, this.chatOptions());

spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ void whenEmptyInputContentThenReturnOriginalContext() {
6060
ChatClientObservationConvention customObservationConvention = null;
6161

6262
var request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(),
63-
List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention);
63+
List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention,
64+
Map.of());
6465

6566
var expectedContext = ChatClientObservationContext.builder().withRequest(request).build();
6667

@@ -76,7 +77,7 @@ void whenWithTextThenAugmentContext() {
7677

7778
var request = new DefaultChatClientRequestSpec(chatModel, "sample user text", Map.of("up1", "upv1"),
7879
"sample system text", Map.of("sp1", "sp1v"), List.of(), List.of(), List.of(), List.of(), null,
79-
List.of(), Map.of(), observationRegistry, customObservationConvention);
80+
List.of(), Map.of(), observationRegistry, customObservationConvention, Map.of());
8081

8182
var originalContext = ChatClientObservationContext.builder().withRequest(request).build();
8283

spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class ChatClientObservationContextTests {
4545
void whenMandatoryRequestOptionsThenReturn() {
4646

4747
var request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(),
48-
List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null);
48+
List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of());
4949

5050
var observationContext = ChatClientObservationContext.builder().withRequest(request).withStream(true).build();
5151

spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,19 @@
2626
import org.mockito.Mock;
2727
import org.mockito.junit.jupiter.MockitoExtension;
2828
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
29-
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
3029
import org.springframework.ai.chat.client.RequestResponseAdvisor;
30+
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
3131
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames;
3232
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames;
33-
import org.springframework.ai.chat.metadata.Usage;
3433
import org.springframework.ai.chat.model.ChatModel;
3534
import org.springframework.ai.chat.model.ChatResponse;
3635
import org.springframework.ai.model.function.FunctionCallback;
36+
import org.springframework.ai.observation.conventions.AiProvider;
37+
import org.springframework.ai.observation.conventions.SpringAiKind;
3738

3839
import io.micrometer.common.KeyValue;
3940
import io.micrometer.observation.Observation;
4041
import io.micrometer.observation.ObservationRegistry;
41-
import org.springframework.ai.observation.conventions.AiProvider;
42-
import org.springframework.ai.observation.conventions.SpringAiKind;
4342

4443
/**
4544
* Unit tests for {@link DefaultChatClientObservationConvention}.
@@ -60,7 +59,7 @@ class DefaultChatClientObservationConventionTests {
6059
@BeforeEach
6160
public void beforeEach() {
6261
request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(),
63-
List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null);
62+
List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of());
6463
}
6564

6665
@Test
@@ -160,7 +159,7 @@ void shouldHaveOptionalKeyValues() {
160159
List.of(dummyFunction("functionCallback1"), dummyFunction("functionCallback2")), List.of(),
161160
List.of("function1", "function2"), List.of(), null,
162161
List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2")), Map.of("advParam1", "advisorParam1Value"),
163-
ObservationRegistry.NOOP, null);
162+
ObservationRegistry.NOOP, null, Map.of());
164163

165164
ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
166165
.withRequest(request)

0 commit comments

Comments
 (0)