Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;

Expand All @@ -27,8 +29,11 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.openai.OpenAiTestConfiguration;
import org.springframework.ai.openai.api.tool.MockWeatherService;
import org.springframework.ai.openai.api.tool.MockWeatherService.Request;
import org.springframework.ai.openai.api.tool.MockWeatherService.Response;
import org.springframework.ai.openai.testutils.AbstractIT;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.test.context.SpringBootTest;
Expand Down Expand Up @@ -107,6 +112,87 @@ void defaultFunctionCallTest() {
assertThat(response).contains("30", "10", "15");
}

@Test
void defaultFunctionCallTestWithToolContext() {

var biFunction = new BiFunction<MockWeatherService.Request, Map<String, Object>, MockWeatherService.Response>() {

@Override
public Response apply(Request request, Map<String, Object> toolContext) {

assertThat(toolContext).containsEntry("sessionId", "123");

double temperature = 0;
if (request.location().contains("Paris")) {
temperature = 15;
}
else if (request.location().contains("Tokyo")) {
temperature = 10;
}
else if (request.location().contains("San Francisco")) {
temperature = 30;
}

return new MockWeatherService.Response(temperature, 15, 20, 2, 53, 45, MockWeatherService.Unit.C);
}

};

// @formatter:off
String response = ChatClient.builder(chatModel)
.defaultFunction("getCurrentWeather", "Get the weather in location", biFunction)
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
.defaultToolContext(Map.of("sessionId", "123"))
.build()
.prompt().call().content();
// @formatter:on

logger.info("Response: {}", response);

assertThat(response).contains("30", "10", "15");
}

@Test
void functionCallTestWithToolContext() {

var biFunction = new BiFunction<MockWeatherService.Request, Map<String, Object>, MockWeatherService.Response>() {

@Override
public Response apply(Request request, Map<String, Object> toolContext) {

assertThat(toolContext).containsEntry("sessionId", "123");

double temperature = 0;
if (request.location().contains("Paris")) {
temperature = 15;
}
else if (request.location().contains("Tokyo")) {
temperature = 10;
}
else if (request.location().contains("San Francisco")) {
temperature = 30;
}

return new MockWeatherService.Response(temperature, 15, 20, 2, 53, 45, MockWeatherService.Unit.C);
}

};

// @formatter:off
String response = ChatClient.builder(chatModel)
.defaultFunction("getCurrentWeather", "Get the weather in location", biFunction)
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
.build()
.prompt()
.toolContext(Map.of("sessionId", "123"))
.call().content();
// @formatter:on

logger.info("Response: {}", response);

assertThat(response).contains("30", "10", "15");
}

@Test
void streamFunctionCallTest() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ <I, O> ChatClientRequestSpec function(String name, String description, Class<I>

ChatClientRequestSpec functions(String... functionBeanNames);

ChatClientRequestSpec toolContext(Map<String, Object> toolContext);

ChatClientRequestSpec system(String text);

ChatClientRequestSpec system(Resource textResource, Charset charset);
Expand Down Expand Up @@ -271,6 +273,8 @@ <I, O> Builder defaultFunction(String name, String description,

Builder defaultFunctions(FunctionCallback... functionCallbacks);

Builder defaultToolContext(Map<String, Object> toolContext);

ChatClient build();

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,14 +549,14 @@ public Map<String, Object> getToolContext() {
DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) {
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks,
ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams,
ccr.observationRegistry, ccr.customObservationConvention);
ccr.observationRegistry, ccr.customObservationConvention, ccr.toolContext);
}

public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map<String, Object> userParams,
String systemText, Map<String, Object> systemParams, List<FunctionCallback> functionCallbacks,
List<Message> messages, List<String> functionNames, List<Media> media, ChatOptions chatOptions,
List<Advisor> advisors, Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
ChatClientObservationConvention customObservationConvention) {
ChatClientObservationConvention customObservationConvention, Map<String, Object> toolContext) {

this.chatModel = chatModel;
this.chatOptions = chatOptions != null ? chatOptions.copy()
Expand All @@ -575,6 +575,7 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map<St
this.advisorParams.putAll(advisorParams);
this.observationRegistry = observationRegistry;
this.customObservationConvention = customObservationConvention;
this.toolContext.putAll(toolContext);

// @formatter:off
// At the stack bottom add the non-streaming and streaming model call advisors.
Expand Down Expand Up @@ -639,6 +640,7 @@ public Builder mutate() {
// workaround to set the missing fields.
builder.defaultRequest.getMessages().addAll(this.messages);
builder.defaultRequest.getFunctionCallbacks().addAll(this.functionCallbacks);
builder.defaultRequest.getToolContext().putAll(this.toolContext);

return builder;
}
Expand Down Expand Up @@ -735,6 +737,12 @@ public ChatClientRequestSpec functions(FunctionCallback... functionCallbacks) {
return this;
}

public ChatClientRequestSpec toolContext(Map<String, Object> toolContext) {
Assert.notNull(toolContext, "the toolContext must be non-null");
this.toolContext.putAll(toolContext);
return this;
}

public ChatClientRequestSpec system(String text) {
Assert.notNull(text, "the text must be non-null");
this.systemText = text;
Expand Down Expand Up @@ -827,7 +835,8 @@ private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inpu
return new AdvisedRequest(inputRequest.chatModel, inputRequest.userText, inputRequest.systemText,
inputRequest.chatOptions, inputRequest.media, inputRequest.functionNames,
inputRequest.functionCallbacks, inputRequest.messages, inputRequest.userParams,
inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext);
inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext,
inputRequest.toolContext);
}

