Skip to content

Commit b25d6e8

Browse files
ilayaperumalgmarkpollack
authored andcommitted
Refactor FunctionCallingOptions Builder
- Deprecate existing FunctionCallingOptionsBuilder - Create FunctionCallingOptions.Builder which extends ChatOptions.Builder - Create DefaultFunctionCallingOptions which extends DefaultChatOptions and implements FunctionCallingOptions to serve the default FunctionCalling options - Create DefaultFunctionCallingOptionsBuilder to build the default functioncalling options - Update the usage of functioncalling options builder to use the newly added builder including the tests Improve extensibility of DefaultChatOptionsBuilder - Enable DefaultChatOptionsBuilder to accommodate any other sub types - Introduce generics to support sub types that extend DefaultChatOptionsBuilder - Update builder methods to return the sub type - Make FunctionCallingOptions' builder()'s return type to accommodate sub types which can extend FunctionCallingOptions.Builder
1 parent 627fb79 commit b25d6e8

File tree

33 files changed

+446
-129
lines changed

33 files changed

+446
-129
lines changed

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
import org.springframework.ai.converter.MapOutputConverter;
5151
import org.springframework.ai.model.Media;
5252
import org.springframework.ai.model.function.FunctionCallback;
53-
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
53+
import org.springframework.ai.model.function.FunctionCallingOptions;
5454
import org.springframework.beans.factory.annotation.Autowired;
5555
import org.springframework.beans.factory.annotation.Value;
5656
import org.springframework.boot.SpringBootConfiguration;
@@ -259,9 +259,7 @@ void multiModalityPdfTest() throws IOException {
259259
List.of(new Media(new MimeType("application", "pdf"), pdfData)));
260260

261261
var response = this.chatModel.call(new Prompt(List.of(userMessage),
262-
PortableFunctionCallingOptions.builder()
263-
.withModel(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName())
264-
.build()));
262+
FunctionCallingOptions.builder().model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName()).build()));
265263

