Skip to content

Commit 0b14cee

Browse files
committed
refactor(ai): improve function callback API and response handling
- Simplify function callback builder API by requiring name parameter upfront - Add configurable response converter for method function callbacks and improve string response handling by avoiding unnecessary JSON conversion - Add optional custom name support for method callbacks - Fix text, documentation and code style
1 parent 5f09c21 commit 0b14cee

File tree

56 files changed

+171
-270
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+171
-270
lines changed

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,7 @@ void functionCallTest() {
259259
.withFunctionCallbacks(List.of(FunctionCallback.builder()
260260
.description(
261261
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
262-
.function(new MockWeatherService())
263-
.name("getCurrentWeather")
262+
.function("getCurrentWeather", new MockWeatherService())
264263
.inputType(MockWeatherService.Request.class)
265264
.build()))
266265
.build();
@@ -288,8 +287,7 @@ void streamFunctionCallTest() {
288287
.withFunctionCallbacks(List.of(FunctionCallback.builder()
289288
.description(
290289
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
291-
.function(new MockWeatherService())
292-
.name("getCurrentWeather")
290+
.function("getCurrentWeather", new MockWeatherService())
293291
.inputType(MockWeatherService.Request.class)
294292
.build()))
295293
.build();

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ void functionCallTest() {
7272
.withDeploymentName(this.selectedModel)
7373
.withFunctionCallbacks(List.of(FunctionCallback.builder()
7474
.description("Get the current weather in a given location")
75-
.function(new MockWeatherService())
76-
.name("getCurrentWeather")
75+
.function("getCurrentWeather", new MockWeatherService())
7776
.inputType(MockWeatherService.Request.class)
7877
.build()))
7978
.build();
@@ -97,8 +96,7 @@ void functionCallSequentialTest() {
9796
.withDeploymentName(this.selectedModel)
9897
.withFunctionCallbacks(List.of(FunctionCallback.builder()
9998
.description("Get the current weather in a given location")
100-
.function(new MockWeatherService())
101-
.name("getCurrentWeather")
99+
.function("getCurrentWeather", new MockWeatherService())
102100
.inputType(MockWeatherService.Request.class)
103101
.build()))
104102
.build();
@@ -120,8 +118,7 @@ void streamFunctionCallTest() {
120118
.withDeploymentName(this.selectedModel)
121119
.withFunctionCallbacks(List.of(FunctionCallback.builder()
122120
.description("Get the current weather in a given location")
123-
.function(new MockWeatherService())
124-
.name("getCurrentWeather")
121+
.function("getCurrentWeather", new MockWeatherService())
125122
.inputType(MockWeatherService.Request.class)
126123
.build()))
127124
.build();
@@ -158,8 +155,7 @@ void functionCallSequentialAndStreamTest() {
158155
.withDeploymentName(this.selectedModel)
159156
.withFunctionCallbacks(List.of(FunctionCallback.builder()
160157
.description("Get the current weather in a given location")
161-
.function(new MockWeatherService())
162-
.name("getCurrentWeather")
158+
.function("getCurrentWeather", new MockWeatherService())
163159
.inputType(MockWeatherService.Request.class)
164160
.build()))
165161
.build();

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,7 @@ void functionCallTest() {
257257
.withFunctionCallbacks(List.of(FunctionCallback.builder()
258258
.description(
259259
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
260-
.function(new MockWeatherService())
261-
.name("getCurrentWeather")
260+
.function("getCurrentWeather", new MockWeatherService())
262261
.inputType(MockWeatherService.Request.class)
263262
.build()))
264263
.build();
@@ -286,8 +285,7 @@ void streamFunctionCallTest() {
286285
.withFunctionCallbacks(List.of(FunctionCallback.builder()
287286
.description(
288287
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
289-
.function(new MockWeatherService())
290-
.name("getCurrentWeather")
288+
.function("getCurrentWeather", new MockWeatherService())
291289
.inputType(MockWeatherService.Request.class)
292290
.build()))
293291
.build();

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiements/BedrockConverseChatModelMain2.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ public static void main(String[] args) {
5454
.withModel(modelId)
5555
.withFunctionCallbacks(List.of(FunctionCallback.builder()
5656
.description("Get the weather in location")
57-
.function(new MockWeatherService())
58-
.name("getCurrentWeather")
57+
.function("getCurrentWeather", new MockWeatherService())
5958
.inputType(MockWeatherService.Request.class)
6059
.build()))
6160
.build());

models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/ChatCompletionRequestTests.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ public void promptOptionsTools() {
6969
.withModel("PROMPT_MODEL")
7070
.withFunctionCallbacks(List.of(FunctionCallback.builder()
7171
.description("Get the weather in location")
72-
.function(new MockWeatherService())
73-
.name(TOOL_FUNCTION_NAME)
72+
.function(TOOL_FUNCTION_NAME, new MockWeatherService())
7473
.inputType(MockWeatherService.Request.class)
7574
.build()))
7675
.build()),
@@ -97,8 +96,7 @@ public void defaultOptionsTools() {
9796
.withModel("DEFAULT_MODEL")
9897
.withFunctionCallbacks(List.of(FunctionCallback.builder()
9998
.description("Get the weather in location")
100-
.function(new MockWeatherService())
101-
.name(TOOL_FUNCTION_NAME)
99+
.function(TOOL_FUNCTION_NAME, new MockWeatherService())
102100
.inputType(MockWeatherService.Request.class)
103101
.build()))
104102
.build());
@@ -130,8 +128,7 @@ public void defaultOptionsTools() {
130128
MiniMaxChatOptions.builder()
131129
.withFunctionCallbacks(List.of(FunctionCallback.builder()
132130
.description("Overridden function description")
133-
.function(new MockWeatherService())
134-
.name(TOOL_FUNCTION_NAME)
131+
.function(TOOL_FUNCTION_NAME, new MockWeatherService())
135132
.inputType(MockWeatherService.Request.class)
136133
.build()))
137134
.build()),

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,7 @@ void functionCallTest() {
195195
.withModel(MistralAiApi.ChatModel.SMALL.getValue())
196196
.withFunctionCallbacks(List.of(FunctionCallback.builder()
197197
.description("Get the weather in location")
198-
.function(new MockWeatherService())
199-
.name("getCurrentWeather")
198+
.function("getCurrentWeather", new MockWeatherService())
200199
.inputType(MockWeatherService.Request.class)
201200
.build()))
202201
.build();
@@ -219,8 +218,7 @@ void streamFunctionCallTest() {
219218
.withModel(MistralAiApi.ChatModel.SMALL.getValue())
220219
.withFunctionCallbacks(List.of(FunctionCallback.builder()
221220
.description("Get the weather in location")
222-
.function(new MockWeatherService())
223-
.name("getCurrentWeather")
221+
.function("getCurrentWeather", new MockWeatherService())
224222
.inputType(MockWeatherService.Request.class)
225223
.build()))
226224
.build();

models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@ void functionCallTest() {
6565
.withModel(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
6666
.withFunctionCallbacks(List.of(FunctionCallback.builder()
6767
.description("Get the weather in location")
68-
.function(new MockWeatherService())
69-
.name("getCurrentWeather")
68+
.function("getCurrentWeather", new MockWeatherService())
7069
.inputType(MockWeatherService.Request.class)
7170
.build()))
7271
.build();
@@ -89,8 +88,7 @@ void streamFunctionCallTest() {
8988
var promptOptions = MoonshotChatOptions.builder()
9089
.withFunctionCallbacks(List.of(FunctionCallback.builder()
9190
.description("Get the weather in location")
92-
.function(new MockWeatherService())
93-
.name("getCurrentWeather")
91+
.function("getCurrentWeather", new MockWeatherService())
9492
.build()))
9593
.build();
9694

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ void functionCallTest() {
6666
.withFunctionCallbacks(List.of(FunctionCallback.builder()
6767
.description(
6868
"Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
69-
.function(new MockWeatherService())
70-
.name("getCurrentWeather")
69+
.function("getCurrentWeather", new MockWeatherService())
7170
.inputType(MockWeatherService.Request.class)
7271
.build()))
7372
.build();
@@ -92,8 +91,7 @@ void streamFunctionCallTest() {
9291
.withFunctionCallbacks(List.of(FunctionCallback.builder()
9392
.description(
9493
"Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
95-
.function(new MockWeatherService())
96-
.name("getCurrentWeather")
94+
.function("getCurrentWeather", new MockWeatherService())
9795
.inputType(MockWeatherService.Request.class)
9896
.build()))
9997
.build();

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ public void promptOptionsTools() {
6969
.withModel("PROMPT_MODEL")
7070
.withFunctionCallbacks(List.of(FunctionCallback.builder()
7171
.description("Get the weather in location")
72-
.function(new MockWeatherService())
73-
.name(TOOL_FUNCTION_NAME)
72+
.function(TOOL_FUNCTION_NAME, new MockWeatherService())
7473
.inputType(MockWeatherService.Request.class)
7574
.build()))
7675
.build()),
@@ -97,8 +96,7 @@ public void defaultOptionsTools() {
9796
.withModel("DEFAULT_MODEL")
9897
.withFunctionCallbacks(List.of(FunctionCallback.builder()
9998
.description("Get the weather in location")
100-
.function(new MockWeatherService())
101-
.name(TOOL_FUNCTION_NAME)
99+
.function(TOOL_FUNCTION_NAME, new MockWeatherService())
102100
.inputType(MockWeatherService.Request.class)
103101
.build()))
104102
.build());
@@ -130,8 +128,7 @@ public void defaultOptionsTools() {
130128
OpenAiChatOptions.builder()
131129
.withFunctionCallbacks(List.of(FunctionCallback.builder()
132130
.description("Overridden function description")
133-
.function(new MockWeatherService())
134-
.name(TOOL_FUNCTION_NAME)
131+
.function(TOOL_FUNCTION_NAME, new MockWeatherService())
135132
.inputType(MockWeatherService.Request.class)
136133
.build()))
137134
.build()),

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@ void functionCallTest() {
6565
.withModel(OpenAiApi.ChatModel.GPT_4_O.getValue())
6666
.withFunctionCallbacks(List.of(FunctionCallback.builder()
6767
.description("Get the weather in location")
68-
.function(new MockWeatherService())
69-
.name("getCurrentWeather")
68+
.function("getCurrentWeather", new MockWeatherService())
7069
.inputType(MockWeatherService.Request.class)
7170
.build()))
7271
.build());
@@ -102,8 +101,7 @@ else if (request.location().contains("San Francisco")) {
102101
.withModel(OpenAiApi.ChatModel.GPT_4_O.getValue())
103102
.withFunctionCallbacks(List.of(FunctionCallback.builder()
104103
.description("Get the weather in location")
105-
.function(biFunction)
106-
.name("getCurrentWeather")
104+
.function("getCurrentWeather", biFunction)
107105
.inputType(MockWeatherService.Request.class)
108106
.build()))
109107
.withToolContext(Map.of("sessionId", "123"))
@@ -129,8 +127,7 @@ void streamFunctionCallTest() {
129127
streamFunctionCallTest(OpenAiChatOptions.builder()
130128
.withFunctionCallbacks(List.of((FunctionCallback.builder()
131129
.description("Get the weather in location")
132-
.function(new MockWeatherService())
133-
.name("getCurrentWeather")
130+
.function("getCurrentWeather", new MockWeatherService())
134131
.inputType(MockWeatherService.Request.class)
135132
// .responseConverter(response -> "" + response.temp() + response.unit())
136133
.build())))
@@ -166,10 +163,8 @@ else if (request.location().contains("San Francisco")) {
166163
OpenAiChatOptions promptOptions = OpenAiChatOptions.builder()
167164
.withFunctionCallbacks(List.of((FunctionCallback.builder()
168165
.description("Get the weather in location")
169-
.function(biFunction)
170-
.name("getCurrentWeather")
166+
.function("getCurrentWeather", biFunction)
171167
.inputType(MockWeatherService.Request.class)
172-
// .responseConverter(response -> "" + response.temp() + response.unit())
173168
.build())))
174169
.withToolContext(Map.of("sessionId", "123"))
175170
.build();

0 commit comments

Comments
 (0)