Skip to content

Commit 2b9a7cb

Browse files
committed
refactor: Improve error handling and validation in DefaultFunctionCallbackBuilder
- Add deprecation warnings for older ChatClient API methods - Update function callback usage in integration tests - Add deprecation warnings for older ChatClient API methods
1 parent ca7936f commit 2b9a7cb

File tree

10 files changed

+451
-40
lines changed

10 files changed

+451
-40
lines changed

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,11 @@ void functionCallTest() {
211211

212212
// @formatter:off
213213
String response = ChatClient.create(this.chatModel).prompt()
214-
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius."))
215-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
214+
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
215+
.functions(FunctionCallback.builder()
216+
.function("getCurrentWeather", new MockWeatherService())
217+
.inputType(MockWeatherService.Request.class)
218+
.build())
216219
.call()
217220
.content();
218221
// @formatter:on
@@ -246,7 +249,11 @@ void defaultFunctionCallTest() {
246249

247250
// @formatter:off
248251
String response = ChatClient.builder(this.chatModel)
249-
.defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService())
252+
.defaultFunctions(FunctionCallback.builder()
253+
.description("Get the weather in location")
254+
.function("getCurrentWeather", new MockWeatherService())
255+
.inputType(MockWeatherService.Request.class)
256+
.build())
250257
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius."))
251258
.build()
252259
.prompt()
@@ -265,7 +272,11 @@ void streamFunctionCallTest() {
265272
// @formatter:off
266273
Flux<String> response = ChatClient.create(this.chatModel).prompt()
267274
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
268-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
275+
.functions(FunctionCallback.builder()
276+
.description("Get the weather in location")
277+
.function("getCurrentWeather", new MockWeatherService())
278+
.inputType(MockWeatherService.Request.class)
279+
.build())
269280
.stream()
270281
.content();
271282
// @formatter:on

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.springframework.ai.chat.model.ChatResponse;
3838
import org.springframework.ai.converter.BeanOutputConverter;
3939
import org.springframework.ai.converter.ListOutputConverter;
40+
import org.springframework.ai.model.function.FunctionCallback;
4041
import org.springframework.ai.model.function.FunctionCallingOptions;
4142
import org.springframework.beans.factory.annotation.Autowired;
4243
import org.springframework.beans.factory.annotation.Value;
@@ -212,7 +213,11 @@ void functionCallTest() {
212213
// @formatter:off
213214
String response = ChatClient.create(this.chatModel)
214215
.prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
215-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
216+
.functions(FunctionCallback.builder()
217+
.description("Get the weather in location")
218+
.function("getCurrentWeather", new MockWeatherService())
219+
.inputType(MockWeatherService.Request.class)
220+
.build())
216221
.call()
217222
.content();
218223
// @formatter:on
@@ -228,7 +233,11 @@ void functionCallWithAdvisorTest() {
228233
// @formatter:off
229234
String response = ChatClient.create(this.chatModel)
230235
.prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
231-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
236+
.functions(FunctionCallback.builder()
237+
.description("Get the weather in location")
238+
.function("getCurrentWeather", new MockWeatherService())
239+
.inputType(MockWeatherService.Request.class)
240+
.build())
232241
.advisors(new SimpleLoggerAdvisor())
233242
.call()
234243
.content();
@@ -244,7 +253,11 @@ void defaultFunctionCallTest() {
244253

245254
// @formatter:off
246255
String response = ChatClient.builder(this.chatModel)
247-
.defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService())
256+
.defaultFunctions(FunctionCallback.builder()
257+
.description("Get the weather in location")
258+
.function("getCurrentWeather", new MockWeatherService())
259+
.inputType(MockWeatherService.Request.class)
260+
.build())
248261
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius."))
249262
.build()
250263
.prompt()
@@ -263,7 +276,11 @@ void streamFunctionCallTest() {
263276
// @formatter:off
264277
Flux<String> response = ChatClient.create(this.chatModel).prompt()
265278
.user("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
266-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
279+
.functions(FunctionCallback.builder()
280+
.description("Get the weather in location")
281+
.function("getCurrentWeather", new MockWeatherService())
282+
.inputType(MockWeatherService.Request.class)
283+
.build())
267284
.stream()
268285
.content();
269286
// @formatter:on
@@ -280,7 +297,11 @@ void singularStreamFunctionCallTest() {
280297
// @formatter:off
281298
Flux<String> response = ChatClient.create(this.chatModel).prompt()
282299
.user("What's the weather like in Paris? Return the temperature in Celsius.")
283-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
300+
.functions(FunctionCallback.builder()
301+
.description("Get the weather in location")
302+
.function("getCurrentWeather", new MockWeatherService())
303+
.inputType(MockWeatherService.Request.class)
304+
.build())
284305
.stream()
285306
.content();
286307
// @formatter:on

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.springframework.ai.converter.ListOutputConverter;
3535
import org.springframework.ai.mistralai.api.MistralAiApi;
3636
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice;
37+
import org.springframework.ai.model.function.FunctionCallback;
3738
import org.springframework.beans.factory.annotation.Autowired;
3839
import org.springframework.beans.factory.annotation.Value;
3940
import org.springframework.boot.test.context.SpringBootTest;
@@ -224,7 +225,11 @@ void functionCallTest() {
224225
String response = ChatClient.create(this.chatModel).prompt()
225226
.options(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).withToolChoice(ToolChoice.AUTO).build())
226227
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius."))
227-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
228+
.functions(FunctionCallback.builder()
229+
.description("Get the weather in location")
230+
.function("getCurrentWeather", new MockWeatherService())
231+
.inputType(MockWeatherService.Request.class)
232+
.build())
228233
.call()
229234
.content();
230235
// @formatter:on
@@ -242,7 +247,11 @@ void defaultFunctionCallTest() {
242247
// @formatter:off
243248
String response = ChatClient.builder(this.chatModel)
244249
.defaultOptions(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).build())
245-
.defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService())
250+
.defaultFunctions(FunctionCallback.builder()
251+
.description("Get the weather in location")
252+
.function("getCurrentWeather", new MockWeatherService())
253+
.inputType(MockWeatherService.Request.class)
254+
.build())
246255
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius."))
247256
.build()
248257
.prompt().call().content();
@@ -262,7 +271,11 @@ void streamFunctionCallTest() {
262271
Flux<String> response = ChatClient.create(this.chatModel).prompt()
263272
.options(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).build())
264273
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.")
265-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
274+
.functions(FunctionCallback.builder()
275+
.description("Get the weather in location")
276+
.function("getCurrentWeather", new MockWeatherService())
277+
.inputType(MockWeatherService.Request.class)
278+
.build())
266279
.stream()
267280
.content();
268281
// @formatter:on

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
import org.springframework.ai.chat.client.ChatClient;
3333
import org.springframework.ai.chat.model.ToolContext;
34+
import org.springframework.ai.model.function.FunctionCallback;
3435
import org.springframework.ai.openai.OpenAiTestConfiguration;
3536
import org.springframework.ai.openai.api.tool.MockWeatherService;
3637
import org.springframework.ai.openai.api.tool.MockWeatherService.Request;
@@ -83,7 +84,11 @@ void turnFunctionsOnAndOffTest() {
8384
// @formatter:off
8485
response = chatClientBuilder.build().prompt()
8586
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
86-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
87+
.functions(FunctionCallback.builder()
88+
.description("Get the weather in location")
89+
.function("getCurrentWeather", new MockWeatherService())
90+
.inputType(MockWeatherService.Request.class)
91+
.build())
8792
.call()
8893
.content();
8994
// @formatter:on
@@ -110,7 +115,11 @@ void defaultFunctionCallTest() {
110115

111116
// @formatter:off
112117
String response = ChatClient.builder(this.chatModel)
113-
.defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService())
118+
.defaultFunctions(FunctionCallback.builder()
119+
.description("Get the weather in location")
120+
.function("getCurrentWeather", new MockWeatherService())
121+
.inputType(MockWeatherService.Request.class)
122+
.build())
114123
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
115124
.build()
116125
.prompt().call().content();
@@ -149,7 +158,11 @@ else if (request.location().contains("San Francisco")) {
149158

150159
// @formatter:off
151160
String response = ChatClient.builder(this.chatModel)
152-
.defaultFunction("getCurrentWeather", "Get the weather in location", biFunction)
161+
.defaultFunctions(FunctionCallback.builder()
162+
.description("Get the weather in location")
163+
.function("getCurrentWeather", biFunction)
164+
.inputType(MockWeatherService.Request.class)
165+
.build())
153166
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
154167
.defaultToolContext(Map.of("sessionId", "123"))
155168
.build()
@@ -189,7 +202,11 @@ else if (request.location().contains("San Francisco")) {
189202

190203
// @formatter:off
191204
String response = ChatClient.builder(this.chatModel)
192-
.defaultFunction("getCurrentWeather", "Get the weather in location", biFunction)
205+
.defaultFunctions(FunctionCallback.builder()
206+
.description("Get the weather in location")
207+
.function("getCurrentWeather", biFunction)
208+
.inputType(MockWeatherService.Request.class)
209+
.build())
193210
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
194211
.build()
195212
.prompt()
@@ -208,7 +225,11 @@ void streamFunctionCallTest() {
208225
// @formatter:off
209226
Flux<String> response = ChatClient.create(this.chatModel).prompt()
210227
.user("What's the weather like in San Francisco, Tokyo, and Paris?")
211-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
228+
.functions(FunctionCallback.builder()
229+
.description("Get the weather in location")
230+
.function("getCurrentWeather", new MockWeatherService())
231+
.inputType(MockWeatherService.Request.class)
232+
.build())
212233
.stream()
213234
.content();
214235
// @formatter:on

spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,14 +212,26 @@ interface ChatClientRequestSpec {
212212

213213
<T extends ChatOptions> ChatClientRequestSpec options(T options);
214214

215+
/**
216+
* @deprecated use {@link #function(FunctionCallback)} instead.
217+
*/
218+
@Deprecated
215219
<I, O> ChatClientRequestSpec function(String name, String description,
216220
java.util.function.Function<I, O> function);
217221

222+
/**
223+
* @deprecated use {@link #function(FunctionCallback)} instead.
224+
*/
225+
@Deprecated
218226
<I, O> ChatClientRequestSpec function(String name, String description,
219227
java.util.function.BiFunction<I, ToolContext, O> function);
220228

221229
<I, O> ChatClientRequestSpec functions(FunctionCallback... functionCallbacks);
222230

231+
/**
232+
* @deprecated use {@link #function(FunctionCallback)} instead.
233+
*/
234+
@Deprecated
223235
<I, O> ChatClientRequestSpec function(String name, String description, Class<I> inputType,
224236
java.util.function.Function<I, O> function);
225237

@@ -278,8 +290,16 @@ interface Builder {
278290

279291
Builder defaultSystem(Consumer<PromptSystemSpec> systemSpecConsumer);
280292

293+
/**
294+
* @deprecated use {@link #defaultFunction(FunctionCallback)} instead.
295+
*/
296+
@Deprecated
281297
<I, O> Builder defaultFunction(String name, String description, java.util.function.Function<I, O> function);
282298

299+
/**
300+
* @deprecated use {@link #defaultFunction(FunctionCallback)} instead.
301+
*/
302+
@Deprecated
283303
<I, O> Builder defaultFunction(String name, String description,
284304
java.util.function.BiFunction<I, ToolContext, O> function);
285305

spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallbackBuilder.java

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package org.springframework.ai.model.function;
1717

1818
import java.lang.reflect.Type;
19+
import java.util.Arrays;
1920
import java.util.function.BiFunction;
2021
import java.util.function.Function;
2122

@@ -43,20 +44,34 @@
4344
* @author Christian Tzolov
4445
* @since 1.0.0
4546
*/
46-
4747
public class DefaultFunctionCallbackBuilder implements FunctionCallback.Builder {
4848

4949
private final static Logger logger = LoggerFactory.getLogger(DefaultFunctionCallbackBuilder.class);
5050

51+
/**
52+
* The description of the function callback. Used to hint the LLM model about the
53+
* tool's purpose and when to use it.
54+
*/
5155
private String description;
5256

57+
/**
58+
* The schema type to use for the input type schema generation. The default is JSON
59+
* Schema. Note: Vertex AI requires the input type schema to be in Open API schema
60+
*/
5361
private SchemaType schemaType = SchemaType.JSON_SCHEMA;
5462

55-
// By default the response is converted to a JSON string.
63+
/**
64+
* The function to convert the response object to a string. The default is to convert
65+
* the response to a JSON string.
66+
*/
5667
private Function<Object, String> responseConverter = input -> (input instanceof String) ? "" + input
5768
: ModelOptionsUtils.toJsonString(input);
5869

59-
// optional
70+
/**
71+
* (Optional) Instead of generating the input type schema from the input type or
72+
* method argument types, you can provide the schema directly. This will override the
73+
* generated schema.
74+
*/
6075
private String inputTypeSchema;
6176

6277
private ObjectMapper objectMapper = JsonMapper.builder()
@@ -127,13 +142,15 @@ public class FunctionInvokerBuilderImpl<I, O> implements FunctionInvokerBuilder<
127142

128143
private FunctionInvokerBuilderImpl(String name, BiFunction<I, ToolContext, O> biFunction) {
129144
Assert.hasText(name, "Name must not be empty");
145+
Assert.notNull(biFunction, "BiFunction must not be null");
130146
this.name = name;
131147
this.biFunction = biFunction;
132148
this.function = null;
133149
}
134150

135151
private FunctionInvokerBuilderImpl(String name, Function<I, O> function) {
136152
Assert.hasText(name, "Name must not be empty");
153+
Assert.notNull(function, "Function must not be null");
137154
this.name = name;
138155
this.biFunction = null;
139156
this.function = function;
@@ -227,6 +244,8 @@ public FunctionCallback build() {
227244
Assert.isTrue(this.targetClass != null || this.targetObject != null,
228245
"Target class or object must not be null");
229246
var method = ReflectionUtils.findMethod(targetClass, methodName, argumentTypes);
247+
Assert.notNull(method,
248+
"Method: '" + methodName + "' with arguments:" + Arrays.toString(argumentTypes) + " not found!");
230249
return new MethodFunctionCallback(this.targetObject, method, this.getDescription(), objectMapper, this.name,
231250
responseConverter);
232251
}

spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import org.springframework.ai.chat.model.Generation;
4141
import org.springframework.ai.chat.prompt.Prompt;
4242
import org.springframework.ai.model.Media;
43+
import org.springframework.ai.model.function.FunctionCallback;
4344
import org.springframework.ai.model.function.FunctionCallingOptions;
4445
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder;
4546
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
@@ -217,7 +218,11 @@ void mutateDefaults() {
217218
.param("param1", "value1")
218219
.param("param2", "value2"))
219220
.defaultFunctions("fun1", "fun2")
220-
.defaultFunction("fun3", "fun3description", mockFunction)
221+
.defaultFunctions(FunctionCallback.builder()
222+
.description("fun3description")
223+
.function("fun3", mockFunction)
224+
.inputType(String.class)
225+
.build())
221226
.defaultUser(u -> u.text("Default user text {uparam1}, {uparam2}")
222227
.param("uparam1", "value1")
223228
.param("uparam2", "value2")
@@ -344,7 +349,11 @@ void mutatePrompt() {
344349
.param("param1", "value1")
345350
.param("param2", "value2"))
346351
.defaultFunctions("fun1", "fun2")
347-
.defaultFunction("fun3", "fun3description", mockFunction)
352+
.defaultFunctions(FunctionCallback.builder()
353+
.description("fun3description")
354+
.function("fun3", mockFunction)
355+
.inputType(String.class)
356+
.build())
348357
.defaultUser(u -> u.text("Default user text {uparam1}, {uparam2}")
349358
.param("uparam1", "value1")
350359
.param("uparam2", "value2")

0 commit comments

Comments
 (0)