Skip to content

Commit 5cd31ae

Browse files
committed
Migrate the openai, vertex
1 parent 97a4476 commit 5cd31ae

File tree

8 files changed

+86
-115
lines changed

8 files changed

+86
-115
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@
3737
import org.springframework.ai.chat.model.ChatResponse;
3838
import org.springframework.ai.converter.BeanOutputConverter;
3939
import org.springframework.ai.converter.ListOutputConverter;
40-
import org.springframework.ai.model.function.FunctionCallback;
4140
import org.springframework.ai.openai.OpenAiChatOptions;
4241
import org.springframework.ai.openai.OpenAiTestConfiguration;
4342
import org.springframework.ai.openai.api.OpenAiApi;
4443
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters;
4544
import org.springframework.ai.openai.api.tool.MockWeatherService;
4645
import org.springframework.ai.openai.testutils.AbstractIT;
46+
import org.springframework.ai.tool.function.FunctionToolCallback;
4747
import org.springframework.beans.factory.annotation.Value;
4848
import org.springframework.boot.test.context.SpringBootTest;
4949
import org.springframework.core.ParameterizedTypeReference;
@@ -246,16 +246,13 @@ void beanStreamOutputConverterRecords() {
246246
@Test
247247
void functionCallTest() {
248248

249-
FunctionCallback functionCallback = FunctionCallback.builder()
250-
.function("getCurrentWeather", new MockWeatherService())
251-
.description("Get the weather in location")
252-
.inputType(MockWeatherService.Request.class)
253-
.build();
254-
255249
// @formatter:off
256250
String response = ChatClient.create(this.chatModel).prompt()
257251
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
258-
.functions(functionCallback)
252+
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
253+
.description("Get the weather in location")
254+
.inputType(MockWeatherService.Request.class)
255+
.build())
259256
.call()
260257
.content();
261258
// @formatter:on
@@ -270,8 +267,7 @@ void defaultFunctionCallTest() {
270267

271268
// @formatter:off
272269
String response = ChatClient.builder(this.chatModel)
273-
.defaultFunctions(FunctionCallback.builder()
274-
.function("getCurrentWeather", new MockWeatherService())
270+
.defaultTools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
275271
.description("Get the weather in location")
276272
.inputType(MockWeatherService.Request.class)
277273
.build())
@@ -291,8 +287,7 @@ void streamFunctionCallTest() {
291287
// @formatter:off
292288
Flux<String> response = ChatClient.create(this.chatModel).prompt()
293289
.user("What's the weather like in San Francisco, Tokyo, and Paris?")
294-
.functions(FunctionCallback.builder()
295-
.function("getCurrentWeather", new MockWeatherService())
290+
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
296291
.description("Get the weather in location")
297292
.inputType(MockWeatherService.Request.class)
298293
.build())

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMethodInvokingFunctionCallbackIT.java

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@
2828
import org.springframework.ai.chat.client.ChatClient;
2929
import org.springframework.ai.chat.model.ChatModel;
3030
import org.springframework.ai.chat.model.ToolContext;
31-
import org.springframework.ai.model.function.FunctionCallback;
3231
import org.springframework.ai.openai.OpenAiTestConfiguration;
32+
import org.springframework.ai.tool.definition.ToolDefinition;
33+
import org.springframework.ai.tool.method.MethodToolCallback;
3334
import org.springframework.beans.factory.annotation.Autowired;
3435
import org.springframework.boot.test.context.SpringBootTest;
3536
import org.springframework.test.context.ActiveProfiles;
37+
import org.springframework.util.ReflectionUtils;
3638

3739
import static org.assertj.core.api.Assertions.assertThat;
3840
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
@@ -57,13 +59,18 @@ void beforeEach() {
5759

5860
@Test
5961
void methodGetWeatherStatic() {
62+
63+
var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherStatic", String.class,
64+
Unit.class);
65+
6066
// @formatter:off
6167
String response = ChatClient.create(this.chatModel).prompt()
6268
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
63-
.functions(FunctionCallback.builder()
64-
.method("getWeatherStatic", String.class, Unit.class)
65-
.description("Get the weather in location")
66-
.targetClass(TestFunctionClass.class)
69+
.tools(MethodToolCallback.builder()
70+
.toolDefinition(ToolDefinition.builder(toolMethod)
71+
.description("Get the weather in location")
72+
.build())
73+
.toolMethod(toolMethod)
6774
.build())
6875
.call()
6976
.content();
@@ -79,13 +86,17 @@ void methodTurnLightNoResponse() {
7986

8087
TestFunctionClass targetObject = new TestFunctionClass();
8188

89+
var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "turnLight", String.class, boolean.class);
90+
8291
// @formatter:off
8392
String response = ChatClient.create(this.chatModel).prompt()
8493
.user("Turn light on in the living room.")
85-
.functions(FunctionCallback.builder()
86-
.method("turnLight", String.class, boolean.class)
87-
.description("Can turn lights on or off by room name")
88-
.targetObject(targetObject)
94+
.tools(MethodToolCallback.builder()
95+
.toolDefinition(ToolDefinition.builder(toolMethod)
96+
.description("Can turn lights on or off by room name")
97+
.build())
98+
.toolMethod(toolMethod)
99+
.toolObject(targetObject)
89100
.build())
90101
.call()
91102
.content();
@@ -102,13 +113,18 @@ void methodGetWeatherNonStatic() {
102113

103114
TestFunctionClass targetObject = new TestFunctionClass();
104115

116+
var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherNonStatic", String.class,
117+
Unit.class);
118+
105119
// @formatter:off
106120
String response = ChatClient.create(this.chatModel).prompt()
107121
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
108-
.functions(FunctionCallback.builder()
109-
.method("getWeatherNonStatic", String.class, Unit.class)
110-
.description("Get the weather in location")
111-
.targetObject(targetObject)
122+
.tools(MethodToolCallback.builder()
123+
.toolDefinition(ToolDefinition.builder(toolMethod)
124+
.description("Get the weather in location")
125+
.build())
126+
.toolMethod(toolMethod)
127+
.toolObject(targetObject)
112128
.build())
113129
.call()
114130
.content();
@@ -124,13 +140,18 @@ void methodGetWeatherToolContext() {
124140

125141
TestFunctionClass targetObject = new TestFunctionClass();
126142

143+
var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherWithContext", String.class,
144+
Unit.class, ToolContext.class);
145+
127146
// @formatter:off
128147
String response = ChatClient.create(this.chatModel).prompt()
129148
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
130-
.functions(FunctionCallback.builder()
131-
.method("getWeatherWithContext", String.class, Unit.class, ToolContext.class)
132-
.description("Get the weather in location")
133-
.targetObject(targetObject)
149+
.tools(MethodToolCallback.builder()
150+
.toolDefinition(ToolDefinition.builder(toolMethod)
151+
.description("Get the weather in location")
152+
.build())
153+
.toolMethod(toolMethod)
154+
.toolObject(targetObject)
134155
.build())
135156
.toolContext(Map.of("tool", "value"))
136157
.call()
@@ -148,19 +169,24 @@ void methodGetWeatherToolContextButNonContextMethod() {
148169

149170
TestFunctionClass targetObject = new TestFunctionClass();
150171

172+
var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherNonStatic", String.class,
173+
Unit.class);
174+
151175
// @formatter:off
152176
assertThatThrownBy(() -> ChatClient.create(this.chatModel).prompt()
153177
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
154-
.functions(FunctionCallback.builder()
155-
.method("getWeatherNonStatic", String.class, Unit.class)
156-
.description("Get the weather in location")
157-
.targetObject(targetObject)
178+
.tools(MethodToolCallback.builder()
179+
.toolDefinition(ToolDefinition.builder(toolMethod)
180+
.description("Get the weather in location")
181+
.build())
182+
.toolMethod(toolMethod)
183+
.toolObject(targetObject)
158184
.build())
159185
.toolContext(Map.of("tool", "value"))
160186
.call()
161187
.content())
162188
.isInstanceOf(IllegalArgumentException.class)
163-
.hasMessage("Configured method does not accept ToolContext as input parameter!");
189+
.hasMessage("ToolContext is not supported by the method as an argument");
164190
// @formatter:on
165191
}
166192

@@ -169,13 +195,17 @@ void methodNoParameters() {
169195

170196
TestFunctionClass targetObject = new TestFunctionClass();
171197

198+
var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "turnLivingRoomLightOn");
199+
172200
// @formatter:off
173201
String response = ChatClient.create(this.chatModel).prompt()
174202
.user("Turn light on in the living room.")
175-
.functions(FunctionCallback.builder()
176-
.method("turnLivingRoomLightOn")
177-
.description("Can turn lights on in the Living Room")
178-
.targetObject(targetObject)
203+
.tools(MethodToolCallback.builder()
204+
.toolDefinition(ToolDefinition.builder(toolMethod)
205+
.description("Can turn lights on in the Living Room")
206+
.build())
207+
.toolMethod(toolMethod)
208+
.toolObject(targetObject)
179209
.build())
180210
.call()
181211
.content();

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@
3131

3232
import org.springframework.ai.chat.client.ChatClient;
3333
import org.springframework.ai.chat.model.ToolContext;
34-
import org.springframework.ai.model.function.FunctionCallback;
3534
import org.springframework.ai.openai.OpenAiTestConfiguration;
3635
import org.springframework.ai.openai.api.tool.MockWeatherService;
3736
import org.springframework.ai.openai.api.tool.MockWeatherService.Request;
3837
import org.springframework.ai.openai.api.tool.MockWeatherService.Response;
3938
import org.springframework.ai.openai.testutils.AbstractIT;
39+
import org.springframework.ai.tool.function.FunctionToolCallback;
4040
import org.springframework.beans.factory.annotation.Value;
4141
import org.springframework.boot.test.context.SpringBootTest;
4242
import org.springframework.core.io.Resource;
@@ -84,8 +84,7 @@ void turnFunctionsOnAndOffTest() {
8484
// @formatter:off
8585
response = chatClientBuilder.build().prompt()
8686
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
87-
.functions(FunctionCallback.builder()
88-
.function("getCurrentWeather", new MockWeatherService())
87+
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
8988
.description("Get the weather in location")
9089
.inputType(MockWeatherService.Request.class)
9190
.build())
@@ -115,8 +114,7 @@ void defaultFunctionCallTest() {
115114

116115
// @formatter:off
117116
String response = ChatClient.builder(this.chatModel)
118-
.defaultFunctions(FunctionCallback.builder()
119-
.function("getCurrentWeather", new MockWeatherService())
117+
.defaultTools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
120118
.description("Get the weather in location")
121119
.inputType(MockWeatherService.Request.class)
122120
.build())
@@ -158,8 +156,7 @@ else if (request.location().contains("San Francisco")) {
158156

159157
// @formatter:off
160158
String response = ChatClient.builder(this.chatModel)
161-
.defaultFunctions(FunctionCallback.builder()
162-
.function("getCurrentWeather", biFunction)
159+
.defaultTools(FunctionToolCallback.builder("getCurrentWeather", biFunction)
163160
.description("Get the weather in location")
164161
.inputType(MockWeatherService.Request.class)
165162
.build())
@@ -202,8 +199,7 @@ else if (request.location().contains("San Francisco")) {
202199

203200
// @formatter:off
204201
String response = ChatClient.builder(this.chatModel)
205-
.defaultFunctions(FunctionCallback.builder()
206-
.function("getCurrentWeather", biFunction)
202+
.defaultTools(FunctionToolCallback.builder("getCurrentWeather", biFunction)
207203
.description("Get the weather in location")
208204
.inputType(MockWeatherService.Request.class)
209205
.build())
@@ -225,8 +221,7 @@ void streamFunctionCallTest() {
225221
// @formatter:off
226222
Flux<String> response = ChatClient.create(this.chatModel).prompt()
227223
.user("What's the weather like in San Francisco, Tokyo, and Paris?")
228-
.functions(FunctionCallback.builder()
229-
.function("getCurrentWeather", new MockWeatherService())
224+
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
230225
.description("Get the weather in location")
231226
.inputType(MockWeatherService.Request.class)
232227
.build())
@@ -255,8 +250,7 @@ void functionCallWithExplicitInputType() throws NoSuchMethodException {
255250

256251
String content = chatClient.prompt()
257252
.user("What's the weather like in Shanghai?")
258-
.functions(FunctionCallback.builder()
259-
.function("currentTemp", function)
253+
.tools(FunctionToolCallback.builder("currentTemp", function)
260254
.description("get current temp")
261255
.inputType(MyFunction.Req.class)
262256
.build())

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientProxyFunctionCallsIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ void toolProxyFunctionCall() throws JsonMappingException, JsonProcessingExceptio
123123

124124
chatResponse = chatClient.prompt()
125125
.messages(messages)
126-
.functions(this.functionDefinition)
126+
.tools(this.functionDefinition)
127127
.options(OpenAiChatOptions.builder().proxyToolCalls(true).build())
128128
.call()
129129
.chatResponse();

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java

Lines changed: 10 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,8 @@
3535
import org.springframework.ai.chat.model.ChatResponse;
3636
import org.springframework.ai.chat.model.Generation;
3737
import org.springframework.ai.chat.prompt.Prompt;
38-
import org.springframework.ai.model.function.FunctionCallback;
39-
import org.springframework.ai.model.function.FunctionCallback.SchemaType;
4038
import org.springframework.ai.tool.function.FunctionToolCallback;
39+
import org.springframework.ai.util.json.JsonSchemaGenerator;
4140
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel;
4241
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions;
4342
import org.springframework.beans.factory.annotation.Autowired;
@@ -101,46 +100,6 @@ public void functionCallExplicitOpenApiSchema() {
101100
@Test
102101
public void functionCallTestInferredOpenApiSchema() {
103102

104-
UserMessage userMessage = new UserMessage("What's the weather like in Paris? Use Celsius units.");
105-
106-
List<Message> messages = new ArrayList<>(List.of(userMessage));
107-
108-
var promptOptions = VertexAiGeminiChatOptions.builder()
109-
.model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH)
110-
.functionCallbacks(List.of(
111-
FunctionCallback.builder()
112-
.function("get_current_weather", new MockWeatherService())
113-
.schemaType(SchemaType.OPEN_API_SCHEMA)
114-
.description("Get the current weather in a given location.")
115-
.inputType(MockWeatherService.Request.class)
116-
.build(),
117-
FunctionCallback.builder()
118-
.function("get_payment_status", new PaymentStatus())
119-
.schemaType(SchemaType.OPEN_API_SCHEMA)
120-
.description(
121-
"Retrieves the payment status for transaction. For example what is the payment status for transaction 700?")
122-
.inputType(PaymentInfoRequest.class)
123-
.build()))
124-
.build();
125-
126-
ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions));
127-
128-
logger.info("Response: {}", response);
129-
130-
assertThat(response.getResult().getOutput().getText()).containsAnyOf("15.0", "15");
131-
132-
ChatResponse response2 = this.chatModel
133-
.call(new Prompt("What is the payment status for transaction 696?", promptOptions));
134-
135-
logger.info("Response: {}", response2);
136-
137-
assertThat(response2.getResult().getOutput().getText()).containsIgnoringCase("transaction 696 is PAYED");
138-
139-
}
140-
141-
@Test
142-
public void functionCallTestInferredOpenApiSchema2() {
143-
144103
UserMessage userMessage = new UserMessage(
145104
"What's the weather like in San Francisco, Paris and in Tokyo? Return the temperature in Celsius.");
146105

@@ -149,15 +108,15 @@ public void functionCallTestInferredOpenApiSchema2() {
149108
var promptOptions = VertexAiGeminiChatOptions.builder()
150109
.model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH)
151110
.functionCallbacks(List.of(
152-
FunctionCallback.builder()
153-
.function("get_current_weather", new MockWeatherService())
154-
.schemaType(SchemaType.OPEN_API_SCHEMA)
111+
FunctionToolCallback.builder("get_current_weather", new MockWeatherService())
112+
.inputSchema(JsonSchemaGenerator.generateForType(MockWeatherService.Request.class,
113+
JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES))
155114
.description("Get the current weather in a given location.")
156115
.inputType(MockWeatherService.Request.class)
157116
.build(),
158-
FunctionCallback.builder()
159-
.function("get_payment_status", new PaymentStatus())
160-
.schemaType(SchemaType.OPEN_API_SCHEMA)
117+
FunctionToolCallback.builder("get_payment_status", new PaymentStatus())
118+
.inputSchema(JsonSchemaGenerator.generateForType(PaymentInfoRequest.class,
119+
JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES))
161120
.description(
162121
"Retrieves the payment status for transaction. For example what is the payment status for transaction 700?")
163122
.inputType(PaymentInfoRequest.class)
@@ -189,9 +148,9 @@ public void functionCallTestInferredOpenApiSchemaStream() {
189148

190149
var promptOptions = VertexAiGeminiChatOptions.builder()
191150
.model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH)
192-
.functionCallbacks(List.of(FunctionCallback.builder()
193-
.function("getCurrentWeather", new MockWeatherService())
194-
.schemaType(SchemaType.OPEN_API_SCHEMA)
151+
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
152+
.inputSchema(JsonSchemaGenerator.generateForType(MockWeatherService.Request.class,
153+
JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES))
195154
.description("Get the current weather in a given location")
196155
.inputType(MockWeatherService.Request.class)
197156
.build()))

0 commit comments

Comments
 (0)