Skip to content

Commit 3fbc310

Browse files
committed
Allow null or empty userText if request contains Tool Response message
- Add a new test case for testing function call with advisor in BedrockConverseChatClientIT - Add a new test case for testing tool proxy function call in OpenAiChatClientProxyFunctionCallsIT - Remove unused model property from BedrockConverseProxyChatProperties
1 parent cbdb578 commit 3fbc310

File tree

6 files changed

+210
-20
lines changed

6 files changed

+210
-20
lines changed

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,23 @@ void functionCallTest() {
222222
assertThat(response).contains("30", "10", "15");
223223
}
224224

225+
@Test
226+
void functionCallWithAdvisorTest() {
227+
228+
// @formatter:off
229+
String response = ChatClient.create(this.chatModel)
230+
.prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
231+
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
232+
.advisors(new SimpleLoggerAdvisor())
233+
.call()
234+
.content();
235+
// @formatter:on
236+
237+
logger.info("Response: {}", response);
238+
239+
assertThat(response).contains("30", "10", "15");
240+
}
241+
225242
@Test
226243
void defaultFunctionCallTest() {
227244

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.chat.client;
18+
19+
import java.util.ArrayList;
20+
import java.util.List;
21+
import java.util.Map;
22+
import java.util.Optional;
23+
import java.util.Set;
24+
25+
import com.fasterxml.jackson.core.JsonProcessingException;
26+
import com.fasterxml.jackson.databind.JsonMappingException;
27+
import com.fasterxml.jackson.databind.ObjectMapper;
28+
import org.junit.jupiter.api.Test;
29+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
30+
import org.slf4j.Logger;
31+
import org.slf4j.LoggerFactory;
32+
33+
import org.springframework.ai.chat.client.ChatClient;
34+
import org.springframework.ai.chat.messages.AssistantMessage;
35+
import org.springframework.ai.chat.messages.Message;
36+
import org.springframework.ai.chat.messages.ToolResponseMessage;
37+
import org.springframework.ai.chat.messages.UserMessage;
38+
import org.springframework.ai.chat.model.ChatResponse;
39+
import org.springframework.ai.chat.model.Generation;
40+
import org.springframework.ai.model.ModelOptionsUtils;
41+
import org.springframework.ai.model.function.FunctionCallback;
42+
import org.springframework.ai.model.function.ToolCallHelper;
43+
import org.springframework.ai.openai.OpenAiChatModel;
44+
import org.springframework.ai.openai.OpenAiChatOptions;
45+
import org.springframework.ai.openai.OpenAiTestConfiguration;
46+
import org.springframework.ai.openai.api.OpenAiApi;
47+
import org.springframework.ai.openai.testutils.AbstractIT;
48+
import org.springframework.beans.factory.annotation.Autowired;
49+
import org.springframework.beans.factory.annotation.Value;
50+
import org.springframework.boot.test.context.SpringBootTest;
51+
import org.springframework.core.io.Resource;
52+
import org.springframework.test.context.ActiveProfiles;
53+
import org.springframework.util.CollectionUtils;
54+
55+
import static org.assertj.core.api.Assertions.assertThat;
56+
57+
@SpringBootTest(classes = OpenAiTestConfiguration.class)
58+
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
59+
@ActiveProfiles("logging-test")
60+
class OpenAiChatClientProxyFunctionCallsIT extends AbstractIT {
61+
62+
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatClientMultipleFunctionCallsIT.class);
63+
64+
@Value("classpath:/prompts/system-message.st")
65+
private Resource systemTextResource;
66+
67+
FunctionCallback functionDefinition = new ToolCallHelper.FunctionDefinition("getWeatherInLocation",
68+
"Get the weather in location", """
69+
{
70+
"type": "object",
71+
"properties": {
72+
"location": {
73+
"type": "string",
74+
"description": "The city and state e.g. San Francisco, CA"
75+
},
76+
"unit": {
77+
"type": "string",
78+
"enum": ["C", "F"]
79+
}
80+
},
81+
"required": ["location", "unit"]
82+
}
83+
""");
84+
85+
@Autowired
86+
private OpenAiChatModel chatModel;
87+
88+
// Helper class that reuses some of the {@link AbstractToolCallSupport} functionality
89+
// to help to implement the function call handling logic on the client side.
90+
private ToolCallHelper toolCallHelper = new ToolCallHelper();
91+
92+
// Function which will be called by the AI model.
93+
private String getWeatherInLocation(String location, String unit) {
94+
95+
double temperature = 0;
96+
97+
if (location.contains("Paris")) {
98+
temperature = 15;
99+
}
100+
else if (location.contains("Tokyo")) {
101+
temperature = 10;
102+
}
103+
else if (location.contains("San Francisco")) {
104+
temperature = 30;
105+
}
106+
107+
return String.format("The weather in %s is %s%s", location, temperature, unit);
108+
}
109+
110+
@Test
111+
void toolProxyFunctionCall() throws JsonMappingException, JsonProcessingException {
112+
113+
List<Message> messages = List
114+
.of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"));
115+
116+
boolean isToolCall = false;
117+
118+
ChatResponse chatResponse = null;
119+
120+
var chatClient = ChatClient.builder(this.chatModel).build();
121+
122+
do {
123+
124+
chatResponse = chatClient.prompt()
125+
.messages(messages)
126+
.functions(this.functionDefinition)
127+
.options(OpenAiChatOptions.builder().withProxyToolCalls(true).build())
128+
.call()
129+
.chatResponse();
130+
131+
// Note that the tool call check could be platform specific because the finish
132+
// reasons.
133+
isToolCall = this.toolCallHelper.isToolCall(chatResponse,
134+
Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
135+
OpenAiApi.ChatCompletionFinishReason.STOP.name()));
136+
137+
if (isToolCall) {
138+
139+
Optional<Generation> toolCallGeneration = chatResponse.getResults()
140+
.stream()
141+
.filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls()))
142+
.findFirst();
143+
144+
assertThat(toolCallGeneration).isNotEmpty();
145+
146+
AssistantMessage assistantMessage = toolCallGeneration.get().getOutput();
147+
148+
List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();
149+
150+
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
151+
152+
var functionName = toolCall.name();
153+
154+
assertThat(functionName).isEqualTo("getWeatherInLocation");
155+
156+
String functionArguments = toolCall.arguments();
157+
158+
@SuppressWarnings("unchecked")
159+
Map<String, String> argumentsMap = new ObjectMapper().readValue(functionArguments, Map.class);
160+
161+
String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(),
162+
argumentsMap.get("unit").toString());
163+
164+
toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), functionName,
165+
ModelOptionsUtils.toJsonString(functionResponse)));
166+
}
167+
168+
ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of());
169+
170+
messages = this.toolCallHelper.buildToolCallConversation(messages, assistantMessage,
171+
toolMessageResponse);
172+
173+
assertThat(messages).isNotEmpty();
174+
175+
// prompt = new Prompt(toolCallConversation, prompt.getOptions());
176+
}
177+
}
178+
while (isToolCall);
179+
180+
logger.info("Response: {}", chatResponse);
181+
182+
assertThat(chatResponse.getResult().getOutput().getContent()).contains("30", "10", "15");
183+
}
184+
185+
}

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ public record AdvisedRequest(
8484

8585
public AdvisedRequest {
8686
Assert.notNull(chatModel, "chatModel cannot be null");
87-
Assert.hasText(userText, "userText cannot be null or empty");
87+
Assert.isTrue(StringUtils.hasText(userText) || !CollectionUtils.isEmpty(messages),
88+
"userText cannot be null or empty unless messages are provided and contain Tool Response message.");
8889
Assert.notNull(media, "media cannot be null");
8990
Assert.noNullElements(media, "media cannot contain null elements");
9091
Assert.notNull(functionNames, "functionNames cannot be null");

spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,17 @@ void whenUserTextIsNullThenThrows() {
5454
assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), null, null, null, List.of(), List.of(),
5555
List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of()))
5656
.isInstanceOf(IllegalArgumentException.class)
57-
.hasMessage("userText cannot be null or empty");
57+
.hasMessage(
58+
"userText cannot be null or empty unless messages are provided and contain Tool Response message.");
5859
}
5960