266264
assertThat(response.getResult().getOutput().getText()).containsAnyOf("Spring AI", "portable API");
267265
}

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,10 @@
9696
import org.springframework.ai.chat.prompt.Prompt;
9797
import org.springframework.ai.model.Media;
9898
import org.springframework.ai.model.ModelOptionsUtils;
99+
import org.springframework.ai.model.function.DefaultFunctionCallingOptions;
99100
import org.springframework.ai.model.function.FunctionCallback;
100101
import org.springframework.ai.model.function.FunctionCallbackResolver;
101102
import org.springframework.ai.model.function.FunctionCallingOptions;
102-
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder;
103-
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
104103
import org.springframework.ai.observation.conventions.AiProvider;
105104
import org.springframework.util.Assert;
106105
import org.springframework.util.CollectionUtils;
@@ -322,12 +321,12 @@ else if (message.getMessageType() == MessageType.TOOL) {
322321
if (prompt.getOptions() != null) {
323322
if (prompt.getOptions() instanceof FunctionCallingOptions) {
324323
var functionCallingOptions = (FunctionCallingOptions) prompt.getOptions();
325-
updatedRuntimeOptions = ((PortableFunctionCallingOptions) updatedRuntimeOptions)
324+
updatedRuntimeOptions = ((DefaultFunctionCallingOptions) updatedRuntimeOptions)
326325
.merge(functionCallingOptions);
327326
}
328327
else if (prompt.getOptions() instanceof ChatOptions) {
329328
var chatOptions = (ChatOptions) prompt.getOptions();
330-
updatedRuntimeOptions = ((PortableFunctionCallingOptions) updatedRuntimeOptions).merge(chatOptions);
329+
updatedRuntimeOptions = ((DefaultFunctionCallingOptions) updatedRuntimeOptions).merge(chatOptions);
331330
}
332331
}
333332

@@ -697,7 +696,7 @@ public static final class Builder {
697696

698697
private Duration timeout = Duration.ofMinutes(10);
699698

700-
private FunctionCallingOptions defaultOptions = new FunctionCallingOptionsBuilder().build();
699+
private FunctionCallingOptions defaultOptions = new DefaultFunctionCallingOptions();
701700

702701
private FunctionCallbackResolver functionCallbackResolver;
703702

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ void multiModalityEmbeddedImage(String modelName) throws IOException {
372372

373373
// @formatter:off
374374
String response = ChatClient.create(this.chatModel).prompt()
375-
.options(FunctionCallingOptions.builder().withModel(modelName).build())
375+
.options(FunctionCallingOptions.builder().model(modelName).build())
376376
.user(u -> u.text("Explain what do you see on this picture?")
377377
.media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.png")))
378378
.call()
@@ -394,7 +394,7 @@ void multiModalityImageUrl(String modelName) throws IOException {
394394
// @formatter:off
395395
String response = ChatClient.create(this.chatModel).prompt()
396396
// TODO consider adding model(...) method to ChatClient as a shortcut to
397-
.options(FunctionCallingOptions.builder().withModel(modelName).build())
397+
.options(FunctionCallingOptions.builder().model(modelName).build())
398398
.user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url))
399399
.call()
400400
.content();

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public BedrockProxyChatModel bedrockConverseChatModel() {
4242
.withRegion(Region.US_EAST_1)
4343
.withTimeout(Duration.ofSeconds(120))
4444
// .withRegion(Region.US_EAST_1)
45-
.withDefaultOptions(FunctionCallingOptions.builder().withModel(modelId).build())
45+
.withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build())
4646
.build();
4747
}
4848

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
import org.springframework.ai.chat.prompt.Prompt;
4242
import org.springframework.ai.model.function.FunctionCallback;
4343
import org.springframework.ai.model.function.FunctionCallingOptions;
44-
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
4544

4645
import static org.assertj.core.api.Assertions.assertThat;
4746
import static org.mockito.ArgumentMatchers.isA;
@@ -145,7 +144,7 @@ public void callWithToolUse() {
145144
.build();
146145

147146
var result = this.chatModel.call(new Prompt("What is the weather in Paris?",
148-
PortableFunctionCallingOptions.builder().withFunctionCallbacks(functionCallback).build()));
147+
FunctionCallingOptions.builder().functionCallbacks(functionCallback).build()));
149148

150149
assertThat(result).isNotNull();
151150
assertThat(result.getResult().getOutput().getText())

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ void roleTest(String modelName) {
9090
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
9191
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
9292
Prompt prompt = new Prompt(List.of(userMessage, systemMessage),
93-
FunctionCallingOptions.builder().withModel(modelName).build());
93+
FunctionCallingOptions.builder().model(modelName).build());
9494
ChatResponse response = this.chatModel.call(prompt);
9595
assertThat(response.getResults()).hasSize(1);
9696
assertThat(response.getMetadata().getUsage().getGenerationTokens()).isGreaterThan(0);
@@ -126,7 +126,7 @@ void testMessageHistory() {
126126

127127
@Test
128128
void streamingWithTokenUsage() {
129-
var promptOptions = FunctionCallingOptions.builder().withTemperature(0.0).build();
129+
var promptOptions = FunctionCallingOptions.builder().temperature(0.0).build();
130130

131131
var prompt = new Prompt("List two colors of the Polish flag. Be brief.", promptOptions);
132132
var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage();
@@ -252,7 +252,7 @@ void functionCallTest() {
252252
List<Message> messages = new ArrayList<>(List.of(userMessage));
253253

254254
var promptOptions = FunctionCallingOptions.builder()
255-
.withFunctionCallbacks(List.of(FunctionCallback.builder()
255+
.functionCallbacks(List.of(FunctionCallback.builder()
256256
.function("getCurrentWeather", new MockWeatherService())
257257
.description(
258258
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
@@ -279,8 +279,8 @@ void streamFunctionCallTest() {
279279
List<Message> messages = new ArrayList<>(List.of(userMessage));
280280

281281
var promptOptions = FunctionCallingOptions.builder()
282-
.withModel("anthropic.claude-3-5-sonnet-20240620-v1:0")
283-
.withFunctionCallbacks(List.of(FunctionCallback.builder()
282+
.model("anthropic.claude-3-5-sonnet-20240620-v1:0")
283+
.functionCallbacks(List.of(FunctionCallback.builder()
284284
.function("getCurrentWeather", new MockWeatherService())
285285
.description(
286286
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
@@ -306,7 +306,7 @@ void validateCallResponseMetadata() {
306306
String model = "anthropic.claude-3-5-sonnet-20240620-v1:0";
307307
// @formatter:off
308308
ChatResponse response = ChatClient.create(this.chatModel).prompt()
309-
.options(FunctionCallingOptions.builder().withModel(model).build())
309+
.options(FunctionCallingOptions.builder().model(model).build())
310310
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
311311
.call()
312312
.chatResponse();
@@ -321,7 +321,7 @@ void validateStreamCallResponseMetadata() {
321321
String model = "anthropic.claude-3-5-sonnet-20240620-v1:0";
322322
// @formatter:off
323323
ChatResponse response = ChatClient.create(this.chatModel).prompt()
324-
.options(FunctionCallingOptions.builder().withModel(model).build())
324+
.options(FunctionCallingOptions.builder().model(model).build())
325325
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
326326
.stream()
327327
.chatResponse()

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
3636
import org.springframework.ai.chat.prompt.Prompt;
3737
import org.springframework.ai.model.function.FunctionCallingOptions;
38-
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
3938
import org.springframework.ai.observation.conventions.AiOperationType;
4039
import org.springframework.ai.observation.conventions.AiProvider;
4140
import org.springframework.beans.factory.annotation.Autowired;
@@ -68,13 +67,13 @@ void beforeEach() {
6867

6968
@Test
7069
void observationForChatOperation() {
71-
var options = PortableFunctionCallingOptions.builder()
72-
.withModel("anthropic.claude-3-5-sonnet-20240620-v1:0")
73-
.withMaxTokens(2048)
74-
.withStopSequences(List.of("this-is-the-end"))
75-
.withTemperature(0.7)
70+
var options = FunctionCallingOptions.builder()
71+
.model("anthropic.claude-3-5-sonnet-20240620-v1:0")
72+
.maxTokens(2048)
73+
.stopSequences(List.of("this-is-the-end"))
74+
.temperature(0.7)
7675
// .withTopK(1)
77-
.withTopP(1.0)
76+
.topP(1.0)
7877
.build();
7978

8079
Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
@@ -90,12 +89,12 @@ void observationForChatOperation() {
9089

9190
@Test
9291
void observationForStreamingChatOperation() {
93-
var options = PortableFunctionCallingOptions.builder()
94-
.withModel("anthropic.claude-3-5-sonnet-20240620-v1:0")
95-
.withMaxTokens(2048)
96-
.withStopSequences(List.of("this-is-the-end"))
97-
.withTemperature(0.7)
98-
.withTopP(1.0)
92+
var options = FunctionCallingOptions.builder()
93+
.model("anthropic.claude-3-5-sonnet-20240620-v1:0")
94+
.maxTokens(2048)
95+
.stopSequences(List.of("this-is-the-end"))
96+
.temperature(0.7)
97+
.topP(1.0)
9998
.build();
10099

101100
Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
@@ -174,7 +173,7 @@ public BedrockProxyChatModel bedrockConverseChatModel(ObservationRegistry observ
174173
.withCredentialsProvider(EnvironmentVariableCredentialsProvider.create())
175174
.withRegion(Region.US_EAST_1)
176175
.withObservationRegistry(observationRegistry)
177-
.withDefaultOptions(FunctionCallingOptions.builder().withModel(modelId).build())
176+
.withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build())
178177
.build();
179178
}
180179

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ public BedrockProxyChatModel bedrockConverseChatModel() {
184184
.withCredentialsProvider(EnvironmentVariableCredentialsProvider.create())
185185
.withRegion(Region.US_EAST_1)
186186
.withTimeout(Duration.ofSeconds(120))
187-
.withDefaultOptions(FunctionCallingOptions.builder().withModel(modelId).build())
187+
.withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build())
188188
.build();
189189
}
190190

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain2.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import org.springframework.ai.bedrock.converse.MockWeatherService;
2828
import org.springframework.ai.chat.prompt.Prompt;
2929
import org.springframework.ai.model.function.FunctionCallback;
30-
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
30+
import org.springframework.ai.model.function.FunctionCallingOptions;
3131

3232
/**
3333
* Used for reverse engineering the protocol
@@ -50,9 +50,9 @@ public static void main(String[] args) {
5050
// "What's the weather like in San Francisco, Tokyo, and Paris? Return the
5151
// temperature in Celsius.",
5252
"What's the weather like in Paris? Return the temperature in Celsius.",
53-
PortableFunctionCallingOptions.builder()
54-
.withModel(modelId)
55-
.withFunctionCallbacks(List.of(FunctionCallback.builder()
53+
FunctionCallingOptions.builder()
54+
.model(modelId)
55+
.functionCallbacks(List.of(FunctionCallback.builder()
5656
.function("getCurrentWeather", new MockWeatherService())
5757
.description("Get the weather in location")
5858
.inputType(MockWeatherService.Request.class)

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain3.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import org.springframework.ai.bedrock.converse.MockWeatherService;
2626
import org.springframework.ai.chat.prompt.Prompt;
2727
import org.springframework.ai.model.function.FunctionCallback;
28-
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
28+
import org.springframework.ai.model.function.FunctionCallingOptions;
2929

3030
/**
3131
* Used for reverse engineering the protocol
@@ -48,9 +48,9 @@ public static void main(String[] args) {
4848
// "What's the weather like in San Francisco, Tokyo, and Paris? Return the
4949
// temperature in Celsius.",
5050
"What's the weather like in Paris? Return the temperature in Celsius.",
51-
PortableFunctionCallingOptions.builder()
52-
.withModel(modelId)
53-
.withFunctionCallbacks(List.of(FunctionCallback.builder()
51+
FunctionCallingOptions.builder()
52+
.model(modelId)
53+
.functionCallbacks(List.of(FunctionCallback.builder()
5454
.function("getCurrentWeather", new MockWeatherService())
5555
.description("Get the weather in location")
5656
.inputType(MockWeatherService.Request.class)

0 commit comments

Comments
 (0)