public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest,
Expand All @@ -837,7 +846,8 @@ public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(Advise
advisedRequest.userParams(), advisedRequest.systemText(), advisedRequest.systemParams(),
advisedRequest.functionCallbacks(), advisedRequest.messages(), advisedRequest.functionNames(),
advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(),
advisedRequest.advisorParams(), observationRegistry, customObservationConvention);
advisedRequest.advisorParams(), observationRegistry, customObservationConvention,
advisedRequest.toolContext());
}

// Prompt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observa
Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null");
this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(),
List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry,
customObservationConvention);
customObservationConvention, Map.of());
}

public ChatClient build() {
Expand Down Expand Up @@ -157,4 +157,9 @@ public Builder defaultFunctions(FunctionCallback... functionCallbacks) {
return this;
}

public Builder defaultToolContext(Map<String, Object> toolContext) {
this.defaultRequest.toolContext(toolContext);
return this;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
public record AdvisedRequest(ChatModel chatModel, String userText, String systemText, ChatOptions chatOptions,
List<Media> media, List<String> functionNames, List<FunctionCallback> functionCallbacks, List<Message> messages,
Map<String, Object> userParams, Map<String, Object> systemParams, List<Advisor> advisors,
Map<String, Object> advisorParams, Map<String, Object> adviseContext) {
Map<String, Object> advisorParams, Map<String, Object> adviseContext, Map<String, Object> toolContext) {

public AdvisedRequest updateContext(Function<Map<String, Object>, Map<String, Object>> contextTransform) {
return from(this)
Expand All @@ -83,6 +83,7 @@ public static Builder from(AdvisedRequest from) {
builder.advisors = from.advisors;
builder.advisorParams = from.advisorParams;
builder.adviseContext = from.adviseContext;
builder.toolContext = from.toolContext;

return builder;
}
Expand All @@ -93,6 +94,8 @@ public static Builder builder() {

public static class Builder {

public Map<String, Object> toolContext;

private ChatModel chatModel;

private String userText = "";
Expand Down Expand Up @@ -154,6 +157,11 @@ public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
return this;
}

public Builder withToolContext(Map<String, Object> toolContext) {
this.toolContext = toolContext;
return this;
}

public Builder withMessages(List<Message> messages) {
this.messages = messages;
return this;
Expand Down Expand Up @@ -187,7 +195,7 @@ public Builder withAdviseContext(Map<String, Object> adviseContext) {
public AdvisedRequest build() {
return new AdvisedRequest(chatModel, this.userText, this.systemText, this.chatOptions, this.media,
this.functionNames, this.functionCallbacks, this.messages, this.userParams, this.systemParams,
this.advisors, this.advisorParams, this.adviseContext);
this.advisors, this.advisorParams, this.adviseContext, this.toolContext);
}

}
Expand Down Expand Up @@ -228,6 +236,9 @@ public Prompt toPrompt() {
if (!this.functionCallbacks().isEmpty()) {
functionCallingOptions.setFunctionCallbacks(this.functionCallbacks());
}
if (!CollectionUtils.isEmpty(this.toolContext())) {
functionCallingOptions.setToolContext(this.toolContext());
}
}

return new Prompt(messages, this.chatOptions());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ void whenEmptyInputContentThenReturnOriginalContext() {
ChatClientObservationConvention customObservationConvention = null;

var request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(),
List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention);
List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention,
Map.of());

var expectedContext = ChatClientObservationContext.builder().withRequest(request).build();

Expand All @@ -76,7 +77,7 @@ void whenWithTextThenAugmentContext() {

var request = new DefaultChatClientRequestSpec(chatModel, "sample user text", Map.of("up1", "upv1"),
"sample system text", Map.of("sp1", "sp1v"), List.of(), List.of(), List.of(), List.of(), null,
List.of(), Map.of(), observationRegistry, customObservationConvention);
List.of(), Map.of(), observationRegistry, customObservationConvention, Map.of());

var originalContext = ChatClientObservationContext.builder().withRequest(request).build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ChatClientObservationContextTests {
void whenMandatoryRequestOptionsThenReturn() {

var request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(),
List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null);
List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of());

var observationContext = ChatClientObservationContext.builder().withRequest(request).withStream(true).build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,19 @@
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.RequestResponseAdvisor;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames;
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.observation.conventions.SpringAiKind;

import io.micrometer.common.KeyValue;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.observation.conventions.SpringAiKind;

/**
* Unit tests for {@link DefaultChatClientObservationConvention}.
Expand All @@ -60,7 +59,7 @@ class DefaultChatClientObservationConventionTests {
@BeforeEach
public void beforeEach() {
request = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(),
List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null);
List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of());
}

@Test
Expand Down Expand Up @@ -160,7 +159,7 @@ void shouldHaveOptionalKeyValues() {
List.of(dummyFunction("functionCallback1"), dummyFunction("functionCallback2")), List.of(),
List.of("function1", "function2"), List.of(), null,
List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2")), Map.of("advParam1", "advisorParam1Value"),
ObservationRegistry.NOOP, null);
ObservationRegistry.NOOP, null, Map.of());

ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
.withRequest(request)
Expand Down