Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
import org.springframework.ai.converter.MapOutputConverter;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.SpringBootConfiguration;
Expand Down Expand Up @@ -258,9 +258,7 @@ void multiModalityPdfTest() throws IOException {
List.of(new Media(new MimeType("application", "pdf"), pdfData)));

var response = this.chatModel.call(new Prompt(List.of(userMessage),
PortableFunctionCallingOptions.builder()
.withModel(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName())
.build()));
FunctionCallingOptions.builder().model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName()).build()));

assertThat(response.getResult().getOutput().getText()).containsAnyOf("Spring AI", "portable API");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,10 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.DefaultFunctionCallingOptions;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder;
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
Expand Down Expand Up @@ -322,12 +321,12 @@ else if (message.getMessageType() == MessageType.TOOL) {
if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof FunctionCallingOptions) {
var functionCallingOptions = (FunctionCallingOptions) prompt.getOptions();
updatedRuntimeOptions = ((PortableFunctionCallingOptions) updatedRuntimeOptions)
updatedRuntimeOptions = ((DefaultFunctionCallingOptions) updatedRuntimeOptions)
.merge(functionCallingOptions);
}
else if (prompt.getOptions() instanceof ChatOptions) {
var chatOptions = (ChatOptions) prompt.getOptions();
updatedRuntimeOptions = ((PortableFunctionCallingOptions) updatedRuntimeOptions).merge(chatOptions);
updatedRuntimeOptions = ((DefaultFunctionCallingOptions) updatedRuntimeOptions).merge(chatOptions);
}
}

