Skip to content

Commit 92f94f4

Browse files
committed
refactor: introduce specialized builder interfaces for function and method invocations
Introduce a new hierarchical builder API to better handle different function types: - Create distinct builder interfaces: - FunctionInvokerBuilder: For Function/BiFunction implementations - MethodInvokerBuilder: For method reflection-based invocations - Establish common Builder interface for shared properties - Move builder configuration into dedicated interfaces: - Separate function object handling from constructor - Add method-specific configuration (method name, arg types, target) - Consolidate common properties in parent interface Supporting changes: - Update all AI model implementations to use new builder pattern: - Move function() call after description() for consistency - Standardize builder method ordering across implementations - Modify integration tests to demonstrate new patterns: - OpenAI, Ollama, Minimax, Moonshot, ZhiPuAI tests updated - Add examples for both static and instance methods - Update schema handling: - Set OPEN_API_SCHEMA as default for Vertex AI Gemini - Standardize schema types across providers
1 parent 4871449 commit 92f94f4

File tree

62 files changed

+634
-543
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+634
-543
lines changed

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,11 @@ void functionCallTest() {
256256

257257
var promptOptions = AnthropicChatOptions.builder()
258258
.withModel(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getName())
259-
.withFunctionCallbacks(List.of(FunctionCallback.builder(new MockWeatherService())
260-
.name("getCurrentWeather")
259+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
261260
.description(
262261
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
262+
.function(new MockWeatherService())
263+
.name("getCurrentWeather")
263264
.inputType(MockWeatherService.Request.class)
264265
.build()))
265266
.build();
@@ -284,10 +285,11 @@ void streamFunctionCallTest() {
284285

285286
var promptOptions = AnthropicChatOptions.builder()
286287
.withModel(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName())
287-
.withFunctionCallbacks(List.of(FunctionCallback.builder(new MockWeatherService())
288-
.name("getCurrentWeather")
288+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
289289
.description(
290290
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
291+
.function(new MockWeatherService())
292+
.name("getCurrentWeather")
291293
.inputType(MockWeatherService.Request.class)
292294
.build()))
293295
.build();

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientMethodFunctionCallbackIT.java

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,10 @@
2929
import org.springframework.ai.chat.client.ChatClient;
3030
import org.springframework.ai.chat.model.ChatModel;
3131
import org.springframework.ai.chat.model.ToolContext;
32-
import org.springframework.ai.model.function.MethodFunctionCallback;
32+
import org.springframework.ai.model.function.FunctionCallback;
3333
import org.springframework.beans.factory.annotation.Autowired;
3434
import org.springframework.boot.test.context.SpringBootTest;
3535
import org.springframework.test.context.ActiveProfiles;
36-
import org.springframework.util.ReflectionUtils;
3736

3837
import static org.assertj.core.api.Assertions.assertThat;
3938
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
@@ -55,13 +54,14 @@ void beforeEach() {
5554
@Test
5655
void methodGetWeatherStatic() {
5756

58-
var method = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherStatic", String.class, Unit.class);
5957
// @formatter:off
6058
String response = ChatClient.create(this.chatModel).prompt()
6159
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
62-
.functions(MethodFunctionCallback.builder()
63-
.method(method)
60+
.functions(FunctionCallback.builder()
6461
.description("Get the weather in location")
62+
.method("getWeatherStatic")
63+
.targetClass(TestFunctionClass.class)
64+
.argumentTypes(String.class, Unit.class)
6565
.build())
6666
.call()
6767
.content();
@@ -77,15 +77,14 @@ void methodTurnLightNoResponse() {
7777

7878
TestFunctionClass targetObject = new TestFunctionClass();
7979

80-
var method = ReflectionUtils.findMethod(TestFunctionClass.class, "turnLight", String.class, boolean.class);
81-
8280
// @formatter:off
8381
String response = ChatClient.create(this.chatModel).prompt()
8482
.user("Turn light on in the living room.")
85-
.functions(MethodFunctionCallback.builder()
86-
.functionObject(targetObject)
87-
.method(method)
88-
.description("Can turn lights on or off by room name")
83+
.functions(FunctionCallback.builder()
84+
.description("Turn light on in the living room.")
85+
.method("turnLight")
86+
.targetObject(targetObject)
87+
.argumentTypes(String.class, boolean.class)
8988
.build())
9089
.call()
9190
.content();
@@ -102,16 +101,14 @@ void methodGetWeatherNonStatic() {
102101

103102
TestFunctionClass targetObject = new TestFunctionClass();
104103

105-
var method = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherNonStatic", String.class,
106-
Unit.class);
107-
108104
// @formatter:off
109105
String response = ChatClient.create(this.chatModel).prompt()
110106
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
111-
.functions(MethodFunctionCallback.builder()
112-
.functionObject(targetObject)
113-
.method(method)
107+
.functions(FunctionCallback.builder()
114108
.description("Get the weather in location")
109+
.method("getWeatherNonStatic")
110+
.argumentTypes(String.class, Unit.class)
111+
.targetObject(targetObject)
115112
.build())
116113
.call()
117114
.content();
@@ -127,17 +124,15 @@ void methodGetWeatherToolContext() {
127124

128125
TestFunctionClass targetObject = new TestFunctionClass();
129126

130-
var method = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherWithContext", String.class,
131-
Unit.class, ToolContext.class);
132-
133127
// @formatter:off
134128
String response = ChatClient.create(this.chatModel).prompt()
135129
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
136-
.functions(MethodFunctionCallback.builder()
137-
.functionObject(targetObject)
138-
.method(method)
130+
.functions(FunctionCallback.builder()
139131
.description("Get the weather in location")
140-
.build())
132+
.method("getWeatherWithContext")
133+
.argumentTypes(String.class, Unit.class, ToolContext.class)
134+
.targetObject(targetObject)
135+
.build())
141136
.toolContext(Map.of("tool", "value"))
142137
.call()
143138
.content();
@@ -154,17 +149,15 @@ void methodGetWeatherToolContextButNonContextMethod() {
154149

155150
TestFunctionClass targetObject = new TestFunctionClass();
156151

157-
var method = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherNonStatic", String.class,
158-
Unit.class);
159-
160152
// @formatter:off
161153
assertThatThrownBy(() -> ChatClient.create(this.chatModel).prompt()
162154
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
163-
.functions(MethodFunctionCallback.builder()
164-
.functionObject(targetObject)
165-
.method(method)
166-
.description("Get the weather in location")
167-
.build())
155+
.functions(FunctionCallback.builder()
156+
.description("Get the weather in location")
157+
.method("getWeatherNonStatic")
158+
.argumentTypes(String.class, Unit.class)
159+
.targetObject(targetObject)
160+
.build())
168161
.toolContext(Map.of("tool", "value"))
169162
.call()
170163
.content())
@@ -178,15 +171,13 @@ void methodNoParameters() {
178171

179172
TestFunctionClass targetObject = new TestFunctionClass();
180173

181-
var method = ReflectionUtils.findMethod(TestFunctionClass.class, "turnLivingRoomLightOn");
182-
183174
// @formatter:off
184175
String response = ChatClient.create(this.chatModel).prompt()
185176
.user("Turn light on in the living room.")
186-
.functions(MethodFunctionCallback.builder()
187-
.functionObject(targetObject)
188-
.method(method)
177+
.functions(FunctionCallback.builder()
189178
.description("Can turn lights on in the Living Room")
179+
.method("turnLivingRoomLightOn")
180+
.targetObject(targetObject)
190181
.build())
191182
.call()
192183
.content();

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ void functionCallTest() {
7070

7171
var promptOptions = AzureOpenAiChatOptions.builder()
7272
.withDeploymentName(this.selectedModel)
73-
.withFunctionCallbacks(List.of(FunctionCallback.builder(new MockWeatherService())
74-
.name("getCurrentWeather")
73+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
7574
.description("Get the current weather in a given location")
75+
.function(new MockWeatherService())
76+
.name("getCurrentWeather")
7677
.inputType(MockWeatherService.Request.class)
77-
.responseConverter(response -> "" + response.temp() + response.unit())
7878
.build()))
7979
.build();
8080

@@ -95,11 +95,11 @@ void functionCallSequentialTest() {
9595

9696
var promptOptions = AzureOpenAiChatOptions.builder()
9797
.withDeploymentName(this.selectedModel)
98-
.withFunctionCallbacks(List.of(FunctionCallback.builder(new MockWeatherService())
99-
.name("getCurrentWeather")
98+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
10099
.description("Get the current weather in a given location")
100+
.function(new MockWeatherService())
101+
.name("getCurrentWeather")
101102
.inputType(MockWeatherService.Request.class)
102-
.responseConverter(response -> "" + response.temp() + response.unit())
103103
.build()))
104104
.build();
105105

@@ -118,11 +118,11 @@ void streamFunctionCallTest() {
118118

119119
var promptOptions = AzureOpenAiChatOptions.builder()
120120
.withDeploymentName(this.selectedModel)
121-
.withFunctionCallbacks(List.of(FunctionCallback.builder(new MockWeatherService())
122-
.name("getCurrentWeather")
121+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
123122
.description("Get the current weather in a given location")
123+
.function(new MockWeatherService())
124+
.name("getCurrentWeather")
124125
.inputType(MockWeatherService.Request.class)
125-
.responseConverter(response -> "" + response.temp() + response.unit())
126126
.build()))
127127
.build();
128128

@@ -156,11 +156,11 @@ void functionCallSequentialAndStreamTest() {
156156

157157
var promptOptions = AzureOpenAiChatOptions.builder()
158158
.withDeploymentName(this.selectedModel)
159-
.withFunctionCallbacks(List.of(FunctionCallback.builder(new MockWeatherService())
160-
.name("getCurrentWeather")
159+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
161160
.description("Get the current weather in a given location")
161+
.function(new MockWeatherService())
162+
.name("getCurrentWeather")
162163
.inputType(MockWeatherService.Request.class)
163-
.responseConverter(response -> "" + response.temp() + response.unit())
164164
.build()))
165165
.build();
166166

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,11 @@ void functionCallTest() {
254254
List<Message> messages = new ArrayList<>(List.of(userMessage));
255255

256256
var promptOptions = FunctionCallingOptions.builder()
257-
.withFunctionCallbacks(List.of(FunctionCallback.builder(new MockWeatherService())
258-
.name("getCurrentWeather")
257+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
259258
.description(
260259
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
260+
.function(new MockWeatherService())
261+
.name("getCurrentWeather")
261262
.inputType(MockWeatherService.Request.class)
262263
.build()))
263264
.build();
@@ -282,10 +283,11 @@ void streamFunctionCallTest() {
282283

283284
var promptOptions = FunctionCallingOptions.builder()
284285
.withModel("anthropic.claude-3-5-sonnet-20240620-v1:0")
285-
.withFunctionCallbacks(List.of(FunctionCallback.builder(new MockWeatherService())
286-
.name("getCurrentWeather")
286+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
287287
.description(
288288
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
289+
.function(new MockWeatherService())
290+
.name("getCurrentWeather")
289291
.inputType(MockWeatherService.Request.class)
290292
.build()))
291293
.build();

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiements/BedrockConverseChatModelMain2.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@ public static void main(String[] args) {
5252
"What's the weather like in Paris? Return the temperature in Celsius.",
5353
PortableFunctionCallingOptions.builder()
5454
.withModel(modelId)
55-
.withFunctionCallbacks(List.of(FunctionCallback.builder(new MockWeatherService())
56-
.name("getCurrentWeather")
55+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
5756
.description("Get the weather in location")
57+
.function(new MockWeatherService())
58+
.name("getCurrentWeather")
5859
.inputType(MockWeatherService.Request.class)
5960
.build()))
6061
.build());

models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/ChatCompletionRequestTests.java

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ public void promptOptionsTools() {
6767
var request = client.createRequest(new Prompt("Test message content",
6868
MiniMaxChatOptions.builder()
6969
.withModel("PROMPT_MODEL")
70-
.withFunctionCallbacks(List.of(FunctionCallback.builder(new MockWeatherService())
71-
.name(TOOL_FUNCTION_NAME)
70+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
7271
.description("Get the weather in location")
72+
.function(new MockWeatherService())
73+
.name(TOOL_FUNCTION_NAME)
7374
.inputType(MockWeatherService.Request.class)
74-
.responseConverter(response -> "" + response.temp() + response.unit())
7575
.build()))
7676
.build()),
7777
false);
@@ -95,11 +95,11 @@ public void defaultOptionsTools() {
9595
var client = new MiniMaxChatModel(new MiniMaxApi("TEST"),
9696
MiniMaxChatOptions.builder()
9797
.withModel("DEFAULT_MODEL")
98-
.withFunctionCallbacks(List.of(FunctionCallback.builder(new MockWeatherService())
99-
.name(TOOL_FUNCTION_NAME)
98+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
10099
.description("Get the weather in location")
100+
.function(new MockWeatherService())
101+
.name(TOOL_FUNCTION_NAME)
101102
.inputType(MockWeatherService.Request.class)
102-
.responseConverter(response -> "" + response.temp() + response.unit())
103103
.build()))
104104
.build());
105105

@@ -128,9 +128,10 @@ public void defaultOptionsTools() {
128128
// Override the default options function with one from the prompt
129129
request = client.createRequest(new Prompt("Test message content",
130130
MiniMaxChatOptions.builder()
131-
.withFunctionCallbacks(List.of(FunctionCallback.builder(new MockWeatherService())
132-
.name(TOOL_FUNCTION_NAME)
131+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
133132
.description("Overridden function description")
133+
.function(new MockWeatherService())
134+
.name(TOOL_FUNCTION_NAME)
134135
.inputType(MockWeatherService.Request.class)
135136
.build()))
136137
.build()),

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,11 @@ void functionCallTest() {
193193

194194
var promptOptions = MistralAiChatOptions.builder()
195195
.withModel(MistralAiApi.ChatModel.SMALL.getValue())
196-
.withFunctionCallbacks(List.of(FunctionCallback.builder(new MockWeatherService())
197-
.name("getCurrentWeather")
196+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
198197
.description("Get the weather in location")
198+
.function(new MockWeatherService())
199+
.name("getCurrentWeather")
199200
.inputType(MockWeatherService.Request.class)
200-
.responseConverter(response -> "" + response.temp() + response.unit())
201201
.build()))
202202
.build();
203203

@@ -217,11 +217,11 @@ void streamFunctionCallTest() {
217217

218218
var promptOptions = MistralAiChatOptions.builder()
219219
.withModel(MistralAiApi.ChatModel.SMALL.getValue())
220-
.withFunctionCallbacks(List.of(FunctionCallback.builder(new MockWeatherService())
221-
.name("getCurrentWeather")
220+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
222221
.description("Get the weather in location")
222+
.function(new MockWeatherService())
223+
.name("getCurrentWeather")
223224
.inputType(MockWeatherService.Request.class)
224-
.responseConverter(response -> "" + response.temp() + response.unit())
225225
.build()))
226226
.build();
227227

models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@ void functionCallTest() {
6363

6464
var promptOptions = MoonshotChatOptions.builder()
6565
.withModel(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
66-
.withFunctionCallbacks(List.of(FunctionCallback.builder(new MockWeatherService())
67-
.name("getCurrentWeather")
66+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
6867
.description("Get the weather in location")
68+
.function(new MockWeatherService())
69+
.name("getCurrentWeather")
6970
.inputType(MockWeatherService.Request.class)
70-
.responseConverter(response -> "" + response.temp() + response.unit())
7171
.build()))
7272
.build();
7373

@@ -87,10 +87,10 @@ void streamFunctionCallTest() {
8787
List<Message> messages = new ArrayList<>(List.of(userMessage));
8888

8989
var promptOptions = MoonshotChatOptions.builder()
90-
.withFunctionCallbacks(List.of(FunctionCallback.builder(new MockWeatherService())
91-
.name("getCurrentWeather")
90+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
9291
.description("Get the weather in location")
93-
.responseConverter(response -> "" + response.temp() + response.unit())
92+
.function(new MockWeatherService())
93+
.name("getCurrentWeather")
9494
.build()))
9595
.build();
9696

0 commit comments

Comments
 (0)