From bdc2778d6c04f7a7cdc3e110659006d787da701b Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Thu, 12 Dec 2024 22:39:06 +0000 Subject: [PATCH 1/2] Refactor FunctionCallingOptions Builder - Deprecate existing FunctionCallingOptionsBuilder - Create FunctionCallingOptions.Builder which extends ChatOptions.Builder - Create DefaultFunctionCallingOptions which extends DefaultChatOptions and implements FunctionCallingOptions to serve the default FunctionCalling options - Create DefaultFunctionCallingOptionsBuilder to build the default functioncalling options - Update the usage of functioncalling options builder to use the newly added builder including the tests Resolves #1874 --- .../ai/anthropic/AnthropicChatModelIT.java | 6 +- .../converse/BedrockProxyChatModel.java | 9 +- .../converse/BedrockConverseChatClientIT.java | 4 +- .../BedrockConverseTestConfiguration.java | 2 +- .../BedrockConverseUsageAggregationTests.java | 3 +- .../converse/BedrockProxyChatModelIT.java | 14 +- .../BedrockProxyChatModelObservationIT.java | 27 ++- .../client/BedrockNovaChatClientIT.java | 2 +- .../BedrockConverseChatModelMain2.java | 8 +- .../BedrockConverseChatModelMain3.java | 8 +- .../ai/chat/prompt/ChatOptions.java | 18 +- .../prompt/DefaultChatOptionsBuilder.java | 18 +- .../DefaultFunctionCallingOptions.java | 173 ++++++++++++++++++ .../DefaultFunctionCallingOptionsBuilder.java | 128 +++++++++++++ .../model/function/FunctionCallingHelper.java | 3 +- .../function/FunctionCallingOptions.java | 72 +++++++- .../FunctionCallingOptionsBuilder.java | 6 + .../ai/chat/ChatBuilderTests.java | 10 +- .../ai/chat/client/ChatClientTest.java | 7 +- .../BedrockConverseProxyChatProperties.java | 16 +- .../tool/FunctionCallWithFunctionBeanIT.java | 4 +- .../tool/FunctionCallWithFunctionBeanIT.java | 4 +- .../tool/FunctionCallWithFunctionBeanIT.java | 7 +- .../FunctionCallWithPromptFunctionIT.java | 2 +- ...nctionCallbackWithPlainFunctionBeanIT.java | 5 +- .../tool/WeatherServicePromptIT.java | 5 +- ...nctionCallbackWithPlainFunctionBeanIT.java | 5 +- .../ollama/tool/OllamaFunctionCallbackIT.java | 6 +- ...nctionCallbackWithPlainFunctionBeanIT.java | 9 +- .../tool/FunctionCallWithFunctionBeanIT.java | 4 +- ...nctionCallbackWithPlainFunctionBeanIT.java | 5 +- .../tool/FunctionCallbackContextKotlinIT.kt | 2 +- .../ollama/tool/FunctionCallbackKotlinIT.kt | 2 +- 33 files changed, 474 insertions(+), 120 deletions(-) create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptions.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilder.java diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java index aa62746bc7d..4546ccf0a6c 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java @@ -49,7 +49,7 @@ import org.springframework.ai.converter.MapOutputConverter; import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; +import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -258,9 +258,7 @@ void multiModalityPdfTest() throws IOException { List.of(new Media(new MimeType("application", "pdf"), pdfData))); var response = this.chatModel.call(new Prompt(List.of(userMessage), - PortableFunctionCallingOptions.builder() - .withModel(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName()) - .build())); + FunctionCallingOptions.builder().model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName()).build())); assertThat(response.getResult().getOutput().getText()).containsAnyOf("Spring AI", "portable API"); } diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index a1bbd10c0c4..116093b09bf 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -96,11 +96,10 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.Media; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.DefaultFunctionCallingOptions; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackResolver; import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -322,12 +321,12 @@ else if (message.getMessageType() == MessageType.TOOL) { if (prompt.getOptions() != null) { if (prompt.getOptions() instanceof FunctionCallingOptions) { var functionCallingOptions = (FunctionCallingOptions) prompt.getOptions(); - updatedRuntimeOptions = ((PortableFunctionCallingOptions) updatedRuntimeOptions) + updatedRuntimeOptions = ((DefaultFunctionCallingOptions) updatedRuntimeOptions) .merge(functionCallingOptions); } else if (prompt.getOptions() instanceof ChatOptions) { var chatOptions = (ChatOptions) prompt.getOptions(); - updatedRuntimeOptions = ((PortableFunctionCallingOptions) updatedRuntimeOptions).merge(chatOptions); + updatedRuntimeOptions = ((DefaultFunctionCallingOptions) updatedRuntimeOptions).merge(chatOptions); } } @@ -697,7 +696,7 @@ public static final class Builder { private Duration timeout = Duration.ofMinutes(10); - private FunctionCallingOptions defaultOptions = new FunctionCallingOptionsBuilder().build(); + private FunctionCallingOptions defaultOptions = new DefaultFunctionCallingOptions(); private FunctionCallbackResolver functionCallbackResolver; diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java index 99d26464418..86febbe1f0a 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java @@ -372,7 +372,7 @@ void multiModalityEmbeddedImage(String modelName) throws IOException { // @formatter:off String response = ChatClient.create(this.chatModel).prompt() - .options(FunctionCallingOptions.builder().withModel(modelName).build()) + .options(FunctionCallingOptions.builder().model(modelName).build()) .user(u -> u.text("Explain what do you see on this picture?") .media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.png"))) .call() @@ -394,7 +394,7 @@ void multiModalityImageUrl(String modelName) throws IOException { // @formatter:off String response = ChatClient.create(this.chatModel).prompt() // TODO consider adding model(...) method to ChatClient as a shortcut to - .options(FunctionCallingOptions.builder().withModel(modelName).build()) + .options(FunctionCallingOptions.builder().model(modelName).build()) .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url)) .call() .content(); diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java index eaf220fb284..f2537e95b0b 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java @@ -42,7 +42,7 @@ public BedrockProxyChatModel bedrockConverseChatModel() { .withRegion(Region.US_EAST_1) .withTimeout(Duration.ofSeconds(120)) // .withRegion(Region.US_EAST_1) - .withDefaultOptions(FunctionCallingOptions.builder().withModel(modelId).build()) + .withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build()) .build(); } diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java index e4d74743523..f319d442389 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java @@ -41,7 +41,6 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.isA; @@ -145,7 +144,7 @@ public void callWithToolUse() { .build(); var result = this.chatModel.call(new Prompt("What is the weather in Paris?", - PortableFunctionCallingOptions.builder().withFunctionCallbacks(functionCallback).build())); + FunctionCallingOptions.builder().functionCallbacks(functionCallback).build())); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getText()) diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java index aea0ee18803..6a4b84c0f23 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java @@ -90,7 +90,7 @@ void roleTest(String modelName) { SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage), - FunctionCallingOptions.builder().withModel(modelName).build()); + FunctionCallingOptions.builder().model(modelName).build()); ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getMetadata().getUsage().getGenerationTokens()).isGreaterThan(0); @@ -126,7 +126,7 @@ void testMessageHistory() { @Test void streamingWithTokenUsage() { - var promptOptions = FunctionCallingOptions.builder().withTemperature(0.0).build(); + var promptOptions = FunctionCallingOptions.builder().temperature(0.0).build(); var prompt = new Prompt("List two colors of the Polish flag. Be brief.", promptOptions); var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage(); @@ -252,7 +252,7 @@ void functionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = FunctionCallingOptions.builder() - .withFunctionCallbacks(List.of(FunctionCallback.builder() + .functionCallbacks(List.of(FunctionCallback.builder() .function("getCurrentWeather", new MockWeatherService()) .description( "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") @@ -279,8 +279,8 @@ void streamFunctionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = FunctionCallingOptions.builder() - .withModel("anthropic.claude-3-5-sonnet-20240620-v1:0") - .withFunctionCallbacks(List.of(FunctionCallback.builder() + .model("anthropic.claude-3-5-sonnet-20240620-v1:0") + .functionCallbacks(List.of(FunctionCallback.builder() .function("getCurrentWeather", new MockWeatherService()) .description( "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") @@ -306,7 +306,7 @@ void validateCallResponseMetadata() { String model = "anthropic.claude-3-5-sonnet-20240620-v1:0"; // @formatter:off ChatResponse response = ChatClient.create(this.chatModel).prompt() - .options(FunctionCallingOptions.builder().withModel(model).build()) + .options(FunctionCallingOptions.builder().model(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .call() .chatResponse(); @@ -321,7 +321,7 @@ void validateStreamCallResponseMetadata() { String model = "anthropic.claude-3-5-sonnet-20240620-v1:0"; // @formatter:off ChatResponse response = ChatClient.create(this.chatModel).prompt() - .options(FunctionCallingOptions.builder().withModel(model).build()) + .options(FunctionCallingOptions.builder().model(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .stream() .chatResponse() diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java index 07c79bdae27..88ed54fc4b8 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java @@ -35,7 +35,6 @@ import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.ai.observation.conventions.AiOperationType; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.beans.factory.annotation.Autowired; @@ -68,13 +67,13 @@ void beforeEach() { @Test void observationForChatOperation() { - var options = PortableFunctionCallingOptions.builder() - .withModel("anthropic.claude-3-5-sonnet-20240620-v1:0") - .withMaxTokens(2048) - .withStopSequences(List.of("this-is-the-end")) - .withTemperature(0.7) + var options = FunctionCallingOptions.builder() + .model("anthropic.claude-3-5-sonnet-20240620-v1:0") + .maxTokens(2048) + .stopSequences(List.of("this-is-the-end")) + .temperature(0.7) // .withTopK(1) - .withTopP(1.0) + .topP(1.0) .build(); Prompt prompt = new Prompt("Why does a raven look like a desk?", options); @@ -90,12 +89,12 @@ void observationForChatOperation() { @Test void observationForStreamingChatOperation() { - var options = PortableFunctionCallingOptions.builder() - .withModel("anthropic.claude-3-5-sonnet-20240620-v1:0") - .withMaxTokens(2048) - .withStopSequences(List.of("this-is-the-end")) - .withTemperature(0.7) - .withTopP(1.0) + var options = FunctionCallingOptions.builder() + .model("anthropic.claude-3-5-sonnet-20240620-v1:0") + .maxTokens(2048) + .stopSequences(List.of("this-is-the-end")) + .temperature(0.7) + .topP(1.0) .build(); Prompt prompt = new Prompt("Why does a raven look like a desk?", options); @@ -174,7 +173,7 @@ public BedrockProxyChatModel bedrockConverseChatModel(ObservationRegistry observ .withCredentialsProvider(EnvironmentVariableCredentialsProvider.create()) .withRegion(Region.US_EAST_1) .withObservationRegistry(observationRegistry) - .withDefaultOptions(FunctionCallingOptions.builder().withModel(modelId).build()) + .withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build()) .build(); } diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java index afdd0237994..326a90425a9 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java @@ -184,7 +184,7 @@ public BedrockProxyChatModel bedrockConverseChatModel() { .withCredentialsProvider(EnvironmentVariableCredentialsProvider.create()) .withRegion(Region.US_EAST_1) .withTimeout(Duration.ofSeconds(120)) - .withDefaultOptions(FunctionCallingOptions.builder().withModel(modelId).build()) + .withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build()) .build(); } diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain2.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain2.java index bcd6dfffab5..4851d621e8e 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain2.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain2.java @@ -27,7 +27,7 @@ import org.springframework.ai.bedrock.converse.MockWeatherService; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; +import org.springframework.ai.model.function.FunctionCallingOptions; /** * Used for reverse engineering the protocol @@ -50,9 +50,9 @@ public static void main(String[] args) { // "What's the weather like in San Francisco, Tokyo, and Paris? Return the // temperature in Celsius.", "What's the weather like in Paris? Return the temperature in Celsius.", - PortableFunctionCallingOptions.builder() - .withModel(modelId) - .withFunctionCallbacks(List.of(FunctionCallback.builder() + FunctionCallingOptions.builder() + .model(modelId) + .functionCallbacks(List.of(FunctionCallback.builder() .function("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain3.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain3.java index 00e9b61c2ad..b869e4c1704 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain3.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain3.java @@ -25,7 +25,7 @@ import org.springframework.ai.bedrock.converse.MockWeatherService; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; +import org.springframework.ai.model.function.FunctionCallingOptions; /** * Used for reverse engineering the protocol @@ -48,9 +48,9 @@ public static void main(String[] args) { // "What's the weather like in San Francisco, Tokyo, and Paris? Return the // temperature in Celsius.", "What's the weather like in Paris? Return the temperature in Celsius.", - PortableFunctionCallingOptions.builder() - .withModel(modelId) - .withFunctionCallbacks(List.of(FunctionCallback.builder() + FunctionCallingOptions.builder() + .model(modelId) + .functionCallbacks(List.of(FunctionCallback.builder() .function("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java index 81af1b7f349..01e2f7805d4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java @@ -101,63 +101,63 @@ static ChatOptions.Builder builder() { /** * Builder for creating {@link ChatOptions} instance. */ - interface Builder { + interface Builder> { /** * Builds with the model to use for the chat. * @param model * @return the builder */ - Builder model(String model); + B model(String model); /** * Builds with the frequency penalty to use for the chat. * @param frequencyPenalty * @return the builder. */ - Builder frequencyPenalty(Double frequencyPenalty); + B frequencyPenalty(Double frequencyPenalty); /** * Builds with the maximum number of tokens to use for the chat. * @param maxTokens * @return the builder. */ - Builder maxTokens(Integer maxTokens); + B maxTokens(Integer maxTokens); /** * Builds with the presence penalty to use for the chat. * @param presencePenalty * @return the builder. */ - Builder presencePenalty(Double presencePenalty); + B presencePenalty(Double presencePenalty); /** * Builds with the stop sequences to use for the chat. * @param stopSequences * @return the builder. */ - Builder stopSequences(List stopSequences); + B stopSequences(List stopSequences); /** * Builds with the temperature to use for the chat. * @param temperature * @return the builder. */ - Builder temperature(Double temperature); + B temperature(Double temperature); /** * Builds with the top K to use for the chat. * @param topK * @return the builder. */ - Builder topK(Integer topK); + B topK(Integer topK); /** * Builds with the top P to use for the chat. * @param topP * @return the builder. */ - Builder topP(Double topP); + B topP(Double topP); /** * Build the {@link ChatOptions}. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java index d85b60478ba..7e067517e9f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java @@ -21,46 +21,46 @@ /** * Implementation of {@link ChatOptions.Builder} to create {@link DefaultChatOptions}. */ -public class DefaultChatOptionsBuilder implements ChatOptions.Builder { +public class DefaultChatOptionsBuilder implements ChatOptions.Builder { private final DefaultChatOptions options = new DefaultChatOptions(); - public ChatOptions.Builder model(String model) { + public DefaultChatOptionsBuilder model(String model) { this.options.setModel(model); return this; } - public ChatOptions.Builder frequencyPenalty(Double frequencyPenalty) { + public DefaultChatOptionsBuilder frequencyPenalty(Double frequencyPenalty) { this.options.setFrequencyPenalty(frequencyPenalty); return this; } - public ChatOptions.Builder maxTokens(Integer maxTokens) { + public DefaultChatOptionsBuilder maxTokens(Integer maxTokens) { this.options.setMaxTokens(maxTokens); return this; } - public ChatOptions.Builder presencePenalty(Double presencePenalty) { + public DefaultChatOptionsBuilder presencePenalty(Double presencePenalty) { this.options.setPresencePenalty(presencePenalty); return this; } - public ChatOptions.Builder stopSequences(List stop) { + public DefaultChatOptionsBuilder stopSequences(List stop) { this.options.setStopSequences(stop); return this; } - public ChatOptions.Builder temperature(Double temperature) { + public DefaultChatOptionsBuilder temperature(Double temperature) { this.options.setTemperature(temperature); return this; } - public ChatOptions.Builder topK(Integer topK) { + public DefaultChatOptionsBuilder topK(Integer topK) { this.options.setTopK(topK); return this; } - public ChatOptions.Builder topP(Double topP) { + public DefaultChatOptionsBuilder topP(Double topP) { this.options.setTopP(topP); return this; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptions.java new file mode 100644 index 00000000000..4a117b1cbe7 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptions.java @@ -0,0 +1,173 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.function; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.DefaultChatOptions; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +/** + * Default implementation of {@link FunctionCallingOptions}. + * + * @author Christian Tzolov + * @author Thomas Vitale + * @author Ilayaperumal Gopinathan + */ +public class DefaultFunctionCallingOptions extends DefaultChatOptions implements FunctionCallingOptions { + + private List functionCallbacks = new ArrayList<>(); + + private Set functions = new HashSet<>(); + + private Boolean proxyToolCalls = false; + + private Map context = new HashMap<>(); + + public static FunctionCallingOptions.Builder builder() { + return new DefaultFunctionCallingOptionsBuilder(); + } + + @Override + public List getFunctionCallbacks() { + return Collections.unmodifiableList(this.functionCallbacks); + } + + public void setFunctionCallbacks(List functionCallbacks) { + Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null"); + this.functionCallbacks = new ArrayList<>(functionCallbacks); + } + + @Override + public Set getFunctions() { + return Collections.unmodifiableSet(this.functions); + } + + public void setFunctions(Set functions) { + Assert.notNull(functions, "Functions must not be null"); + this.functions = new HashSet<>(functions); + } + + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + + public Map getToolContext() { + return Collections.unmodifiableMap(this.context); + } + + public void setToolContext(Map context) { + Assert.notNull(context, "Context must not be null"); + this.context = new HashMap<>(context); + } + + @Override + public FunctionCallingOptions copy() { + return FunctionCallingOptions.builder() + .model(this.getModel()) + .frequencyPenalty(this.getFrequencyPenalty()) + .maxTokens(this.getMaxTokens()) + .presencePenalty(this.getPresencePenalty()) + .stopSequences(this.getStopSequences() != null ? new ArrayList<>(this.getStopSequences()) : null) + .temperature(this.getTemperature()) + .topK(this.getTopK()) + .topP(this.getTopP()) + .functions(new HashSet<>(this.functions)) + .functionCallbacks(new ArrayList<>(this.functionCallbacks)) + .proxyToolCalls(this.proxyToolCalls) + .toolContext(new HashMap<>(this.getToolContext())) + .build(); + } + + public FunctionCallingOptions merge(FunctionCallingOptions options) { + + var builder = FunctionCallingOptions.builder() + .model(StringUtils.hasText(options.getModel()) ? options.getModel() : this.getModel()) + .frequencyPenalty( + options.getFrequencyPenalty() != null ? options.getFrequencyPenalty() : this.getFrequencyPenalty()) + .maxTokens(options.getMaxTokens() != null ? options.getMaxTokens() : this.getMaxTokens()) + .presencePenalty( + options.getPresencePenalty() != null ? options.getPresencePenalty() : this.getPresencePenalty()) + .stopSequences(options.getStopSequences() != null ? options.getStopSequences() : this.getStopSequences()) + .temperature(options.getTemperature() != null ? options.getTemperature() : this.getTemperature()) + .topK(options.getTopK() != null ? options.getTopK() : this.getTopK()) + .topP(options.getTopP() != null ? options.getTopP() : this.getTopP()); + + builder.proxyToolCalls(options.getProxyToolCalls() != null ? options.getProxyToolCalls() : this.proxyToolCalls); + + Set functions = new HashSet<>(); + if (!CollectionUtils.isEmpty(this.functions)) { + functions.addAll(this.functions); + } + if (!CollectionUtils.isEmpty(options.getFunctions())) { + functions.addAll(options.getFunctions()); + } + builder.functions(functions); + + List functionCallbacks = new ArrayList<>(); + if (!CollectionUtils.isEmpty(this.functionCallbacks)) { + functionCallbacks.addAll(this.functionCallbacks); + } + if (!CollectionUtils.isEmpty(options.getFunctionCallbacks())) { + functionCallbacks.addAll(options.getFunctionCallbacks()); + } + builder.functionCallbacks(functionCallbacks); + + Map context = new HashMap<>(); + if (!CollectionUtils.isEmpty(this.context)) { + context.putAll(this.context); + } + if (!CollectionUtils.isEmpty(options.getToolContext())) { + context.putAll(options.getToolContext()); + } + builder.toolContext(context); + + return builder.build(); + } + + public FunctionCallingOptions merge(ChatOptions options) { + + var builder = FunctionCallingOptions.builder() + .model(StringUtils.hasText(options.getModel()) ? options.getModel() : this.getModel()) + .frequencyPenalty( + options.getFrequencyPenalty() != null ? options.getFrequencyPenalty() : this.getFrequencyPenalty()) + .maxTokens(options.getMaxTokens() != null ? options.getMaxTokens() : this.getMaxTokens()) + .presencePenalty( + options.getPresencePenalty() != null ? options.getPresencePenalty() : this.getPresencePenalty()) + .stopSequences(options.getStopSequences() != null ? options.getStopSequences() : this.getStopSequences()) + .temperature(options.getTemperature() != null ? options.getTemperature() : this.getTemperature()) + .topK(options.getTopK() != null ? options.getTopK() : this.getTopK()) + .topP(options.getTopP() != null ? options.getTopP() : this.getTopP()); + + return builder.build(); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilder.java new file mode 100644 index 00000000000..53718e44cc3 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilder.java @@ -0,0 +1,128 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.function; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.springframework.util.Assert; + +/** + * Default implementation of {@link FunctionCallingOptions.Builder}. + * + * @author Christian Tzolov + * @author Thomas Vitale + * @author Ilayaperumal Gopinathan + */ +public class DefaultFunctionCallingOptionsBuilder implements FunctionCallingOptions.Builder { + + private final DefaultFunctionCallingOptions functionCallingOptions = new DefaultFunctionCallingOptions(); + + public FunctionCallingOptions.Builder model(String model) { + this.functionCallingOptions.setModel(model); + return this; + } + + public FunctionCallingOptions.Builder frequencyPenalty(Double frequencyPenalty) { + this.functionCallingOptions.setFrequencyPenalty(frequencyPenalty); + return this; + } + + public FunctionCallingOptions.Builder maxTokens(Integer maxTokens) { + this.functionCallingOptions.setMaxTokens(maxTokens); + return this; + } + + public FunctionCallingOptions.Builder presencePenalty(Double presencePenalty) { + this.functionCallingOptions.setPresencePenalty(presencePenalty); + return this; + } + + public FunctionCallingOptions.Builder stopSequences(List stopSequences) { + this.functionCallingOptions.setStopSequences(stopSequences); + return this; + } + + public FunctionCallingOptions.Builder temperature(Double temperature) { + this.functionCallingOptions.setTemperature(temperature); + return this; + } + + public FunctionCallingOptions.Builder topK(Integer topK) { + this.functionCallingOptions.setTopK(topK); + return this; + } + + public FunctionCallingOptions.Builder topP(Double topP) { + this.functionCallingOptions.setTopP(topP); + return this; + } + + public FunctionCallingOptions.Builder functionCallbacks(List functionCallbacks) { + this.functionCallingOptions.setFunctionCallbacks(functionCallbacks); + return this; + } + + public FunctionCallingOptions.Builder functionCallbacks(FunctionCallback... functionCallbacks) { + Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null"); + this.functionCallingOptions.setFunctionCallbacks(List.of(functionCallbacks)); + return this; + } + + public FunctionCallingOptions.Builder functions(Set functions) { + this.functionCallingOptions.setFunctions(functions); + return this; + } + + public FunctionCallingOptions.Builder function(String function) { + Assert.notNull(function, "Function must not be null"); + var set = new HashSet<>(this.functionCallingOptions.getFunctions()); + set.add(function); + this.functionCallingOptions.setFunctions(set); + return this; + } + + public FunctionCallingOptions.Builder proxyToolCalls(Boolean proxyToolCalls) { + this.functionCallingOptions.setProxyToolCalls(proxyToolCalls); + return this; + } + + public FunctionCallingOptions.Builder toolContext(Map context) { + Assert.notNull(context, "Tool context must not be null"); + Map newContext = new HashMap<>(this.functionCallingOptions.getToolContext()); + newContext.putAll(context); + this.functionCallingOptions.setToolContext(newContext); + return this; + } + + public FunctionCallingOptions.Builder toolContext(String key, Object value) { + Assert.notNull(key, "Key must not be null"); + Assert.notNull(value, "Value must not be null"); + Map newContext = new HashMap<>(this.functionCallingOptions.getToolContext()); + newContext.put(key, value); + this.functionCallingOptions.setToolContext(newContext); + return this; + } + + public FunctionCallingOptions build() { + return this.functionCallingOptions; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingHelper.java index 7da7e87314d..4f50e9b582e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingHelper.java @@ -34,7 +34,6 @@ import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.util.CollectionUtils; /** @@ -45,7 +44,7 @@ public class FunctionCallingHelper extends AbstractToolCallSupport { public FunctionCallingHelper() { - this(null, PortableFunctionCallingOptions.builder().build(), List.of()); + this(null, FunctionCallingOptions.builder().build(), List.of()); } public FunctionCallingHelper(FunctionCallbackResolver functionCallbackResolver, diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java index 876c5289ab7..f30c7643816 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java @@ -27,15 +27,16 @@ * calling behavior of the ChatModel. * * @author Christian Tzolov + * @author Ilayaperumal Gopinathan */ public interface FunctionCallingOptions extends ChatOptions { /** - * @return Returns FunctionCallingOptionsBuilder to create a new instance of - * FunctionCallingOptions. + * @return Returns {@link DefaultFunctionCallingOptionsBuilder} to create a new + * instance of {@link FunctionCallingOptions}. */ - static FunctionCallingOptionsBuilder builder() { - return new FunctionCallingOptionsBuilder(); + static FunctionCallingOptions.Builder builder() { + return new DefaultFunctionCallingOptionsBuilder(); } /** @@ -83,4 +84,67 @@ default void setProxyToolCalls(Boolean proxyToolCalls) { void setToolContext(Map tooContext); + /** + * Builder for creating {@link FunctionCallingOptions} instance. + */ + interface Builder extends ChatOptions.Builder { + + /** + * The list of Function Callbacks to be registered with the Chat model. + * @param functionCallbacks the list of Function Callbacks. + * @return the FunctionCallOptions Builder. + */ + Builder functionCallbacks(List functionCallbacks); + + /** + * The Function Callbacks to be registered with the Chat model. + * @param functionCallbacks the function callbacks. + * @return the FunctionCallOptions Builder. + */ + Builder functionCallbacks(FunctionCallback... functionCallbacks); + + /** + * {@link Set} of function names to be registered with the Chat model. + * @param functions the {@link Set} of function names + * @return the FunctionCallOptions Builder. + */ + Builder functions(Set functions); + + /** + * The function name to be registered with the chat model. + * @param function the name of the function. + * @return the FunctionCallOptions Builder. + */ + Builder function(String function); + + /** + * Boolean flag to indicate if the proxy ToolCalls is enabled. + * @param proxyToolCalls boolean value to enable proxy ToolCalls. + * @return the FunctionCallOptions Builder. + */ + Builder proxyToolCalls(Boolean proxyToolCalls); + + /** + * Add a {@link Map} of context values into tool context. + * @param context the map representing the tool context. + * @return the FunctionCallOptions Builder. + */ + Builder toolContext(Map context); + + /** + * Add a specific key/value pair to the tool context. + * @param key the key to use. + * @param value the corresponding value. + * @return the FunctionCallOptions Builder. + */ + Builder toolContext(String key, Object value); + + /** + * Builds the {@link FunctionCallingOptions}. + * @return the FunctionCalling options. + */ + FunctionCallingOptions build(); + + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java index a4ceca490ef..f831a135110 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java @@ -34,10 +34,12 @@ * permits options portability between different AI providers that support * function-calling. * + * @deprecated Use {@link FunctionCallingOptions.Builder} instead. * @author Christian Tzolov * @author Thomas Vitale * @since 0.8.1 */ +@Deprecated(forRemoval = true, since = "1.0.0-M5") public class FunctionCallingOptionsBuilder { private final PortableFunctionCallingOptions options; @@ -136,6 +138,10 @@ public PortableFunctionCallingOptions build() { return this.options; } + /** + * @deprecated use {@link DefaultFunctionCallingOptions} instead. + */ + @Deprecated(forRemoval = true, since = "1.0.0-M5") public static class PortableFunctionCallingOptions implements FunctionCallingOptions { private List functionCallbacks = new ArrayList<>(); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java index bcdaf87761b..fd15f846573 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java @@ -86,11 +86,11 @@ void createFunctionCallingOptionTest() { functionCallbacks.add(cb); FunctionCallingOptions options = FunctionCallingOptions.builder() - .withFunctionCallbacks(functionCallbacks) - .withFunctions(functions) - .withTopK(topK) - .withTopP(topP) - .withTemperature(temperature) + .functionCallbacks(functionCallbacks) + .functions(functions) + .topK(topK) + .topP(topP) + .temperature(temperature) .build(); // Callback Functions diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java index 1254ec5db01..60cb699228c 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java @@ -40,10 +40,9 @@ import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.Media; +import org.springframework.ai.model.function.DefaultFunctionCallingOptions; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.util.MimeTypeUtils; @@ -199,7 +198,7 @@ void defaultSystemTextLambda() { @Test void mutateDefaults() { - PortableFunctionCallingOptions options = new FunctionCallingOptionsBuilder().build(); + FunctionCallingOptions options = new DefaultFunctionCallingOptions(); given(this.chatModel.getDefaultOptions()).willReturn(options); given(this.chatModel.call(this.promptCaptor.capture())) @@ -331,7 +330,7 @@ void mutateDefaults() { @Test void mutatePrompt() { - PortableFunctionCallingOptions options = new FunctionCallingOptionsBuilder().build(); + FunctionCallingOptions options = new DefaultFunctionCallingOptions(); given(this.chatModel.getDefaultOptions()).willReturn(options); given(this.chatModel.call(this.promptCaptor.capture())) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatProperties.java index 88f5068025a..bca0e2fba0d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatProperties.java @@ -16,7 +16,7 @@ package org.springframework.ai.autoconfigure.bedrock.converse; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; +import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.util.Assert; @@ -38,10 +38,10 @@ public class BedrockConverseProxyChatProperties { private boolean enabled = true; @NestedConfigurationProperty - private PortableFunctionCallingOptions options = PortableFunctionCallingOptions.builder() - .withTemperature(0.7) - .withMaxTokens(300) - .withTopK(10) + private FunctionCallingOptions options = FunctionCallingOptions.builder() + .temperature(0.7) + .maxTokens(300) + .topK(10) .build(); public boolean isEnabled() { @@ -52,12 +52,12 @@ public void setEnabled(boolean enabled) { this.enabled = enabled; } - public PortableFunctionCallingOptions getOptions() { + public FunctionCallingOptions getOptions() { return this.options; } - public void setOptions(PortableFunctionCallingOptions options) { - Assert.notNull(options, "PortableFunctionCallingOptions must not be null"); + public void setOptions(FunctionCallingOptions options) { + Assert.notNull(options, "FunctionCallingOptions must not be null"); this.options = options; } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java index 5a03f27ee76..57a3abfbb31 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java @@ -33,7 +33,7 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; +import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; @@ -96,7 +96,7 @@ void functionCallWithPortableFunctionCallingOptions() { "What's the weather like in San Francisco, in Paris, France and in Tokyo, Japan? Return the temperature in Celsius."); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - PortableFunctionCallingOptions.builder().withFunction("weatherFunction").build())); + FunctionCallingOptions.builder().function("weatherFunction").build())); logger.info("Response: {}", response); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java index 2917adaf2ce..bd43c5f7eb6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java @@ -31,7 +31,7 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; +import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; @@ -96,7 +96,7 @@ void functionCallWithPortableFunctionCallingOptions() { "What's the weather like in San Francisco, Paris and in Tokyo? Use Multi-turn function calling."); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - PortableFunctionCallingOptions.builder().withFunction("weatherFunction").build())); + FunctionCallingOptions.builder().function("weatherFunction").build())); logger.info("Response: {}", response); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/tool/FunctionCallWithFunctionBeanIT.java index 4adff902279..3cf71bec9dd 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/tool/FunctionCallWithFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/tool/FunctionCallWithFunctionBeanIT.java @@ -33,7 +33,6 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; @@ -65,14 +64,14 @@ void functionCallTest() { "What's the weather like in San Francisco, in Paris, France and in Tokyo, Japan? Return the temperature in Celsius."); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - PortableFunctionCallingOptions.builder().withFunction("weatherFunction").build())); + FunctionCallingOptions.builder().function("weatherFunction").build())); logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); response = chatModel.call(new Prompt(List.of(userMessage), - FunctionCallingOptions.builder().withFunction("weatherFunction3").build())); + FunctionCallingOptions.builder().function("weatherFunction3").build())); logger.info("Response: {}", response); @@ -94,7 +93,7 @@ void functionStreamTest() { "What's the weather like in San Francisco, in Paris, France and in Tokyo, Japan? Return the temperature in Celsius."); Flux responses = chatModel.stream(new Prompt(List.of(userMessage), - FunctionCallingOptions.builder().withFunction("weatherFunction").build())); + FunctionCallingOptions.builder().function("weatherFunction").build())); String content = responses.collectList() .block() diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/tool/FunctionCallWithPromptFunctionIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/tool/FunctionCallWithPromptFunctionIT.java index ed8c7b58ef2..5066c0a8950 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/tool/FunctionCallWithPromptFunctionIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/tool/FunctionCallWithPromptFunctionIT.java @@ -57,7 +57,7 @@ void functionCallTest() { "What's the weather like in San Francisco, in Paris and in Tokyo? Return the temperature in Celsius."); var promptOptions = FunctionCallingOptions.builder() - .withFunctionCallbacks(List.of(FunctionCallback.builder() + .functionCallbacks(List.of(FunctionCallback.builder() .function("CurrentWeatherService", new MockWeatherService()) .description("Get the weather in location. Return temperature in 36°F or 36°C format.") .inputType(MockWeatherService.Request.class) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java index 0c597d09bc7..6de078a40b6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java @@ -35,7 +35,6 @@ import org.springframework.ai.minimax.MiniMaxChatModel; import org.springframework.ai.minimax.MiniMaxChatOptions; import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -98,8 +97,8 @@ void functionCallWithPortableFunctionCallingOptions() { UserMessage userMessage = new UserMessage( "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius."); - PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() - .withFunction("weatherFunction") + FunctionCallingOptions functionOptions = FunctionCallingOptions.builder() + .function("weatherFunction") .build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java index 588c089b890..3ee61974605 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java @@ -39,7 +39,6 @@ import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -98,8 +97,8 @@ void functionCallWithPortableFunctionCallingOptions() { UserMessage userMessage = new UserMessage("What's the weather like in Paris? Use Celsius."); - PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() - .withFunctionCallbacks(List.of(FunctionCallback.builder() + FunctionCallingOptions functionOptions = FunctionCallingOptions.builder() + .functionCallbacks(List.of(FunctionCallback.builder() .function("CurrentWeatherService", new MyWeatherService()) .description("Get the current weather in requested location") .inputType(MyWeatherService.Request.class) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java index 0456b9e65b9..2fc7782bab0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -34,7 +34,6 @@ import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.ai.moonshot.MoonshotChatModel; import org.springframework.ai.moonshot.MoonshotChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; @@ -98,8 +97,8 @@ void functionCallWithPortableFunctionCallingOptions() { UserMessage userMessage = new UserMessage( "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius"); - PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() - .withFunction("weatherFunction") + FunctionCallingOptions functionOptions = FunctionCallingOptions.builder() + .function("weatherFunction") .build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/OllamaFunctionCallbackIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/OllamaFunctionCallbackIT.java index 716e6bcc042..b3c7353541b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/OllamaFunctionCallbackIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/OllamaFunctionCallbackIT.java @@ -20,7 +20,6 @@ import java.util.stream.Collectors; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,7 +34,6 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; @@ -120,9 +118,7 @@ void functionCallWithPortableFunctionCallingOptions() { UserMessage userMessage = new UserMessage( "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); - PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() - .withFunction("WeatherInfo") - .build(); + FunctionCallingOptions functionOptions = FunctionCallingOptions.builder().function("WeatherInfo").build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java index fb02d9bf6ec..27753c29dd7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -41,7 +41,6 @@ import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi.ChatModel; @@ -156,8 +155,8 @@ void trainScheduler() { UserMessage userMessage = new UserMessage( "Please schedule a train from San Francisco to Los Angeles on 2023-12-25"); - PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() - .withFunction("trainReservation") + FunctionCallingOptions functionOptions = FunctionCallingOptions.builder() + .function("trainReservation") .build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); @@ -267,8 +266,8 @@ void functionCallWithPortableFunctionCallingOptions() { // Test weatherFunction UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); - PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() - .withFunction("weatherFunction") + FunctionCallingOptions functionOptions = FunctionCallingOptions.builder() + .function("weatherFunction") .build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java index 530e564eb22..9da432b564d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java @@ -28,7 +28,7 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; +import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; @@ -109,7 +109,7 @@ void functionCallWithPortableFunctionCallingOptions() { """); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - PortableFunctionCallingOptions.builder().withFunction("weatherFunction").build())); + FunctionCallingOptions.builder().function("weatherFunction").build())); logger.info("Response: {}", response); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java index a83c29c65a8..9e3ae2106d3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -34,7 +34,6 @@ import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.ai.zhipuai.ZhiPuAiChatModel; import org.springframework.ai.zhipuai.ZhiPuAiChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; @@ -98,8 +97,8 @@ void functionCallWithPortableFunctionCallingOptions() { UserMessage userMessage = new UserMessage( "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius."); - PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() - .withFunction("weatherFunction") + FunctionCallingOptions functionOptions = FunctionCallingOptions.builder() + .function("weatherFunction") .build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/kotlin/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackContextKotlinIT.kt b/spring-ai-spring-boot-autoconfigure/src/test/kotlin/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackContextKotlinIT.kt index 0d171acc2e7..1187e2b1511 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/kotlin/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackContextKotlinIT.kt +++ b/spring-ai-spring-boot-autoconfigure/src/test/kotlin/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackContextKotlinIT.kt @@ -89,7 +89,7 @@ class FunctionCallbackResolverKotlinIT : BaseOllamaIT() { "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.") val functionOptions = FunctionCallingOptions.builder() - .withFunction("weatherInfo") + .function("weatherInfo") .build() val response = chatModel.call(Prompt(listOf(userMessage), functionOptions)); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/kotlin/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackKotlinIT.kt b/spring-ai-spring-boot-autoconfigure/src/test/kotlin/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackKotlinIT.kt index 3376f24c06e..aeb27dbc310 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/kotlin/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackKotlinIT.kt +++ b/spring-ai-spring-boot-autoconfigure/src/test/kotlin/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackKotlinIT.kt @@ -89,7 +89,7 @@ class FunctionCallbackKotlinIT : BaseOllamaIT() { "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.") val functionOptions = FunctionCallingOptions.builder() - .withFunction("WeatherInfo") + .function("WeatherInfo") .build() val response = chatModel.call(Prompt(listOf(userMessage), functionOptions)); From 0f6705bb96bcd60d01390e316b09d396a1ac3b3f Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Fri, 13 Dec 2024 23:18:31 +0000 Subject: [PATCH 2/2] Improve extensibility of DefaultChatOptionsBuilder - Enable DefaultChatOptionsBuilder to accommodate any other sub types - Introduce generics to support sub types that extend DefaultChatOptionsBuilder - Update builder methods to return the sub type - Make FunctionCallingOptions' builder()'s return type to accommodate sub types which can extend FunctionCallingOptions.Builder --- .../ai/chat/prompt/ChatOptions.java | 2 +- .../prompt/DefaultChatOptionsBuilder.java | 38 ++++++------ .../DefaultFunctionCallingOptions.java | 4 -- .../DefaultFunctionCallingOptionsBuilder.java | 59 ++++--------------- .../function/FunctionCallingOptions.java | 18 +++--- 5 files changed, 42 insertions(+), 79 deletions(-) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java index 01e2f7805d4..8cc03c5eda6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java @@ -94,7 +94,7 @@ public interface ChatOptions extends ModelOptions { * {@link ChatOptions}. * @return Returns a new {@link ChatOptions.Builder}. */ - static ChatOptions.Builder builder() { + static ChatOptions.Builder builder() { return new DefaultChatOptionsBuilder(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java index 7e067517e9f..1d84a704337 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java @@ -21,48 +21,52 @@ /** * Implementation of {@link ChatOptions.Builder} to create {@link DefaultChatOptions}. */ -public class DefaultChatOptionsBuilder implements ChatOptions.Builder { +public class DefaultChatOptionsBuilder> implements ChatOptions.Builder { private final DefaultChatOptions options = new DefaultChatOptions(); - public DefaultChatOptionsBuilder model(String model) { + protected T self() { + return (T) this; + } + + public T model(String model) { this.options.setModel(model); - return this; + return self(); } - public DefaultChatOptionsBuilder frequencyPenalty(Double frequencyPenalty) { + public T frequencyPenalty(Double frequencyPenalty) { this.options.setFrequencyPenalty(frequencyPenalty); - return this; + return self(); } - public DefaultChatOptionsBuilder maxTokens(Integer maxTokens) { + public T maxTokens(Integer maxTokens) { this.options.setMaxTokens(maxTokens); - return this; + return self(); } - public DefaultChatOptionsBuilder presencePenalty(Double presencePenalty) { + public T presencePenalty(Double presencePenalty) { this.options.setPresencePenalty(presencePenalty); - return this; + return self(); } - public DefaultChatOptionsBuilder stopSequences(List stop) { + public T stopSequences(List stop) { this.options.setStopSequences(stop); - return this; + return self(); } - public DefaultChatOptionsBuilder temperature(Double temperature) { + public T temperature(Double temperature) { this.options.setTemperature(temperature); - return this; + return self(); } - public DefaultChatOptionsBuilder topK(Integer topK) { + public T topK(Integer topK) { this.options.setTopK(topK); - return this; + return self(); } - public DefaultChatOptionsBuilder topP(Double topP) { + public T topP(Double topP) { this.options.setTopP(topP); - return this; + return self(); } public ChatOptions build() { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptions.java index 4a117b1cbe7..180f8ddfb7f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptions.java @@ -47,10 +47,6 @@ public class DefaultFunctionCallingOptions extends DefaultChatOptions implements private Map context = new HashMap<>(); - public static FunctionCallingOptions.Builder builder() { - return new DefaultFunctionCallingOptionsBuilder(); - } - @Override public List getFunctionCallbacks() { return Collections.unmodifiableList(this.functionCallbacks); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilder.java index 53718e44cc3..3d6d124150a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilder.java @@ -22,6 +22,7 @@ import java.util.Map; import java.util.Set; +import org.springframework.ai.chat.prompt.DefaultChatOptionsBuilder; import org.springframework.util.Assert; /** @@ -31,67 +32,29 @@ * @author Thomas Vitale * @author Ilayaperumal Gopinathan */ -public class DefaultFunctionCallingOptionsBuilder implements FunctionCallingOptions.Builder { +public class DefaultFunctionCallingOptionsBuilder + extends DefaultChatOptionsBuilder + implements FunctionCallingOptions.Builder { private final DefaultFunctionCallingOptions functionCallingOptions = new DefaultFunctionCallingOptions(); - public FunctionCallingOptions.Builder model(String model) { - this.functionCallingOptions.setModel(model); - return this; - } - - public FunctionCallingOptions.Builder frequencyPenalty(Double frequencyPenalty) { - this.functionCallingOptions.setFrequencyPenalty(frequencyPenalty); - return this; - } - - public FunctionCallingOptions.Builder maxTokens(Integer maxTokens) { - this.functionCallingOptions.setMaxTokens(maxTokens); - return this; - } - - public FunctionCallingOptions.Builder presencePenalty(Double presencePenalty) { - this.functionCallingOptions.setPresencePenalty(presencePenalty); - return this; - } - - public FunctionCallingOptions.Builder stopSequences(List stopSequences) { - this.functionCallingOptions.setStopSequences(stopSequences); - return this; - } - - public FunctionCallingOptions.Builder temperature(Double temperature) { - this.functionCallingOptions.setTemperature(temperature); - return this; - } - - public FunctionCallingOptions.Builder topK(Integer topK) { - this.functionCallingOptions.setTopK(topK); - return this; - } - - public FunctionCallingOptions.Builder topP(Double topP) { - this.functionCallingOptions.setTopP(topP); - return this; - } - - public FunctionCallingOptions.Builder functionCallbacks(List functionCallbacks) { + public DefaultFunctionCallingOptionsBuilder functionCallbacks(List functionCallbacks) { this.functionCallingOptions.setFunctionCallbacks(functionCallbacks); return this; } - public FunctionCallingOptions.Builder functionCallbacks(FunctionCallback... functionCallbacks) { + public DefaultFunctionCallingOptionsBuilder functionCallbacks(FunctionCallback... functionCallbacks) { Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null"); this.functionCallingOptions.setFunctionCallbacks(List.of(functionCallbacks)); return this; } - public FunctionCallingOptions.Builder functions(Set functions) { + public DefaultFunctionCallingOptionsBuilder functions(Set functions) { this.functionCallingOptions.setFunctions(functions); return this; } - public FunctionCallingOptions.Builder function(String function) { + public DefaultFunctionCallingOptionsBuilder function(String function) { Assert.notNull(function, "Function must not be null"); var set = new HashSet<>(this.functionCallingOptions.getFunctions()); set.add(function); @@ -99,12 +62,12 @@ public FunctionCallingOptions.Builder function(String function) { return this; } - public FunctionCallingOptions.Builder proxyToolCalls(Boolean proxyToolCalls) { + public DefaultFunctionCallingOptionsBuilder proxyToolCalls(Boolean proxyToolCalls) { this.functionCallingOptions.setProxyToolCalls(proxyToolCalls); return this; } - public FunctionCallingOptions.Builder toolContext(Map context) { + public DefaultFunctionCallingOptionsBuilder toolContext(Map context) { Assert.notNull(context, "Tool context must not be null"); Map newContext = new HashMap<>(this.functionCallingOptions.getToolContext()); newContext.putAll(context); @@ -112,7 +75,7 @@ public FunctionCallingOptions.Builder toolContext(Map context) { return this; } - public FunctionCallingOptions.Builder toolContext(String key, Object value) { + public DefaultFunctionCallingOptionsBuilder toolContext(String key, Object value) { Assert.notNull(key, "Key must not be null"); Assert.notNull(value, "Value must not be null"); Map newContext = new HashMap<>(this.functionCallingOptions.getToolContext()); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java index f30c7643816..0859df140aa 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java @@ -35,7 +35,7 @@ public interface FunctionCallingOptions extends ChatOptions { * @return Returns {@link DefaultFunctionCallingOptionsBuilder} to create a new * instance of {@link FunctionCallingOptions}. */ - static FunctionCallingOptions.Builder builder() { + static FunctionCallingOptions.Builder builder() { return new DefaultFunctionCallingOptionsBuilder(); } @@ -87,49 +87,49 @@ default void setProxyToolCalls(Boolean proxyToolCalls) { /** * Builder for creating {@link FunctionCallingOptions} instance. */ - interface Builder extends ChatOptions.Builder { + interface Builder> extends ChatOptions.Builder { /** * The list of Function Callbacks to be registered with the Chat model. * @param functionCallbacks the list of Function Callbacks. * @return the FunctionCallOptions Builder. */ - Builder functionCallbacks(List functionCallbacks); + T functionCallbacks(List functionCallbacks); /** * The Function Callbacks to be registered with the Chat model. * @param functionCallbacks the function callbacks. * @return the FunctionCallOptions Builder. */ - Builder functionCallbacks(FunctionCallback... functionCallbacks); + T functionCallbacks(FunctionCallback... functionCallbacks); /** * {@link Set} of function names to be registered with the Chat model. * @param functions the {@link Set} of function names * @return the FunctionCallOptions Builder. */ - Builder functions(Set functions); + T functions(Set functions); /** * The function name to be registered with the chat model. * @param function the name of the function. * @return the FunctionCallOptions Builder. */ - Builder function(String function); + T function(String function); /** * Boolean flag to indicate if the proxy ToolCalls is enabled. * @param proxyToolCalls boolean value to enable proxy ToolCalls. * @return the FunctionCallOptions Builder. */ - Builder proxyToolCalls(Boolean proxyToolCalls); + T proxyToolCalls(Boolean proxyToolCalls); /** * Add a {@link Map} of context values into tool context. * @param context the map representing the tool context. * @return the FunctionCallOptions Builder. */ - Builder toolContext(Map context); + T toolContext(Map context); /** * Add a specific key/value pair to the tool context. @@ -137,7 +137,7 @@ interface Builder extends ChatOptions.Builder { * @param value the corresponding value. * @return the FunctionCallOptions Builder. */ - Builder toolContext(String key, Object value); + T toolContext(String key, Object value); /** * Builds the {@link FunctionCallingOptions}.