Expand Down Expand Up @@ -697,7 +696,7 @@ public static final class Builder {

private Duration timeout = Duration.ofMinutes(10);

private FunctionCallingOptions defaultOptions = new FunctionCallingOptionsBuilder().build();
private FunctionCallingOptions defaultOptions = new DefaultFunctionCallingOptions();

private FunctionCallbackResolver functionCallbackResolver;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ void multiModalityEmbeddedImage(String modelName) throws IOException {

// @formatter:off
String response = ChatClient.create(this.chatModel).prompt()
.options(FunctionCallingOptions.builder().withModel(modelName).build())
.options(FunctionCallingOptions.builder().model(modelName).build())
.user(u -> u.text("Explain what do you see on this picture?")
.media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.png")))
.call()
Expand All @@ -394,7 +394,7 @@ void multiModalityImageUrl(String modelName) throws IOException {
// @formatter:off
String response = ChatClient.create(this.chatModel).prompt()
// TODO consider adding model(...) method to ChatClient as a shortcut to
.options(FunctionCallingOptions.builder().withModel(modelName).build())
.options(FunctionCallingOptions.builder().model(modelName).build())
.user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url))
.call()
.content();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public BedrockProxyChatModel bedrockConverseChatModel() {
.withRegion(Region.US_EAST_1)
.withTimeout(Duration.ofSeconds(120))
// .withRegion(Region.US_EAST_1)
.withDefaultOptions(FunctionCallingOptions.builder().withModel(modelId).build())
.withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.isA;
Expand Down Expand Up @@ -145,7 +144,7 @@ public void callWithToolUse() {
.build();

var result = this.chatModel.call(new Prompt("What is the weather in Paris?",
PortableFunctionCallingOptions.builder().withFunctionCallbacks(functionCallback).build()));
FunctionCallingOptions.builder().functionCallbacks(functionCallback).build()));

assertThat(result).isNotNull();
assertThat(result.getResult().getOutput().getText())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ void roleTest(String modelName) {
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage),
FunctionCallingOptions.builder().withModel(modelName).build());
FunctionCallingOptions.builder().model(modelName).build());
ChatResponse response = this.chatModel.call(prompt);
assertThat(response.getResults()).hasSize(1);
assertThat(response.getMetadata().getUsage().getGenerationTokens()).isGreaterThan(0);
Expand Down Expand Up @@ -126,7 +126,7 @@ void testMessageHistory() {

@Test
void streamingWithTokenUsage() {
var promptOptions = FunctionCallingOptions.builder().withTemperature(0.0).build();
var promptOptions = FunctionCallingOptions.builder().temperature(0.0).build();

var prompt = new Prompt("List two colors of the Polish flag. Be brief.", promptOptions);
var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage();
Expand Down Expand Up @@ -252,7 +252,7 @@ void functionCallTest() {
List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = FunctionCallingOptions.builder()
.withFunctionCallbacks(List.of(FunctionCallback.builder()
.functionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.description(
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
Expand All @@ -279,8 +279,8 @@ void streamFunctionCallTest() {
List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = FunctionCallingOptions.builder()
.withModel("anthropic.claude-3-5-sonnet-20240620-v1:0")
.withFunctionCallbacks(List.of(FunctionCallback.builder()
.model("anthropic.claude-3-5-sonnet-20240620-v1:0")
.functionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.description(
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
Expand All @@ -306,7 +306,7 @@ void validateCallResponseMetadata() {
String model = "anthropic.claude-3-5-sonnet-20240620-v1:0";
// @formatter:off
ChatResponse response = ChatClient.create(this.chatModel).prompt()
.options(FunctionCallingOptions.builder().withModel(model).build())
.options(FunctionCallingOptions.builder().model(model).build())
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
.call()
.chatResponse();
Expand All @@ -321,7 +321,7 @@ void validateStreamCallResponseMetadata() {
String model = "anthropic.claude-3-5-sonnet-20240620-v1:0";
// @formatter:off
ChatResponse response = ChatClient.create(this.chatModel).prompt()
.options(FunctionCallingOptions.builder().withModel(model).build())
.options(FunctionCallingOptions.builder().model(model).build())
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
.stream()
.chatResponse()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.beans.factory.annotation.Autowired;
Expand Down Expand Up @@ -68,13 +67,13 @@ void beforeEach() {

@Test
void observationForChatOperation() {
var options = PortableFunctionCallingOptions.builder()
.withModel("anthropic.claude-3-5-sonnet-20240620-v1:0")
.withMaxTokens(2048)
.withStopSequences(List.of("this-is-the-end"))
.withTemperature(0.7)
var options = FunctionCallingOptions.builder()
.model("anthropic.claude-3-5-sonnet-20240620-v1:0")
.maxTokens(2048)
.stopSequences(List.of("this-is-the-end"))
.temperature(0.7)
// .withTopK(1)
.withTopP(1.0)
.topP(1.0)
.build();

Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
Expand All @@ -90,12 +89,12 @@ void observationForChatOperation() {

@Test
void observationForStreamingChatOperation() {
var options = PortableFunctionCallingOptions.builder()
.withModel("anthropic.claude-3-5-sonnet-20240620-v1:0")
.withMaxTokens(2048)
.withStopSequences(List.of("this-is-the-end"))
.withTemperature(0.7)
.withTopP(1.0)
var options = FunctionCallingOptions.builder()
.model("anthropic.claude-3-5-sonnet-20240620-v1:0")
.maxTokens(2048)
.stopSequences(List.of("this-is-the-end"))
.temperature(0.7)
.topP(1.0)
.build();

Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
Expand Down Expand Up @@ -174,7 +173,7 @@ public BedrockProxyChatModel bedrockConverseChatModel(ObservationRegistry observ
.withCredentialsProvider(EnvironmentVariableCredentialsProvider.create())
.withRegion(Region.US_EAST_1)
.withObservationRegistry(observationRegistry)
.withDefaultOptions(FunctionCallingOptions.builder().withModel(modelId).build())
.withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ public BedrockProxyChatModel bedrockConverseChatModel() {
.withCredentialsProvider(EnvironmentVariableCredentialsProvider.create())
.withRegion(Region.US_EAST_1)
.withTimeout(Duration.ofSeconds(120))
.withDefaultOptions(FunctionCallingOptions.builder().withModel(modelId).build())
.withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import org.springframework.ai.bedrock.converse.MockWeatherService;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
import org.springframework.ai.model.function.FunctionCallingOptions;

/**
* Used for reverse engineering the protocol
Expand All @@ -50,9 +50,9 @@ public static void main(String[] args) {
// "What's the weather like in San Francisco, Tokyo, and Paris? Return the
// temperature in Celsius.",
"What's the weather like in Paris? Return the temperature in Celsius.",
PortableFunctionCallingOptions.builder()
.withModel(modelId)
.withFunctionCallbacks(List.of(FunctionCallback.builder()
FunctionCallingOptions.builder()
.model(modelId)
.functionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location")
.inputType(MockWeatherService.Request.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.springframework.ai.bedrock.converse.MockWeatherService;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
import org.springframework.ai.model.function.FunctionCallingOptions;

/**
* Used for reverse engineering the protocol
Expand All @@ -48,9 +48,9 @@ public static void main(String[] args) {
// "What's the weather like in San Francisco, Tokyo, and Paris? Return the
// temperature in Celsius.",
"What's the weather like in Paris? Return the temperature in Celsius.",
PortableFunctionCallingOptions.builder()
.withModel(modelId)
.withFunctionCallbacks(List.of(FunctionCallback.builder()
FunctionCallingOptions.builder()
.model(modelId)
.functionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location")
.inputType(MockWeatherService.Request.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,70 +94,70 @@ public interface ChatOptions extends ModelOptions {
* {@link ChatOptions}.
* @return Returns a new {@link ChatOptions.Builder}.
*/
static ChatOptions.Builder builder() {
static ChatOptions.Builder<? extends DefaultChatOptionsBuilder> builder() {
return new DefaultChatOptionsBuilder();
}

/**
* Builder for creating {@link ChatOptions} instance.
*/
interface Builder {
interface Builder<B extends Builder<B>> {

/**
* Builds with the model to use for the chat.
* @param model
* @return the builder
*/
Builder model(String model);
B model(String model);

/**
* Builds with the frequency penalty to use for the chat.
* @param frequencyPenalty
* @return the builder.
*/
Builder frequencyPenalty(Double frequencyPenalty);
B frequencyPenalty(Double frequencyPenalty);

/**
* Builds with the maximum number of tokens to use for the chat.
* @param maxTokens
* @return the builder.
*/
Builder maxTokens(Integer maxTokens);
B maxTokens(Integer maxTokens);

/**
* Builds with the presence penalty to use for the chat.
* @param presencePenalty
* @return the builder.
*/
Builder presencePenalty(Double presencePenalty);
B presencePenalty(Double presencePenalty);

/**
* Builds with the stop sequences to use for the chat.
* @param stopSequences
* @return the builder.
*/
Builder stopSequences(List<String> stopSequences);
B stopSequences(List<String> stopSequences);

/**
* Builds with the temperature to use for the chat.
* @param temperature
* @return the builder.
*/
Builder temperature(Double temperature);
B temperature(Double temperature);

/**
* Builds with the top K to use for the chat.
* @param topK
* @return the builder.
*/
Builder topK(Integer topK);
B topK(Integer topK);

/**
* Builds with the top P to use for the chat.
* @param topP
* @return the builder.
*/
Builder topP(Double topP);
B topP(Double topP);

/**
* Build the {@link ChatOptions}.
Expand Down
Loading
Loading