6061
@Test
6162
void whenUserTextIsEmptyThenThrows() {
6263
assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "", null, null, List.of(), List.of(),
6364
List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of()))
6465
.isInstanceOf(IllegalArgumentException.class)
65-
.hasMessage("userText cannot be null or empty");
66+
.hasMessage(
67+
"userText cannot be null or empty unless messages are provided and contain Tool Response message.");
6668
}
6769

6870
@Test

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatProperties.java

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
package org.springframework.ai.autoconfigure.bedrock.converse;
1818

19-
import org.springframework.ai.bedrock.converse.BedrockProxyChatModel;
2019
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
2120
import org.springframework.boot.context.properties.ConfigurationProperties;
2221
import org.springframework.boot.context.properties.NestedConfigurationProperty;
@@ -38,12 +37,6 @@ public class BedrockConverseProxyChatProperties {
3837
*/
3938
private boolean enabled = true;
4039

41-
/**
42-
* The generative id to use. See the {@link BedrockProxyChatModel} for the supported
43-
* models.
44-
*/
45-
private String model = "anthropic.claude-3-5-sonnet-20240620-v1:0";
46-
4740
@NestedConfigurationProperty
4841
private PortableFunctionCallingOptions options = PortableFunctionCallingOptions.builder()
4942
.withTemperature(0.7)
@@ -59,14 +52,6 @@ public void setEnabled(boolean enabled) {
5952
this.enabled = enabled;
6053
}
6154

62-
public String getModel() {
63-
return this.model;
64-
}
65-
66-
public void setModel(String model) {
67-
this.model = model;
68-
}
69-
7055
public PortableFunctionCallingOptions getOptions() {
7156
return this.options;
7257
}

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatPropertiesTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ public void chatCompletionDisabled() {
7171
.run(context -> assertThat(context.getBeansOfType(BedrockConverseProxyChatProperties.class)).isNotEmpty());
7272

7373
// Explicitly enable the chat auto-configuration.
74-
new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.converse.chat..enabled=true")
74+
new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.converse.chat.enabled=true")
7575
.withConfiguration(AutoConfigurations.of(BedrockConverseProxyChatAutoConfiguration.class))
7676
.run(context -> assertThat(context.getBeansOfType(BedrockConverseProxyChatProperties.class)).isNotEmpty());
7777

7878
// Explicitly disable the chat auto-configuration.
79-
new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.converse.chat..enabled=false")
79+
new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.converse.chat.enabled=false")
8080
.withConfiguration(AutoConfigurations.of(BedrockConverseProxyChatAutoConfiguration.class))
8181
.run(context -> assertThat(context.getBeansOfType(BedrockConverseProxyChatProperties.class)).isEmpty());
8282
}

0 commit comments

Comments
 (0)