From bf84d5945e9499d01a9deff9435b2e772804afab Mon Sep 17 00:00:00 2001 From: Thomas Vitale Date: Sat, 10 Aug 2024 14:48:05 +0200 Subject: [PATCH] Streamline ChatOptions * Surface more configuration APIs to ChatOptions * Use abstraction in Observations directly instead of dedicated implementation * Simplify metadata config in observations for defined models * Improve merging of runtime and default options in OpenAI * Fix missing option in Mistral AI Relates to gh-1148 Signed-off-by: Thomas Vitale --- .../ai/anthropic/AnthropicChatOptions.java | 17 ++ .../azure/openai/AzureOpenAiChatOptions.java | 69 +++--- .../openai/AzureOpenAiEmbeddingOptions.java | 7 + .../anthropic/AnthropicChatOptions.java | 32 +++ .../anthropic3/Anthropic3ChatOptions.java | 22 ++ .../cohere/BedrockCohereChatOptions.java | 22 ++ .../cohere/BedrockCohereEmbeddingOptions.java | 3 + .../BedrockAi21Jurassic2ChatOptions.java | 102 ++++++--- .../llama/BedrockLlamaChatOptions.java | 43 +++- .../titan/BedrockTitanChatOptions.java | 39 +++- .../titan/BedrockTitanEmbeddingOptions.java | 3 + .../BedrockAi21Jurassic2ChatModelIT.java | 4 +- .../ai/minimax/MiniMaxChatOptions.java | 23 +- .../ai/minimax/MiniMaxEmbeddingOptions.java | 2 + .../ai/mistralai/MistralAiChatOptions.java | 64 +++++- .../mistralai/MistralAiEmbeddingOptions.java | 2 + .../ai/moonshot/MoonshotChatOptions.java | 33 ++- .../ai/ollama/api/OllamaOptions.java | 39 +++- .../ai/openai/OpenAiChatModel.java | 41 ++-- .../ai/openai/OpenAiChatOptions.java | 32 ++- .../ai/openai/OpenAiEmbeddingModel.java | 29 ++- .../ai/openai/OpenAiImageModel.java | 39 ++-- .../openai/api/common/OpenAiApiConstants.java | 5 + .../PostgresMlEmbeddingOptions.java | 3 + .../ai/qianfan/QianFanChatOptions.java | 17 +- .../ai/qianfan/QianFanEmbeddingOptions.java | 2 + .../ai/qianfan/QianFanImageOptions.java | 2 + .../api/StabilityAiImageOptions.java | 7 + .../gemini/VertexAiGeminiChatOptions.java | 27 ++- .../palm2/VertexAiPaLm2ChatOptions.java | 34 +++ .../ai/watsonx/WatsonxAiChatOptions.java | 38 +++- .../ai/zhipuai/ZhiPuAiChatOptions.java | 43 +++- .../ai/zhipuai/ZhiPuAiEmbeddingOptions.java | 2 + .../ai/zhipuai/ZhiPuAiImageOptions.java | 6 + .../ChatModelObservationContext.java | 24 ++- .../observation/ChatModelRequestOptions.java | 201 ------------------ ...DefaultChatModelObservationConvention.java | 18 +- .../ai/chat/prompt/ChatOptions.java | 25 ++- .../ai/chat/prompt/ChatOptionsBuilder.java | 124 +++++++++-- ...ltEmbeddingModelObservationConvention.java | 2 +- .../EmbeddingModelObservationContext.java | 20 +- ...efaultImageModelObservationConvention.java | 2 +- .../ImageModelObservationContext.java | 14 +- .../FunctionCallingOptionsBuilder.java | 121 +++++++++-- ...ModelCompletionObservationFilterTests.java | 23 +- ...ChatModelMeterObservationHandlerTests.java | 15 +- .../ChatModelObservationContextTests.java | 17 +- ...elPromptContentObservationFilterTests.java | 23 +- .../ChatModelRequestOptionsTests.java | 50 ----- ...ltChatModelObservationConventionTests.java | 72 ++++--- ...eddingModelObservationConventionTests.java | 24 +-- ...dingModelMeterObservationHandlerTests.java | 12 +- ...EmbeddingModelObservationContextTests.java | 14 +- ...tImageModelObservationConventionTests.java | 26 +-- .../ImageModelObservationContextTests.java | 14 +- ...elPromptContentObservationFilterTests.java | 16 +- .../ROOT/pages/api/chat/mistralai-chat.adoc | 2 +- .../modules/ROOT/pages/upgrade-notes.adoc | 7 + .../zhipuai/ZhiPuAiPropertiesTests.java | 6 - 59 files changed, 1031 insertions(+), 694 deletions(-) delete mode 100644 spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelRequestOptions.java delete mode 100644 spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelRequestOptionsTests.java diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index 9bc3fdabe1e..14bc85835a3 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -37,6 +37,7 @@ * The options to be used when sending a chat request to the Anthropic API. * * @author Christian Tzolov + * @author Thomas Vitale * @since 1.0.0 */ @JsonInclude(Include.NON_NULL) @@ -149,6 +150,7 @@ public AnthropicChatOptions build() { } + @Override public String getModel() { return model; } @@ -157,6 +159,7 @@ public void setModel(String model) { this.model = model; } + @Override public Integer getMaxTokens() { return this.maxTokens; } @@ -173,6 +176,7 @@ public void setMetadata(ChatCompletionRequest.Metadata metadata) { this.metadata = metadata; } + @Override public List getStopSequences() { return this.stopSequences; } @@ -199,6 +203,7 @@ public void setTopP(Float topP) { this.topP = topP; } + @Override public Integer getTopK() { return this.topK; } @@ -229,6 +234,18 @@ public void setFunctions(Set functions) { this.functions = functions; } + @Override + @JsonIgnore + public Float getFrequencyPenalty() { + return null; + } + + @Override + @JsonIgnore + public Float getPresencePenalty() { + return null; + } + @Override public AnthropicChatOptions copy() { return fromOptions(this); diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index 07ac7236bf7..eec7613d98c 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -38,6 +38,7 @@ * prompt data. * * @author Christian Tzolov + * @author Thomas Vitale */ @JsonInclude(Include.NON_NULL) public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptions { @@ -108,7 +109,7 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio * output new topics. */ @JsonProperty(value = "presence_penalty") - private Double presencePenalty; + private Float presencePenalty; /** * A value that influences the probability of generated tokens appearing based on @@ -117,7 +118,7 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio * model repeating the same statements verbatim. */ @JsonProperty(value = "frequency_penalty") - private Double frequencyPenalty; + private Float frequencyPenalty; /** * The deployment name as defined in Azure Open AI Studio when creating a deployment @@ -182,9 +183,7 @@ public Builder withDeploymentName(String deploymentName) { } public Builder withFrequencyPenalty(Float frequencyPenalty) { - if (frequencyPenalty != null) { - this.options.frequencyPenalty = frequencyPenalty.doubleValue(); - } + this.options.frequencyPenalty = frequencyPenalty; return this; } @@ -204,9 +203,7 @@ public Builder withN(Integer n) { } public Builder withPresencePenalty(Float presencePenalty) { - if (presencePenalty != null) { - this.options.presencePenalty = presencePenalty.doubleValue(); - } + this.options.presencePenalty = presencePenalty; return this; } @@ -259,6 +256,7 @@ public AzureOpenAiChatOptions build() { } + @Override public Integer getMaxTokens() { return this.maxTokens; } @@ -291,6 +289,17 @@ public void setN(Integer n) { this.n = n; } + @Override + @JsonIgnore + public List getStopSequences() { + return getStop(); + } + + @JsonIgnore + public void setStopSequences(List stopSequences) { + setStop(stopSequences); + } + public List getStop() { return this.stop; } @@ -299,22 +308,35 @@ public void setStop(List stop) { this.stop = stop; } - public Double getPresencePenalty() { + @Override + public Float getPresencePenalty() { return this.presencePenalty; } - public void setPresencePenalty(Double presencePenalty) { + public void setPresencePenalty(Float presencePenalty) { this.presencePenalty = presencePenalty; } - public Double getFrequencyPenalty() { + @Override + public Float getFrequencyPenalty() { return this.frequencyPenalty; } - public void setFrequencyPenalty(Double frequencyPenalty) { + public void setFrequencyPenalty(Float frequencyPenalty) { this.frequencyPenalty = frequencyPenalty; } + @Override + @JsonIgnore + public String getModel() { + return getDeploymentName(); + } + + @JsonIgnore + public void setModel(String model) { + setDeploymentName(model); + } + public String getDeploymentName() { return this.deploymentName; } @@ -341,17 +363,6 @@ public void setTopP(Float topP) { this.topP = topP; } - @Override - @JsonIgnore - public Integer getTopK() { - throw new UnsupportedOperationException("Unimplemented method 'getTopK'"); - } - - @JsonIgnore - public void setTopK(Integer topK) { - throw new UnsupportedOperationException("Unimplemented method 'setTopK'"); - } - @Override public List getFunctionCallbacks() { return this.functionCallbacks; @@ -378,6 +389,12 @@ public void setResponseFormat(AzureOpenAiResponseFormat responseFormat) { this.responseFormat = responseFormat; } + @Override + @JsonIgnore + public Integer getTopK() { + return null; + } + @Override public AzureOpenAiChatOptions copy() { return fromOptions(this); @@ -385,13 +402,11 @@ public AzureOpenAiChatOptions copy() { public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOptions) { return builder().withDeploymentName(fromOptions.getDeploymentName()) - .withFrequencyPenalty( - fromOptions.getFrequencyPenalty() != null ? fromOptions.getFrequencyPenalty().floatValue() : null) + .withFrequencyPenalty(fromOptions.getFrequencyPenalty() != null ? fromOptions.getFrequencyPenalty() : null) .withLogitBias(fromOptions.getLogitBias()) .withMaxTokens(fromOptions.getMaxTokens()) .withN(fromOptions.getN()) - .withPresencePenalty( - fromOptions.getPresencePenalty() != null ? fromOptions.getPresencePenalty().floatValue() : null) + .withPresencePenalty(fromOptions.getPresencePenalty() != null ? fromOptions.getPresencePenalty() : null) .withStop(fromOptions.getStop()) .withTemperature(fromOptions.getTemperature()) .withTopP(fromOptions.getTopP()) diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java index 48ec238166c..7713f95f633 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java @@ -17,6 +17,7 @@ import java.util.List; +import com.fasterxml.jackson.annotation.JsonIgnore; import org.springframework.ai.embedding.EmbeddingOptions; /** @@ -125,10 +126,16 @@ public AzureOpenAiEmbeddingOptions build() { } @Override + @JsonIgnore public String getModel() { return getDeploymentName(); } + @JsonIgnore + public void setModel(String model) { + setDeploymentName(model); + } + public String getUser() { return this.user; } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/AnthropicChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/AnthropicChatOptions.java index 55b37032f64..e12280dff11 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/AnthropicChatOptions.java @@ -17,6 +17,7 @@ import java.util.List; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -26,6 +27,7 @@ /** * @author Christian Tzolov + * @author Thomas Vitale */ @JsonInclude(Include.NON_NULL) public class AnthropicChatOptions implements ChatOptions { @@ -122,6 +124,17 @@ public void setTemperature(Float temperature) { this.temperature = temperature; } + @Override + @JsonIgnore + public Integer getMaxTokens() { + return getMaxTokensToSample(); + } + + @JsonIgnore + public void setMaxTokens(Integer maxTokens) { + setMaxTokensToSample(maxTokens); + } + public Integer getMaxTokensToSample() { return this.maxTokensToSample; } @@ -148,6 +161,7 @@ public void setTopP(Float topP) { this.topP = topP; } + @Override public List getStopSequences() { return this.stopSequences; } @@ -164,6 +178,24 @@ public void setAnthropicVersion(String anthropicVersion) { this.anthropicVersion = anthropicVersion; } + @Override + @JsonIgnore + public String getModel() { + return null; + } + + @Override + @JsonIgnore + public Float getFrequencyPenalty() { + return null; + } + + @Override + @JsonIgnore + public Float getPresencePenalty() { + return null; + } + @Override public AnthropicChatOptions copy() { return fromOptions(this); diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java index 0f44c212b3f..c573df9118a 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.bedrock.anthropic3; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; @@ -24,6 +25,7 @@ /** * @author Ben Middleton + * @author Thomas Vitale * @since 1.0.0 */ @JsonInclude(Include.NON_NULL) @@ -121,6 +123,7 @@ public void setTemperature(Float temperature) { this.temperature = temperature; } + @Override public Integer getMaxTokens() { return this.maxTokens; } @@ -147,6 +150,7 @@ public void setTopP(Float topP) { this.topP = topP; } + @Override public List getStopSequences() { return this.stopSequences; } @@ -163,6 +167,24 @@ public void setAnthropicVersion(String anthropicVersion) { this.anthropicVersion = anthropicVersion; } + @Override + @JsonIgnore + public String getModel() { + return null; + } + + @Override + @JsonIgnore + public Float getFrequencyPenalty() { + return null; + } + + @Override + @JsonIgnore + public Float getPresencePenalty() { + return null; + } + @Override public Anthropic3ChatOptions copy() { return fromOptions(this); diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java index 619aa8d7b82..4773c49da6c 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java @@ -17,6 +17,7 @@ import java.util.List; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; @@ -28,6 +29,7 @@ /** * @author Christian Tzolov + * @author Thomas Vitale * @since 0.8.0 */ @JsonInclude(Include.NON_NULL) @@ -165,6 +167,7 @@ public void setTopK(Integer topK) { this.topK = topK; } + @Override public Integer getMaxTokens() { return this.maxTokens; } @@ -173,6 +176,7 @@ public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } + @Override public List getStopSequences() { return this.stopSequences; } @@ -213,6 +217,24 @@ public void setTruncate(Truncate truncate) { this.truncate = truncate; } + @Override + @JsonIgnore + public String getModel() { + return null; + } + + @Override + @JsonIgnore + public Float getFrequencyPenalty() { + return null; + } + + @Override + @JsonIgnore + public Float getPresencePenalty() { + return null; + } + @Override public BedrockCohereChatOptions copy() { return fromOptions(this); diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingOptions.java index 04a2b37d607..068d704545c 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingOptions.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.bedrock.cohere; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; @@ -88,11 +89,13 @@ public void setTruncate(Truncate truncate) { } @Override + @JsonIgnore public String getModel() { return null; } @Override + @JsonIgnore public Integer getDimensions() { return null; } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatOptions.java index 723f427ca52..5c424fd226d 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatOptions.java @@ -16,14 +16,18 @@ package org.springframework.ai.bedrock.jurassic2; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.chat.prompt.ChatOptions; +import java.util.List; + /** * Request body for the /complete endpoint of the Jurassic-2 API. * * @author Ahmed Yousri + * @author Thomas Vitale * @since 1.0.0 */ @JsonInclude(JsonInclude.Include.NON_NULL) @@ -75,25 +79,25 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { * Stops decoding if any of the strings is generated. */ @JsonProperty("stopSequences") - private String[] stopSequences; + private List stopSequences; /** * Penalty object for frequency. */ @JsonProperty("frequencyPenalty") - private Penalty frequencyPenalty; + private Penalty frequencyPenaltyOptions; /** * Penalty object for presence. */ @JsonProperty("presencePenalty") - private Penalty presencePenalty; + private Penalty presencePenaltyOptions; /** * Penalty object for count. */ @JsonProperty("countPenalty") - private Penalty countPenalty; + private Penalty countPenaltyOptions; // Getters and setters @@ -133,6 +137,7 @@ public void setNumResults(Integer numResults) { * Gets the maximum number of tokens to generate per result. * @return The maximum number of tokens. */ + @Override public Integer getMaxTokens() { return maxTokens; } @@ -165,6 +170,7 @@ public void setMinTokens(Integer minTokens) { * Gets the temperature for modifying the token sampling distribution. * @return The temperature. */ + @Override public Float getTemperature() { return temperature; } @@ -182,6 +188,7 @@ public void setTemperature(Float temperature) { * mass. * @return The topP parameter. */ + @Override public Float getTopP() { return topP; } @@ -216,7 +223,8 @@ public void setTopK(Integer topK) { * Gets the stop sequences for stopping decoding if any of the strings is generated. * @return The stop sequences. */ - public String[] getStopSequences() { + @Override + public List getStopSequences() { return stopSequences; } @@ -224,56 +232,88 @@ public String[] getStopSequences() { * Sets the stop sequences for stopping decoding if any of the strings is generated. * @param stopSequences The stop sequences. */ - public void setStopSequences(String[] stopSequences) { + public void setStopSequences(List stopSequences) { this.stopSequences = stopSequences; } + @Override + @JsonIgnore + public Float getFrequencyPenalty() { + return getFrequencyPenaltyOptions() != null ? getFrequencyPenaltyOptions().scale() : null; + } + + @JsonIgnore + public void setFrequencyPenalty(Float frequencyPenalty) { + if (frequencyPenalty != null) { + setFrequencyPenaltyOptions(Penalty.builder().scale(frequencyPenalty).build()); + } + } + /** * Gets the frequency penalty object. * @return The frequency penalty object. */ - public Penalty getFrequencyPenalty() { - return frequencyPenalty; + public Penalty getFrequencyPenaltyOptions() { + return frequencyPenaltyOptions; } /** * Sets the frequency penalty object. - * @param frequencyPenalty The frequency penalty object. + * @param frequencyPenaltyOptions The frequency penalty object. */ - public void setFrequencyPenalty(Penalty frequencyPenalty) { - this.frequencyPenalty = frequencyPenalty; + public void setFrequencyPenaltyOptions(Penalty frequencyPenaltyOptions) { + this.frequencyPenaltyOptions = frequencyPenaltyOptions; + } + + @Override + @JsonIgnore + public Float getPresencePenalty() { + return getPresencePenaltyOptions() != null ? getPresencePenaltyOptions().scale() : null; + } + + @JsonIgnore + public void setPresencePenalty(Float presencePenalty) { + if (presencePenalty != null) { + setPresencePenaltyOptions(Penalty.builder().scale(presencePenalty).build()); + } } /** * Gets the presence penalty object. * @return The presence penalty object. */ - public Penalty getPresencePenalty() { - return presencePenalty; + public Penalty getPresencePenaltyOptions() { + return presencePenaltyOptions; } /** * Sets the presence penalty object. - * @param presencePenalty The presence penalty object. + * @param presencePenaltyOptions The presence penalty object. */ - public void setPresencePenalty(Penalty presencePenalty) { - this.presencePenalty = presencePenalty; + public void setPresencePenaltyOptions(Penalty presencePenaltyOptions) { + this.presencePenaltyOptions = presencePenaltyOptions; } /** * Gets the count penalty object. * @return The count penalty object. */ - public Penalty getCountPenalty() { - return countPenalty; + public Penalty getCountPenaltyOptions() { + return countPenaltyOptions; } /** * Sets the count penalty object. - * @param countPenalty The count penalty object. + * @param countPenaltyOptions The count penalty object. */ - public void setCountPenalty(Penalty countPenalty) { - this.countPenalty = countPenalty; + public void setCountPenaltyOptions(Penalty countPenaltyOptions) { + this.countPenaltyOptions = countPenaltyOptions; + } + + @Override + @JsonIgnore + public String getModel() { + return null; } public static Builder builder() { @@ -314,7 +354,7 @@ public Builder withTopP(Float topP) { return this; } - public Builder withStopSequences(String[] stopSequences) { + public Builder withStopSequences(List stopSequences) { request.setStopSequences(stopSequences); return this; } @@ -324,18 +364,18 @@ public Builder withTopK(Integer topKReturn) { return this; } - public Builder withFrequencyPenalty(BedrockAi21Jurassic2ChatOptions.Penalty frequencyPenalty) { - request.setFrequencyPenalty(frequencyPenalty); + public Builder withFrequencyPenaltyOptions(BedrockAi21Jurassic2ChatOptions.Penalty frequencyPenalty) { + request.setFrequencyPenaltyOptions(frequencyPenalty); return this; } - public Builder withPresencePenalty(BedrockAi21Jurassic2ChatOptions.Penalty presencePenalty) { - request.setPresencePenalty(presencePenalty); + public Builder withPresencePenaltyOptions(BedrockAi21Jurassic2ChatOptions.Penalty presencePenalty) { + request.setPresencePenaltyOptions(presencePenalty); return this; } - public Builder withCountPenalty(BedrockAi21Jurassic2ChatOptions.Penalty countPenalty) { - request.setCountPenalty(countPenalty); + public Builder withCountPenaltyOptions(BedrockAi21Jurassic2ChatOptions.Penalty countPenalty) { + request.setCountPenaltyOptions(countPenalty); return this; } @@ -427,9 +467,9 @@ public static BedrockAi21Jurassic2ChatOptions fromOptions(BedrockAi21Jurassic2Ch .withTopP(fromOptions.getTopP()) .withTopK(fromOptions.getTopK()) .withStopSequences(fromOptions.getStopSequences()) - .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) - .withPresencePenalty(fromOptions.getPresencePenalty()) - .withCountPenalty(fromOptions.getCountPenalty()) + .withFrequencyPenaltyOptions(fromOptions.getFrequencyPenaltyOptions()) + .withPresencePenaltyOptions(fromOptions.getPresencePenaltyOptions()) + .withCountPenaltyOptions(fromOptions.getCountPenaltyOptions()) .build(); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java index cf0e68ea782..8d26ccb309c 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java @@ -22,8 +22,11 @@ import org.springframework.ai.chat.prompt.ChatOptions; +import java.util.List; + /** * @author Christian Tzolov + * @author Thomas Vitale */ @JsonInclude(Include.NON_NULL) public class BedrockLlamaChatOptions implements ChatOptions { @@ -74,6 +77,7 @@ public BedrockLlamaChatOptions build() { } + @Override public Float getTemperature() { return this.temperature; } @@ -82,6 +86,7 @@ public void setTemperature(Float temperature) { this.temperature = temperature; } + @Override public Float getTopP() { return this.topP; } @@ -90,6 +95,17 @@ public void setTopP(Float topP) { this.topP = topP; } + @Override + @JsonIgnore + public Integer getMaxTokens() { + return getMaxGenLen(); + } + + @JsonIgnore + public void setMaxTokens(Integer maxTokens) { + setMaxGenLen(maxTokens); + } + public Integer getMaxGenLen() { return this.maxGenLen; } @@ -100,13 +116,32 @@ public void setMaxGenLen(Integer maxGenLen) { @Override @JsonIgnore - public Integer getTopK() { - throw new UnsupportedOperationException("Unsupported option: 'TopK'"); + public String getModel() { + return null; } + @Override @JsonIgnore - public void setTopK(Integer topK) { - throw new UnsupportedOperationException("Unsupported option: 'TopK'"); + public Float getFrequencyPenalty() { + return null; + } + + @Override + @JsonIgnore + public Float getPresencePenalty() { + return null; + } + + @Override + @JsonIgnore + public List getStopSequences() { + return null; + } + + @Override + @JsonIgnore + public Integer getTopK() { + return null; } @Override diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatOptions.java index 364cd1bc46c..67458a5b13b 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatOptions.java @@ -17,6 +17,7 @@ import java.util.List; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -26,6 +27,7 @@ /** * @author Christian Tzolov + * @author Thomas Vitale * @since 0.8.0 */ @JsonInclude(Include.NON_NULL) @@ -87,6 +89,7 @@ public BedrockTitanChatOptions build() { } + @Override public Float getTemperature() { return temperature; } @@ -95,6 +98,7 @@ public void setTemperature(Float temperature) { this.temperature = temperature; } + @Override public Float getTopP() { return topP; } @@ -103,6 +107,17 @@ public void setTopP(Float topP) { this.topP = topP; } + @Override + @JsonIgnore + public Integer getMaxTokens() { + return getMaxTokenCount(); + } + + @JsonIgnore + public void setMaxTokens(Integer maxTokens) { + setMaxTokenCount(maxTokens); + } + public Integer getMaxTokenCount() { return maxTokenCount; } @@ -111,6 +126,7 @@ public void setMaxTokenCount(Integer maxTokenCount) { this.maxTokenCount = maxTokenCount; } + @Override public List getStopSequences() { return stopSequences; } @@ -120,12 +136,27 @@ public void setStopSequences(List stopSequences) { } @Override - public Integer getTopK() { - throw new UnsupportedOperationException("Bedrock Titan Chat does not support the 'TopK' option."); + @JsonIgnore + public String getModel() { + return null; } - public void setTopK(Integer topK) { - throw new UnsupportedOperationException("Bedrock Titan Chat does not support the 'TopK' option.'"); + @Override + @JsonIgnore + public Float getFrequencyPenalty() { + return null; + } + + @Override + @JsonIgnore + public Float getPresencePenalty() { + return null; + } + + @Override + @JsonIgnore + public Integer getTopK() { + return null; } @Override diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java index 51a957b7185..28757f3b78d 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.bedrock.titan; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -64,11 +65,13 @@ public void setInputType(InputType inputType) { } @Override + @JsonIgnore public String getModel() { return null; } @Override + @JsonIgnore public Integer getDimensions() { return null; } diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java index b468fbcaba1..4366f5d7a65 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java @@ -77,7 +77,7 @@ void testEmojiPenaltyFalse() { .applyToEmojis(false) .build(); BedrockAi21Jurassic2ChatOptions options = new BedrockAi21Jurassic2ChatOptions.Builder() - .withPresencePenalty(penalty) + .withPresencePenaltyOptions(penalty) .build(); UserMessage userMessage = new UserMessage("Can you express happiness using an emoji like 😄 ?"); @@ -94,7 +94,7 @@ void emojiPenaltyWhenTrueByDefaultApplyPenaltyTest() { // applyToEmojis is by default true BedrockAi21Jurassic2ChatOptions.Penalty penalty = new BedrockAi21Jurassic2ChatOptions.Penalty.Builder().build(); BedrockAi21Jurassic2ChatOptions options = new BedrockAi21Jurassic2ChatOptions.Builder() - .withPresencePenalty(penalty) + .withPresencePenaltyOptions(penalty) .build(); UserMessage userMessage = new UserMessage("Can you express happiness using an emoji like 😄?"); diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java index 6a6bfc999ea..b924c0084b4 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java @@ -39,6 +39,7 @@ * @see FunctionCallingOptions * @see ChatOptions * @author Geng Rong + * @author Thomas Vitale * @since 1.0.0 M1 */ @JsonInclude(Include.NON_NULL) @@ -236,6 +237,7 @@ public MiniMaxChatOptions build() { } + @Override public String getModel() { return this.model; } @@ -244,6 +246,7 @@ public void setModel(String model) { this.model = model; } + @Override public Float getFrequencyPenalty() { return this.frequencyPenalty; } @@ -252,6 +255,7 @@ public void setFrequencyPenalty(Float frequencyPenalty) { this.frequencyPenalty = frequencyPenalty; } + @Override public Integer getMaxTokens() { return this.maxTokens; } @@ -268,6 +272,7 @@ public void setN(Integer n) { this.n = n; } + @Override public Float getPresencePenalty() { return this.presencePenalty; } @@ -292,6 +297,17 @@ public void setSeed(Integer seed) { this.seed = seed; } + @Override + @JsonIgnore + public List getStopSequences() { + return getStop(); + } + + @JsonIgnore + public void setStopSequences(List stopSequences) { + setStop(stopSequences); + } + public List getStop() { return this.stop; } @@ -356,12 +372,7 @@ public void setFunctions(Set functionNames) { @Override @JsonIgnore public Integer getTopK() { - throw new UnsupportedOperationException("Unimplemented method 'getTopK'"); - } - - @JsonIgnore - public void setTopK(Integer topK) { - throw new UnsupportedOperationException("Unimplemented method 'setTopK'"); + return null; } @Override diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingOptions.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingOptions.java index 15143b1c823..d265e2dd687 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingOptions.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingOptions.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.minimax; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; @@ -70,6 +71,7 @@ public void setModel(String model) { } @Override + @JsonIgnore public Integer getDimensions() { return null; } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index cadb5e8f90b..3a5523f2b56 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -36,6 +36,7 @@ /** * @author Ricken Bazolo * @author Christian Tzolov + * @author Thomas Vitale * @since 0.8.1 */ @JsonInclude(JsonInclude.Include.NON_NULL) @@ -85,6 +86,13 @@ public class MistralAiChatOptions implements FunctionCallingOptions, ChatOptions */ private @JsonProperty("response_format") ResponseFormat responseFormat; + /** + * Stop generation if this token is detected. Or if one of these tokens is detected + * when providing an array. + */ + @NestedConfigurationProperty + private @JsonProperty("stop") List stop; + /** * A list of tools the model may call. Currently, only functions are supported as a * tool. Use this to provide a list of functions the model may generate JSON inputs @@ -160,6 +168,11 @@ public Builder withRandomSeed(Integer randomSeed) { return this; } + public Builder withStop(List stop) { + this.options.setStop(stop); + return this; + } + public Builder withTemperature(Float temperature) { this.options.setTemperature(temperature); return this; @@ -208,6 +221,7 @@ public MistralAiChatOptions build() { } + @Override public String getModel() { return this.model; } @@ -216,6 +230,7 @@ public void setModel(String model) { this.model = model; } + @Override public Integer getMaxTokens() { return this.maxTokens; } @@ -248,6 +263,25 @@ public void setResponseFormat(ResponseFormat responseFormat) { this.responseFormat = responseFormat; } + @Override + @JsonIgnore + public List getStopSequences() { + return getStop(); + } + + @JsonIgnore + public void setStopSequences(List stopSequences) { + setStop(stopSequences); + } + + public List getStop() { + return this.stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + public void setTools(List tools) { this.tools = tools; } @@ -282,17 +316,6 @@ public void setTopP(Float topP) { this.topP = topP; } - @Override - @JsonIgnore - public Integer getTopK() { - throw new UnsupportedOperationException("Unsupported option: 'TopK'"); - } - - @JsonIgnore - public void setTopK(Integer topK) { - throw new UnsupportedOperationException("Unsupported option: 'TopK'"); - } - @Override public List getFunctionCallbacks() { return this.functionCallbacks; @@ -315,6 +338,24 @@ public void setFunctions(Set functions) { this.functions = functions; } + @Override + @JsonIgnore + public Float getFrequencyPenalty() { + return null; + } + + @Override + @JsonIgnore + public Float getPresencePenalty() { + return null; + } + + @Override + @JsonIgnore + public Integer getTopK() { + return null; + } + @Override public MistralAiChatOptions copy() { return fromOptions(this); @@ -328,6 +369,7 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) .withTemperature(fromOptions.getTemperature()) .withTopP(fromOptions.getTopP()) .withResponseFormat(fromOptions.getResponseFormat()) + .withStop(fromOptions.getStop()) .withTools(fromOptions.getTools()) .withToolChoice(fromOptions.getToolChoice()) .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingOptions.java index 8f66999e964..7abfa01fc81 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingOptions.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.mistralai; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; @@ -60,6 +61,7 @@ public void setEncodingFormat(String encodingFormat) { } @Override + @JsonIgnore public Integer getDimensions() { return null; } diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java index 739e34d1a50..7751df3266f 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java @@ -31,6 +31,7 @@ /** * @author Geng Rong + * @author Thomas Vitale */ @JsonInclude(JsonInclude.Include.NON_NULL) public class MoonshotChatOptions implements ChatOptions { @@ -229,6 +230,7 @@ public MoonshotChatOptions build() { } + @Override public String getModel() { return this.model; } @@ -237,6 +239,7 @@ public void setModel(String model) { this.model = model; } + @Override public Float getFrequencyPenalty() { return this.frequencyPenalty; } @@ -245,6 +248,7 @@ public void setFrequencyPenalty(Float frequencyPenalty) { this.frequencyPenalty = frequencyPenalty; } + @Override public Integer getMaxTokens() { return this.maxTokens; } @@ -261,6 +265,7 @@ public void setN(Integer n) { this.n = n; } + @Override public Float getPresencePenalty() { return this.presencePenalty; } @@ -269,6 +274,17 @@ public void setPresencePenalty(Float presencePenalty) { this.presencePenalty = presencePenalty; } + @Override + @JsonIgnore + public List getStopSequences() { + return getStop(); + } + + @JsonIgnore + public void setStopSequences(List stopSequences) { + setStop(stopSequences); + } + public List getStop() { return this.stop; } @@ -303,6 +319,12 @@ public void setUser(String user) { this.user = user; } + @Override + @JsonIgnore + public Integer getTopK() { + return null; + } + @Override public MoonshotChatOptions copy() { return builder().withModel(this.model) @@ -402,15 +424,4 @@ else if (!this.user.equals(other.user)) return true; } - @Override - @JsonIgnore - public Integer getTopK() { - throw new UnsupportedOperationException("Unimplemented method 'getTopK'"); - } - - @JsonIgnore - public void setTopK(Integer topK) { - throw new UnsupportedOperationException("Unimplemented method 'setTopK'"); - } - } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index 4485be23f9a..2105f74102b 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -635,6 +635,17 @@ public void setSeed(Integer seed) { this.seed = seed; } + @Override + @JsonIgnore + public Integer getMaxTokens() { + return getNumPredict(); + } + + @JsonIgnore + public void setMaxTokens(Integer maxTokens) { + setNumPredict(maxTokens); + } + public Integer getNumPredict() { return this.numPredict; } @@ -643,6 +654,7 @@ public void setNumPredict(Integer numPredict) { this.numPredict = numPredict; } + @Override public Integer getTopK() { return this.topK; } @@ -651,6 +663,7 @@ public void setTopK(Integer topK) { this.topK = topK; } + @Override public Float getTopP() { return this.topP; } @@ -683,6 +696,7 @@ public void setRepeatLastN(Integer repeatLastN) { this.repeatLastN = repeatLastN; } + @Override public Float getTemperature() { return this.temperature; } @@ -699,6 +713,7 @@ public void setRepeatPenalty(Float repeatPenalty) { this.repeatPenalty = repeatPenalty; } + @Override public Float getPresencePenalty() { return this.presencePenalty; } @@ -707,6 +722,7 @@ public void setPresencePenalty(Float presencePenalty) { this.presencePenalty = presencePenalty; } + @Override public Float getFrequencyPenalty() { return this.frequencyPenalty; } @@ -747,6 +763,17 @@ public void setPenalizeNewline(Boolean penalizeNewline) { this.penalizeNewline = penalizeNewline; } + @Override + @JsonIgnore + public List getStopSequences() { + return getStop(); + } + + @JsonIgnore + public void setStopSequences(List stopSequences) { + setStop(stopSequences); + } + public List getStop() { return this.stop; } @@ -763,11 +790,6 @@ public void setTruncate(Boolean truncate) { this.truncate = truncate; } - @Override - public Integer getDimensions() { - return null; - } - @Override public List getFunctionCallbacks() { return this.functionCallbacks; @@ -776,7 +798,6 @@ public List getFunctionCallbacks() { @Override public void setFunctionCallbacks(List functionCallbacks) { this.functionCallbacks = functionCallbacks; - } @Override @@ -789,6 +810,12 @@ public void setFunctions(Set functions) { this.functions = functions; } + @Override + @JsonIgnore + public Integer getDimensions() { + return null; + } + /** * Convert the {@link OllamaOptions} object to a {@link Map} of key/value pairs. * @return The {@link Map} of key/value pairs. diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 3ecad08d34f..25f7e81221e 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -44,16 +44,13 @@ import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; -import org.springframework.ai.chat.observation.ChatModelRequestOptions; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; -import org.springframework.ai.observation.AiOperationMetadata; -import org.springframework.ai.observation.conventions.AiOperationType; -import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice; @@ -62,6 +59,7 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.MediaContent; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; +import org.springframework.ai.openai.api.common.OpenAiApiConstants; import org.springframework.ai.openai.metadata.OpenAiUsage; import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor; import org.springframework.ai.retry.RetryUtils; @@ -214,7 +212,7 @@ public ChatResponse call(Prompt prompt) { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) - .operationMetadata(buildOperationMetadata()) + .provider(OpenAiApiConstants.PROVIDER_NAME) .requestOptions(buildRequestOptions(request)) .build(); @@ -287,7 +285,7 @@ public Flux stream(Prompt prompt) { final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) - .operationMetadata(buildOperationMetadata()) + .provider(OpenAiApiConstants.PROVIDER_NAME) .requestOptions(buildRequestOptions(request)) .build(); @@ -356,9 +354,7 @@ public Flux stream(Prompt prompt) { .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on - return new MessageAggregator().aggregate(flux, mergedChatResponse -> { - observationContext.setResponse(mergedChatResponse); - }); + return new MessageAggregator().aggregate(flux, observationContext::setResponse); }); } @@ -370,7 +366,7 @@ private MultiValueMap getAdditionalHttpHeaders(Prompt prompt) { headers.putAll(chatOptions.getHttpHeaders()); } return CollectionUtils.toMultiValueMap( - headers.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> List.of(e.getValue())))); + headers.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> List.of(e.getValue())))); } private Generation buildGeneration(Choice choice, Map metadata) { @@ -536,22 +532,15 @@ private List getFunctionTools(Set functionNames) }).toList(); } - private AiOperationMetadata buildOperationMetadata() { - return AiOperationMetadata.builder() - .operationType(AiOperationType.CHAT.value()) - .provider(AiProvider.OPENAI.value()) - .build(); - } - - private ChatModelRequestOptions buildRequestOptions(OpenAiApi.ChatCompletionRequest request) { - return ChatModelRequestOptions.builder() - .model(StringUtils.hasText(request.model()) ? request.model() : "unknown") - .frequencyPenalty(request.frequencyPenalty()) - .maxTokens(request.maxTokens()) - .presencePenalty(request.presencePenalty()) - .stopSequences(request.stop()) - .temperature(request.temperature()) - .topP(request.topP()) + private ChatOptions buildRequestOptions(OpenAiApi.ChatCompletionRequest request) { + return ChatOptionsBuilder.builder() + .withModel(request.model()) + .withFrequencyPenalty(request.frequencyPenalty()) + .withMaxTokens(request.maxTokens()) + .withPresencePenalty(request.presencePenalty()) + .withStopSequences(request.stop()) + .withTemperature(request.temperature()) + .withTopP(request.topP()) .build(); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index bd8373377ad..db5e5a6aa2f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -327,6 +327,7 @@ public void setStreamUsage(Boolean enableStreamUsage) { this.streamOptions = (enableStreamUsage) ? StreamOptions.INCLUDE_USAGE : null; } + @Override public String getModel() { return this.model; } @@ -335,6 +336,7 @@ public void setModel(String model) { this.model = model; } + @Override public Float getFrequencyPenalty() { return this.frequencyPenalty; } @@ -367,6 +369,7 @@ public void setTopLogprobs(Integer topLogprobs) { this.topLogprobs = topLogprobs; } + @Override public Integer getMaxTokens() { return this.maxTokens; } @@ -383,6 +386,7 @@ public void setN(Integer n) { this.n = n; } + @Override public Float getPresencePenalty() { return this.presencePenalty; } @@ -415,6 +419,17 @@ public void setSeed(Integer seed) { this.seed = seed; } + @Override + @JsonIgnore + public List getStopSequences() { + return getStop(); + } + + @JsonIgnore + public void setStopSequences(List stopSequences) { + setStop(stopSequences); + } + public List getStop() { return this.stop; } @@ -500,6 +515,12 @@ public void setHttpHeaders(Map httpHeaders) { this.httpHeaders = httpHeaders; } + @Override + @JsonIgnore + public Integer getTopK() { + return null; + } + @Override public int hashCode() { final int prime = 31; @@ -646,17 +667,6 @@ else if (!this.parallelToolCalls.equals(other.parallelToolCalls)) return true; } - @Override - @JsonIgnore - public Integer getTopK() { - throw new UnsupportedOperationException("Unimplemented method 'getTopK'"); - } - - @JsonIgnore - public void setTopK(Integer topK) { - throw new UnsupportedOperationException("Unimplemented method 'setTopK'"); - } - @Override public OpenAiChatOptions copy() { return OpenAiChatOptions.fromOptions(this); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java index 2b61674c2ea..4265589ae63 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java @@ -31,11 +31,9 @@ import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation; import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.observation.AiOperationMetadata; -import org.springframework.ai.observation.conventions.AiOperationType; -import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList; +import org.springframework.ai.openai.api.common.OpenAiApiConstants; import org.springframework.ai.openai.metadata.OpenAiUsage; import org.springframework.ai.retry.RetryUtils; import org.springframework.lang.Nullable; @@ -151,7 +149,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { var observationContext = EmbeddingModelObservationContext.builder() .embeddingRequest(request) - .operationMetadata(buildOperationMetadata()) + .provider(OpenAiApiConstants.PROVIDER_NAME) .requestOptions(requestOptions) .build(); @@ -195,25 +193,22 @@ private OpenAiApi.EmbeddingRequest> createRequest(EmbeddingRequest */ private OpenAiEmbeddingOptions mergeOptions(@Nullable EmbeddingOptions runtimeOptions, OpenAiEmbeddingOptions defaultOptions) { - if (runtimeOptions == null) { + var runtimeOptionsForProvider = ModelOptionsUtils.copyToTarget(runtimeOptions, EmbeddingOptions.class, + OpenAiEmbeddingOptions.class); + + if (runtimeOptionsForProvider == null) { return defaultOptions; } return OpenAiEmbeddingOptions.builder() // Handle portable embedding options - .withModel(ModelOptionsUtils.mergeOption(runtimeOptions.getModel(), defaultOptions.getModel())) - .withDimensions( - ModelOptionsUtils.mergeOption(runtimeOptions.getDimensions(), defaultOptions.getDimensions())) + .withModel(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel())) + .withDimensions(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getDimensions(), + defaultOptions.getDimensions())) // Handle OpenAI specific embedding options - .withEncodingFormat(defaultOptions.getEncodingFormat()) - .withUser(defaultOptions.getUser()) - .build(); - } - - private AiOperationMetadata buildOperationMetadata() { - return AiOperationMetadata.builder() - .operationType(AiOperationType.EMBEDDING.value()) - .provider(AiProvider.OPENAI.value()) + .withEncodingFormat(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getEncodingFormat(), + defaultOptions.getEncodingFormat())) + .withUser(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getUser(), defaultOptions.getUser())) .build(); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java index 88e6656d646..3fabb9f4976 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java @@ -30,13 +30,12 @@ import org.springframework.ai.image.observation.ImageModelObservationContext; import org.springframework.ai.image.observation.ImageModelObservationDocumentation; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.observation.AiOperationMetadata; -import org.springframework.ai.observation.conventions.AiOperationType; -import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.openai.api.OpenAiImageApi; +import org.springframework.ai.openai.api.common.OpenAiApiConstants; import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; +import org.springframework.lang.Nullable; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; @@ -131,7 +130,7 @@ public ImageResponse call(ImagePrompt imagePrompt) { var observationContext = ImageModelObservationContext.builder() .imagePrompt(imagePrompt) - .operationMetadata(buildOperationMetadata()) + .provider(OpenAiApiConstants.PROVIDER_NAME) .requestOptions(requestImageOptions) .build(); @@ -181,30 +180,28 @@ private ImageResponse convertResponse(ResponseEntity getStopSequences() { + return getStop(); + } + + @JsonIgnore + public void setStopSequences(List stopSequences) { + setStop(stopSequences); + } + public List getStop() { return this.stop; } @@ -212,7 +227,7 @@ public void setTopP(Float topP) { @Override @JsonIgnore public Integer getTopK() { - throw new UnsupportedOperationException("Unimplemented method 'getTopK'"); + return null; } @Override diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingOptions.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingOptions.java index 48864e2ada7..672b68ab2f6 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingOptions.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingOptions.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.qianfan; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; @@ -89,6 +90,7 @@ public void setUser(String user) { } @Override + @JsonIgnore public Integer getDimensions() { return null; } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageOptions.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageOptions.java index bd4d3bee92b..7ddbd701393 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageOptions.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageOptions.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.qianfan; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.image.ImageOptions; @@ -170,6 +171,7 @@ public void setHeight(Integer height) { } @Override + @JsonIgnore public String getResponseFormat() { return null; } diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java index a6cccaf8ecb..4bf839e36b7 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.stabilityai.api; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.image.ImageOptions; @@ -452,10 +453,16 @@ public void setSteps(Integer steps) { } @Override + @JsonIgnore public String getStyle() { return getStylePreset(); } + @JsonIgnore + public void setStyle(String style) { + setStylePreset(style); + } + public String getStylePreset() { return stylePreset; } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index ec42584824a..f5f4f76878b 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -34,6 +34,7 @@ /** * @author Christian Tzolov + * @author Thomas Vitale * @since 0.8.1 */ @JsonInclude(Include.NON_NULL) @@ -173,6 +174,7 @@ public VertexAiGeminiChatOptions build() { } + @Override public List getStopSequences() { return this.stopSequences; } @@ -200,7 +202,6 @@ public void setTopP(Float topP) { } @Override - @JsonIgnore public Integer getTopK() { return (this.topK != null) ? this.topK.intValue() : null; } @@ -222,6 +223,17 @@ public void setCandidateCount(Integer candidateCount) { this.candidateCount = candidateCount; } + @Override + @JsonIgnore + public Integer getMaxTokens() { + return getMaxOutputTokens(); + } + + @JsonIgnore + public void setMaxTokens(Integer maxTokens) { + setMaxOutputTokens(maxTokens); + } + public Integer getMaxOutputTokens() { return this.maxOutputTokens; } @@ -230,6 +242,7 @@ public void setMaxOutputTokens(Integer maxOutputTokens) { this.maxOutputTokens = maxOutputTokens; } + @Override public String getModel() { return this.model; } @@ -254,6 +267,18 @@ public void setFunctions(Set functions) { this.functions = functions; } + @Override + @JsonIgnore + public Float getFrequencyPenalty() { + return null; + } + + @Override + @JsonIgnore + public Float getPresencePenalty() { + return null; + } + @Override public int hashCode() { final int prime = 31; diff --git a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatOptions.java b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatOptions.java index 06f4cfa8bee..8e271453628 100644 --- a/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatOptions.java +++ b/models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/VertexAiPaLm2ChatOptions.java @@ -15,14 +15,18 @@ */ package org.springframework.ai.vertexai.palm2; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.chat.prompt.ChatOptions; +import java.util.List; + /** * @author Christian Tzolov + * @author Thomas Vitale */ @JsonInclude(Include.NON_NULL) public class VertexAiPaLm2ChatOptions implements ChatOptions { @@ -127,6 +131,36 @@ public void setTopK(Integer topK) { this.topK = topK; } + @Override + @JsonIgnore + public String getModel() { + return null; + } + + @Override + @JsonIgnore + public Integer getMaxTokens() { + return null; + } + + @Override + @JsonIgnore + public List getStopSequences() { + return null; + } + + @Override + @JsonIgnore + public Float getFrequencyPenalty() { + return null; + } + + @Override + @JsonIgnore + public Float getPresencePenalty() { + return null; + } + @Override public VertexAiPaLm2ChatOptions copy() { return fromOptions(this); diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java index ba87789ff9b..1ca67cb2c96 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java @@ -34,6 +34,7 @@ * * @author Pablo Sanchidrian Herrera * @author John Jairo Moreno Rojas + * @author Thomas Vitale * @since 1.0.0 * @see watsonx.ai @@ -124,6 +125,7 @@ public class WatsonxAiChatOptions implements ChatOptions { @JsonIgnore private ObjectMapper mapper = new ObjectMapper(); + @Override public Float getTemperature() { return temperature; } @@ -132,6 +134,7 @@ public void setTemperature(Float temperature) { this.temperature = temperature; } + @Override public Float getTopP() { return topP; } @@ -140,6 +143,7 @@ public void setTopP(Float topP) { this.topP = topP; } + @Override public Integer getTopK() { return topK; } @@ -156,6 +160,17 @@ public void setDecodingMethod(String decodingMethod) { this.decodingMethod = decodingMethod; } + @Override + @JsonIgnore + public Integer getMaxTokens() { + return getMaxNewTokens(); + } + + @JsonIgnore + public void setMaxTokens(Integer maxTokens) { + setMaxNewTokens(maxTokens); + } + public Integer getMaxNewTokens() { return maxNewTokens; } @@ -172,7 +187,8 @@ public void setMinNewTokens(Integer minNewTokens) { this.minNewTokens = minNewTokens; } - public List getStopSequences() { + @Override + public List getStopSequences() { return stopSequences; } @@ -180,7 +196,18 @@ public void setStopSequences(List stopSequences) { this.stopSequences = stopSequences; } - public Float getRepetitionPenalty() { + @Override + @JsonIgnore + public Float getPresencePenalty() { + return getRepetitionPenalty(); + } + + @JsonIgnore + public void setPresencePenalty(Float presencePenalty) { + setRepetitionPenalty(presencePenalty); + } + + public Float getRepetitionPenalty() { return repetitionPenalty; } @@ -196,6 +223,7 @@ public void setRandomSeed(Integer randomSeed) { this.randomSeed = randomSeed; } + @Override public String getModel() { return model; } @@ -218,6 +246,12 @@ public void addAdditionalProperty(String key, Object value) { additional.put(key, value); } + @Override + @JsonIgnore + public Float getFrequencyPenalty() { + return null; + } + public static Builder builder() { return new Builder(); } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java index 7a9c3968f02..9af8f6616db 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java @@ -36,6 +36,7 @@ * ZhiPuAiChatOptions represents the options for the ZhiPuAiChat model. * * @author Geng Rong + * @author Thomas Vitale * @since 1.0.0 M1 */ @JsonInclude(Include.NON_NULL) @@ -213,6 +214,7 @@ public ZhiPuAiChatOptions build() { } + @Override public String getModel() { return this.model; } @@ -221,6 +223,7 @@ public void setModel(String model) { this.model = model; } + @Override public Integer getMaxTokens() { return this.maxTokens; } @@ -229,6 +232,17 @@ public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } + @Override + @JsonIgnore + public List getStopSequences() { + return getStop(); + } + + @JsonIgnore + public void setStopSequences(List stopSequences) { + setStop(stopSequences); + } + public List getStop() { return this.stop; } @@ -314,6 +328,24 @@ public void setFunctions(Set functionNames) { this.functions = functionNames; } + @Override + @JsonIgnore + public Float getFrequencyPenalty() { + return null; + } + + @Override + @JsonIgnore + public Float getPresencePenalty() { + return null; + } + + @Override + @JsonIgnore + public Integer getTopK() { + return null; + } + @Override public int hashCode() { final int prime = 31; @@ -401,17 +433,6 @@ else if (!this.doSample.equals(other.doSample)) return true; } - @Override - @JsonIgnore - public Integer getTopK() { - throw new UnsupportedOperationException("Unimplemented method 'getTopK'"); - } - - @JsonIgnore - public void setTopK(Integer topK) { - throw new UnsupportedOperationException("Unimplemented method 'setTopK'"); - } - @Override public ZhiPuAiChatOptions copy() { return fromOptions(this); diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingOptions.java index 388384f2324..cbd75ad4e82 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingOptions.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.zhipuai; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; @@ -70,6 +71,7 @@ public void setModel(String model) { } @Override + @JsonIgnore public Integer getDimensions() { return null; } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java index 1f80d1a74b9..a6d1de3167e 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.zhipuai; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.image.ImageOptions; @@ -88,6 +89,7 @@ public ZhiPuAiImageOptions build() { } @Override + @JsonIgnore public Integer getN() { return null; } @@ -102,21 +104,25 @@ public void setModel(String model) { } @Override + @JsonIgnore public Integer getWidth() { return null; } @Override + @JsonIgnore public Integer getHeight() { return null; } @Override + @JsonIgnore public String getResponseFormat() { return null; } @Override + @JsonIgnore public String getStyle() { return null; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContext.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContext.java index 46a2368ceea..eb20f161a62 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContext.java @@ -16,9 +16,11 @@ package org.springframework.ai.chat.observation; import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.observation.ModelObservationContext; import org.springframework.ai.observation.AiOperationMetadata; +import org.springframework.ai.observation.conventions.AiOperationType; import org.springframework.util.Assert; /** @@ -29,16 +31,16 @@ */ public class ChatModelObservationContext extends ModelObservationContext { - private final ChatModelRequestOptions requestOptions; + private final ChatOptions requestOptions; - ChatModelObservationContext(Prompt prompt, AiOperationMetadata operationMetadata, - ChatModelRequestOptions requestOptions) { - super(prompt, operationMetadata); + ChatModelObservationContext(Prompt prompt, String provider, ChatOptions requestOptions) { + super(prompt, + AiOperationMetadata.builder().operationType(AiOperationType.CHAT.value()).provider(provider).build()); Assert.notNull(requestOptions, "requestOptions cannot be null"); this.requestOptions = requestOptions; } - public ChatModelRequestOptions getRequestOptions() { + public ChatOptions getRequestOptions() { return this.requestOptions; } @@ -50,9 +52,9 @@ public static class Builder { private Prompt prompt; - private AiOperationMetadata operationMetadata; + private String provider; - private ChatModelRequestOptions requestOptions; + private ChatOptions requestOptions; private Builder() { } @@ -62,18 +64,18 @@ public Builder prompt(Prompt prompt) { return this; } - public Builder operationMetadata(AiOperationMetadata operationMetadata) { - this.operationMetadata = operationMetadata; + public Builder provider(String provider) { + this.provider = provider; return this; } - public Builder requestOptions(ChatModelRequestOptions requestOptions) { + public Builder requestOptions(ChatOptions requestOptions) { this.requestOptions = requestOptions; return this; } public ChatModelObservationContext build() { - return new ChatModelObservationContext(prompt, operationMetadata, requestOptions); + return new ChatModelObservationContext(prompt, provider, requestOptions); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelRequestOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelRequestOptions.java deleted file mode 100644 index ca19bc08f53..00000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelRequestOptions.java +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Copyright 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.chat.observation; - -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.lang.Nullable; -import org.springframework.util.Assert; - -import java.util.List; - -/** - * Represents client-side options for chat model requests. - * - * @author Thomas Vitale - * @since 1.0.0 - */ -public class ChatModelRequestOptions implements ChatOptions { - - private final String model; - - @Nullable - private final Float frequencyPenalty; - - @Nullable - private final Integer maxTokens; - - @Nullable - private final Float presencePenalty; - - @Nullable - private final List stopSequences; - - @Nullable - private final Float temperature; - - @Nullable - private final Integer topK; - - @Nullable - private final Float topP; - - ChatModelRequestOptions(Builder builder) { - Assert.hasText(builder.model, "model cannot be null or empty"); - - this.model = builder.model; - this.frequencyPenalty = builder.frequencyPenalty; - this.maxTokens = builder.maxTokens; - this.presencePenalty = builder.presencePenalty; - this.stopSequences = builder.stopSequences; - this.temperature = builder.temperature; - this.topK = builder.topK; - this.topP = builder.topP; - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - - private String model; - - @Nullable - private Float frequencyPenalty; - - @Nullable - private Integer maxTokens; - - @Nullable - private Float presencePenalty; - - @Nullable - private List stopSequences; - - @Nullable - private Float temperature; - - @Nullable - private Integer topK; - - @Nullable - private Float topP; - - private Builder() { - } - - public Builder model(String model) { - this.model = model; - return this; - } - - public Builder frequencyPenalty(@Nullable Float frequencyPenalty) { - this.frequencyPenalty = frequencyPenalty; - return this; - } - - public Builder maxTokens(@Nullable Integer maxTokens) { - this.maxTokens = maxTokens; - return this; - } - - public Builder presencePenalty(@Nullable Float presencePenalty) { - this.presencePenalty = presencePenalty; - return this; - } - - public Builder stopSequences(@Nullable List stopSequences) { - this.stopSequences = stopSequences; - return this; - } - - public Builder temperature(@Nullable Float temperature) { - this.temperature = temperature; - return this; - } - - public Builder topK(@Nullable Integer topK) { - this.topK = topK; - return this; - } - - public Builder topP(@Nullable Float topP) { - this.topP = topP; - return this; - } - - public ChatModelRequestOptions build() { - return new ChatModelRequestOptions(this); - } - - } - - public String getModel() { - return this.model; - } - - @Nullable - public Float getFrequencyPenalty() { - return this.frequencyPenalty; - } - - @Nullable - public Integer getMaxTokens() { - return this.maxTokens; - } - - @Nullable - public Float getPresencePenalty() { - return this.presencePenalty; - } - - @Nullable - public List getStopSequences() { - return this.stopSequences; - } - - @Override - @Nullable - public Float getTemperature() { - return this.temperature; - } - - @Override - @Nullable - public Integer getTopK() { - return this.topK; - } - - @Override - @Nullable - public Float getTopP() { - return this.topP; - } - - @Override - public ChatOptions copy() { - return builder().model(this.model) - .frequencyPenalty(this.frequencyPenalty) - .maxTokens(this.maxTokens) - .presencePenalty(this.presencePenalty) - .stopSequences(this.stopSequences != null ? List.copyOf(this.stopSequences) : null) - .temperature(this.temperature) - .topK(this.topK) - .topP(this.topP) - .build(); - } - -} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConvention.java index 082b4a66d96..57122c883e5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConvention.java @@ -18,6 +18,7 @@ import io.micrometer.common.KeyValue; import io.micrometer.common.KeyValues; import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; import java.util.StringJoiner; @@ -29,6 +30,9 @@ */ public class DefaultChatModelObservationConvention implements ChatModelObservationConvention { + private static final KeyValue REQUEST_MODEL_NONE = KeyValue + .of(ChatModelObservationDocumentation.LowCardinalityKeyNames.REQUEST_MODEL, KeyValue.NONE_VALUE); + private static final KeyValue RESPONSE_MODEL_NONE = KeyValue .of(ChatModelObservationDocumentation.LowCardinalityKeyNames.RESPONSE_MODEL, KeyValue.NONE_VALUE); @@ -77,8 +81,11 @@ public String getName() { @Override public String getContextualName(ChatModelObservationContext context) { - return "%s %s".formatted(context.getOperationMetadata().operationType(), - context.getRequestOptions().getModel()); + if (StringUtils.hasText(context.getRequestOptions().getModel())) { + return "%s %s".formatted(context.getOperationMetadata().operationType(), + context.getRequestOptions().getModel()); + } + return context.getOperationMetadata().operationType(); } @Override @@ -98,8 +105,11 @@ protected KeyValue aiProvider(ChatModelObservationContext context) { } protected KeyValue requestModel(ChatModelObservationContext context) { - return KeyValue.of(ChatModelObservationDocumentation.LowCardinalityKeyNames.REQUEST_MODEL, - context.getRequestOptions().getModel()); + if (StringUtils.hasText(context.getRequestOptions().getModel())) { + return KeyValue.of(ChatModelObservationDocumentation.LowCardinalityKeyNames.REQUEST_MODEL, + context.getRequestOptions().getModel()); + } + return REQUEST_MODEL_NONE; } protected KeyValue responseModel(ChatModelObservationContext context) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java index 24545ad3c80..a217eb9407c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java @@ -16,18 +16,39 @@ package org.springframework.ai.chat.prompt; import org.springframework.ai.model.ModelOptions; +import org.springframework.lang.Nullable; + +import java.util.List; /** * The ChatOptions represent the common options, portable across different chat models. */ public interface ChatOptions extends ModelOptions { - Float getTemperature(); + @Nullable + String getModel(); - Float getTopP(); + @Nullable + Float getFrequencyPenalty(); + + @Nullable + Integer getMaxTokens(); + + @Nullable + Float getPresencePenalty(); + @Nullable + List getStopSequences(); + + @Nullable + Float getTemperature(); + + @Nullable Integer getTopK(); + @Nullable + Float getTopP(); + ChatOptions copy(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java index 3b99d3a6e75..9101892eab6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java @@ -15,32 +15,80 @@ */ package org.springframework.ai.chat.prompt; +import java.util.List; + public class ChatOptionsBuilder { - private class ChatOptionsImpl implements ChatOptions { + private static class DefaultChatOptions implements ChatOptions { + + private String model; + + private Float frequencyPenalty; + + private Integer maxTokens; + + private Float presencePenalty; + + private List stopSequences; private Float temperature; + private Integer topK; + private Float topP; - private Integer topK; + @Override + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } @Override - public Float getTemperature() { - return temperature; + public Float getFrequencyPenalty() { + return frequencyPenalty; } - public void setTemperature(Float temperature) { - this.temperature = temperature; + public void setFrequencyPenalty(Float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; } @Override - public Float getTopP() { - return topP; + public Integer getMaxTokens() { + return maxTokens; } - public void setTopP(Float topP) { - this.topP = topP; + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + @Override + public Float getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(Float presencePenalty) { + this.presencePenalty = presencePenalty; + } + + @Override + public List getStopSequences() { + return stopSequences; + } + + public void setStopSequences(List stopSequences) { + this.stopSequences = stopSequences; + } + + @Override + public Float getTemperature() { + return temperature; + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; } @Override @@ -52,14 +100,31 @@ public void setTopK(Integer topK) { this.topK = topK; } + @Override + public Float getTopP() { + return topP; + } + + public void setTopP(Float topP) { + this.topP = topP; + } + @Override public ChatOptions copy() { - return builder().withTemperature(this.temperature).withTopP(this.topP).withTopK(this.topK).build(); + return builder().withModel(this.model) + .withFrequencyPenalty(this.frequencyPenalty) + .withMaxTokens(this.maxTokens) + .withPresencePenalty(this.presencePenalty) + .withStopSequences(this.stopSequences != null ? List.copyOf(this.stopSequences) : null) + .withTemperature(this.temperature) + .withTopK(this.topK) + .withTopP(this.topP) + .build(); } } - private final ChatOptionsImpl options = new ChatOptionsImpl(); + private final DefaultChatOptions options = new DefaultChatOptions(); private ChatOptionsBuilder() { } @@ -68,13 +133,33 @@ public static ChatOptionsBuilder builder() { return new ChatOptionsBuilder(); } - public ChatOptionsBuilder withTemperature(Float temperature) { - options.setTemperature(temperature); + public ChatOptionsBuilder withModel(String model) { + options.setModel(model); return this; } - public ChatOptionsBuilder withTopP(Float topP) { - options.setTopP(topP); + public ChatOptionsBuilder withFrequencyPenalty(Float frequencyPenalty) { + options.setFrequencyPenalty(frequencyPenalty); + return this; + } + + public ChatOptionsBuilder withMaxTokens(Integer maxTokens) { + options.setMaxTokens(maxTokens); + return this; + } + + public ChatOptionsBuilder withPresencePenalty(Float presencePenalty) { + options.setPresencePenalty(presencePenalty); + return this; + } + + public ChatOptionsBuilder withStopSequences(List stop) { + options.setStopSequences(stop); + return this; + } + + public ChatOptionsBuilder withTemperature(Float temperature) { + options.setTemperature(temperature); return this; } @@ -83,8 +168,13 @@ public ChatOptionsBuilder withTopK(Integer topK) { return this; } + public ChatOptionsBuilder withTopP(Float topP) { + options.setTopP(topP); + return this; + } + public ChatOptions build() { return options; } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConvention.java index 1331b2027e6..10117088f42 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConvention.java @@ -67,7 +67,7 @@ public KeyValues getLowCardinalityKeyValues(EmbeddingModelObservationContext con protected KeyValue aiOperationType(EmbeddingModelObservationContext context) { return KeyValue.of(EmbeddingModelObservationDocumentation.LowCardinalityKeyNames.AI_OPERATION_TYPE, - context.getOperationType()); + context.getOperationMetadata().operationType()); } protected KeyValue aiProvider(EmbeddingModelObservationContext context) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContext.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContext.java index 2e209d35f34..2b6b09c6771 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContext.java @@ -33,9 +33,13 @@ public class EmbeddingModelObservationContext extends ModelObservationContext stopSequences) { + this.options.setStopSequences(stopSequences); + return this; + } + + public FunctionCallingOptionsBuilder withTemperature(Float temperature) { + this.options.setTemperature(temperature); return this; } @@ -76,6 +96,11 @@ public FunctionCallingOptionsBuilder withTopK(Integer topK) { return this; } + public FunctionCallingOptionsBuilder withTopP(Float topP) { + this.options.setTopP(topP); + return this; + } + public PortableFunctionCallingOptions build() { return this.options; } @@ -86,12 +111,22 @@ public static class PortableFunctionCallingOptions implements FunctionCallingOpt private Set functions = new HashSet<>(); - private Float temperature; + private String model; - private Float topP; + private Float frequencyPenalty; + + private Integer maxTokens; + + private Float presencePenalty; + + private List stopSequences; + + private Float temperature; private Integer topK; + private Float topP; + @Override public List getFunctionCallbacks() { return this.functionCallbacks; @@ -113,37 +148,87 @@ public void setFunctions(Set functions) { } @Override - public Float getTemperature() { - return this.temperature; + public String getModel() { + return model; } - public void setTemperature(Float temperature) { - this.temperature = temperature; + public void setModel(String model) { + this.model = model; } @Override - public Float getTopP() { - return this.topP; + public Float getFrequencyPenalty() { + return frequencyPenalty; } - public void setTopP(Float topP) { - this.topP = topP; + public void setFrequencyPenalty(Float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + @Override + public Integer getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + @Override + public Float getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(Float presencePenalty) { + this.presencePenalty = presencePenalty; + } + + @Override + public List getStopSequences() { + return stopSequences; + } + + public void setStopSequences(List stopSequences) { + this.stopSequences = stopSequences; + } + + @Override + public Float getTemperature() { + return temperature; + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; } @Override public Integer getTopK() { - return this.topK; + return topK; } public void setTopK(Integer topK) { this.topK = topK; } + @Override + public Float getTopP() { + return topP; + } + + public void setTopP(Float topP) { + this.topP = topP; + } + @Override public ChatOptions copy() { - return new FunctionCallingOptionsBuilder().withTemperature(this.temperature) - .withTopP(this.topP) + return new FunctionCallingOptionsBuilder().withModel(this.model) + .withFrequencyPenalty(this.frequencyPenalty) + .withMaxTokens(this.maxTokens) + .withPresencePenalty(this.presencePenalty) + .withStopSequences(this.stopSequences) + .withTemperature(this.temperature) .withTopK(this.topK) + .withTopP(this.topP) .withFunctions(this.functions) .withFunctionCallbacks(this.functionCallbacks) .build(); @@ -151,4 +236,4 @@ public ChatOptions copy() { } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java index 44180e8ade6..0276568dd1c 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java @@ -21,10 +21,8 @@ import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.observation.AiOperationMetadata; -import org.springframework.ai.observation.conventions.AiOperationType; -import org.springframework.ai.observation.conventions.AiProvider; import java.util.List; @@ -52,8 +50,8 @@ void whenNotSupportedObservationContextThenReturnOriginalContext() { void whenEmptyResponseThenReturnOriginalContext() { var expectedContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) - .operationMetadata(generateOperationMetadata()) - .requestOptions(ChatModelRequestOptions.builder().model("mistral").build()) + .provider("superprovider") + .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); var actualContext = observationFilter.map(expectedContext); @@ -64,8 +62,8 @@ void whenEmptyResponseThenReturnOriginalContext() { void whenEmptyCompletionThenReturnOriginalContext() { var expectedContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) - .operationMetadata(generateOperationMetadata()) - .requestOptions(ChatModelRequestOptions.builder().model("mistral").build()) + .provider("superprovider") + .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); expectedContext.setResponse(new ChatResponse(List.of(new Generation(new AssistantMessage(""))))); var actualContext = observationFilter.map(expectedContext); @@ -77,8 +75,8 @@ void whenEmptyCompletionThenReturnOriginalContext() { void whenCompletionWithTextThenAugmentContext() { var originalContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) - .operationMetadata(generateOperationMetadata()) - .requestOptions(ChatModelRequestOptions.builder().model("mistral").build()) + .provider("superprovider") + .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); originalContext.setResponse(new ChatResponse(List.of(new Generation(new AssistantMessage("say please")), new Generation(new AssistantMessage("seriously, say please"))))); @@ -92,11 +90,4 @@ private Prompt generatePrompt() { return new Prompt("supercalifragilisticexpialidocious"); } - private AiOperationMetadata generateOperationMetadata() { - return AiOperationMetadata.builder() - .operationType(AiOperationType.CHAT.value()) - .provider(AiProvider.OLLAMA.value()) - .build(); - } - } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java index 8c1414c0360..08edea5d18c 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java @@ -26,8 +26,8 @@ import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.observation.AiOperationMetadata; import org.springframework.ai.observation.conventions.*; import java.util.List; @@ -70,7 +70,7 @@ void shouldCreateAllMetersDuringAnObservation() { assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()).meters()).hasSize(3); assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), AiOperationType.CHAT.value()) - .tag(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.OLLAMA.value()) + .tag(LowCardinalityKeyNames.AI_PROVIDER.asString(), "superprovider") .tag(LowCardinalityKeyNames.REQUEST_MODEL.asString(), "mistral") .tag(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), "mistral-42") .meters()).hasSize(3); @@ -88,8 +88,8 @@ void shouldCreateAllMetersDuringAnObservation() { private ChatModelObservationContext generateObservationContext() { return ChatModelObservationContext.builder() .prompt(generatePrompt()) - .operationMetadata(generateOperationMetadata()) - .requestOptions(ChatModelRequestOptions.builder().model("mistral").build()) + .provider("superprovider") + .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); } @@ -97,13 +97,6 @@ private Prompt generatePrompt() { return new Prompt("hello"); } - private AiOperationMetadata generateOperationMetadata() { - return AiOperationMetadata.builder() - .operationType(AiOperationType.CHAT.value()) - .provider(AiProvider.OLLAMA.value()) - .build(); - } - static class TestUsage implements Usage { @Override diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java index 342d36fb70d..e723b91263c 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java @@ -16,10 +16,8 @@ package org.springframework.ai.chat.observation; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.observation.AiOperationMetadata; -import org.springframework.ai.observation.conventions.AiOperationType; -import org.springframework.ai.observation.conventions.AiProvider; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -35,8 +33,8 @@ class ChatModelObservationContextTests { void whenMandatoryRequestOptionsThenReturn() { var observationContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) - .operationMetadata(generateOperationMetadata()) - .requestOptions(ChatModelRequestOptions.builder().model("supermodel").build()) + .provider("superprovider") + .requestOptions(ChatOptionsBuilder.builder().withModel("supermodel").build()) .build(); assertThat(observationContext).isNotNull(); @@ -46,7 +44,7 @@ void whenMandatoryRequestOptionsThenReturn() { void whenRequestOptionsIsNullThenThrow() { assertThatThrownBy(() -> ChatModelObservationContext.builder() .prompt(generatePrompt()) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(null) .build()).isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("requestOptions cannot be null"); @@ -56,11 +54,4 @@ private Prompt generatePrompt() { return new Prompt("hello"); } - private AiOperationMetadata generateOperationMetadata() { - return AiOperationMetadata.builder() - .operationType(AiOperationType.CHAT.value()) - .provider(AiProvider.OLLAMA.value()) - .build(); - } - } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java index faba1ab2eb2..92d9e0d8b43 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java @@ -20,10 +20,8 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.observation.AiOperationMetadata; -import org.springframework.ai.observation.conventions.AiOperationType; -import org.springframework.ai.observation.conventions.AiProvider; import java.util.List; @@ -51,8 +49,8 @@ void whenNotSupportedObservationContextThenReturnOriginalContext() { void whenEmptyPromptThenReturnOriginalContext() { var expectedContext = ChatModelObservationContext.builder() .prompt(new Prompt(List.of())) - .operationMetadata(generateOperationMetadata()) - .requestOptions(ChatModelRequestOptions.builder().model("mistral").build()) + .provider("superprovider") + .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); var actualContext = observationFilter.map(expectedContext); @@ -63,8 +61,8 @@ void whenEmptyPromptThenReturnOriginalContext() { void whenPromptWithTextThenAugmentContext() { var originalContext = ChatModelObservationContext.builder() .prompt(new Prompt("supercalifragilisticexpialidocious")) - .operationMetadata(generateOperationMetadata()) - .requestOptions(ChatModelRequestOptions.builder().model("mistral").build()) + .provider("superprovider") + .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); var augmentedContext = observationFilter.map(originalContext); @@ -77,8 +75,8 @@ void whenPromptWithMessagesThenAugmentContext() { var originalContext = ChatModelObservationContext.builder() .prompt(new Prompt(List.of(new SystemMessage("you're a chimney sweep"), new UserMessage("supercalifragilisticexpialidocious")))) - .operationMetadata(generateOperationMetadata()) - .requestOptions(ChatModelRequestOptions.builder().model("mistral").build()) + .provider("superprovider") + .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); var augmentedContext = observationFilter.map(originalContext); @@ -87,11 +85,4 @@ void whenPromptWithMessagesThenAugmentContext() { "[\"you're a chimney sweep\", \"supercalifragilisticexpialidocious\"]")); } - private AiOperationMetadata generateOperationMetadata() { - return AiOperationMetadata.builder() - .operationType(AiOperationType.CHAT.value()) - .provider(AiProvider.OLLAMA.value()) - .build(); - } - } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelRequestOptionsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelRequestOptionsTests.java deleted file mode 100644 index 18eed1f9174..00000000000 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelRequestOptionsTests.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.chat.observation; - -import org.junit.jupiter.api.Test; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Unit tests for {@link ChatModelRequestOptions}. - * - * @author Thomas Vitale - */ -class ChatModelRequestOptionsTests { - - @Test - void whenMandatoryRequestOptionsThenReturn() { - var requestOptions = ChatModelRequestOptions.builder().model("rowena").build(); - - assertThat(requestOptions).isNotNull(); - } - - @Test - void whenModelIsNullThenThrow() { - assertThatThrownBy(() -> ChatModelRequestOptions.builder().build()).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("model cannot be null or empty"); - } - - @Test - void whenModelIsEmptyThenThrow() { - assertThatThrownBy(() -> ChatModelRequestOptions.builder().model("").build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("model cannot be null or empty"); - } - -} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java index 90c2b5f0a01..2984943a873 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java @@ -24,10 +24,8 @@ import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.observation.AiOperationMetadata; -import org.springframework.ai.observation.conventions.AiOperationType; -import org.springframework.ai.observation.conventions.AiProvider; import java.util.List; @@ -50,53 +48,63 @@ void shouldHaveName() { } @Test - void shouldHaveContextualName() { + void contextualNameWhenModelIsDefined() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) - .operationMetadata(generateOperationMetadata()) - .requestOptions(ChatModelRequestOptions.builder().model("mistral").build()) + .provider("superprovider") + .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("chat mistral"); } + @Test + void contextualNameWhenModelIsNotDefined() { + ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + .prompt(generatePrompt()) + .provider("superprovider") + .requestOptions(ChatOptionsBuilder.builder().build()) + .build(); + assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("chat"); + } + @Test void supportsOnlyChatModelObservationContext() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) - .operationMetadata(generateOperationMetadata()) - .requestOptions(ChatModelRequestOptions.builder().model("mistral").build()) + .provider("superprovider") + .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); assertThat(this.observationConvention.supportsContext(observationContext)).isTrue(); assertThat(this.observationConvention.supportsContext(new Observation.Context())).isFalse(); } @Test - void shouldHaveRequiredKeyValues() { + void shouldHaveLowCardinalityKeyValuesWhenDefined() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) - .operationMetadata(generateOperationMetadata()) - .requestOptions(ChatModelRequestOptions.builder().model("mistral").build()) + .provider("superprovider") + .requestOptions(ChatOptionsBuilder.builder().withModel("mistral").build()) .build(); assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains( KeyValue.of(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), "chat"), - KeyValue.of(LowCardinalityKeyNames.AI_PROVIDER.asString(), "ollama"), + KeyValue.of(LowCardinalityKeyNames.AI_PROVIDER.asString(), "superprovider"), KeyValue.of(LowCardinalityKeyNames.REQUEST_MODEL.asString(), "mistral")); } @Test - void shouldHaveOptionalKeyValues() { + void shouldHaveKeyValuesWhenDefinedAndResponse() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) - .operationMetadata(generateOperationMetadata()) - .requestOptions(ChatModelRequestOptions.builder() - .model("mistral") - .frequencyPenalty(0.8f) - .maxTokens(200) - .presencePenalty(1.0f) - .stopSequences(List.of("addio", "bye")) - .temperature(0.5f) - .topK(1) - .topP(0.9f) + .provider("superprovider") + .requestOptions(ChatOptionsBuilder.builder() + .withModel("mistral") + .withFrequencyPenalty(0.8f) + .withMaxTokens(200) + .withPresencePenalty(1.0f) + .withStopSequences(List.of("addio", "bye")) + .withTemperature(0.5f) + .withTopK(1) + .withTopP(0.9f) .build()) .build(); observationContext.setResponse(new ChatResponse( @@ -125,14 +133,15 @@ void shouldHaveOptionalKeyValues() { } @Test - void shouldHaveMissingKeyValues() { + void shouldHaveNoneKeyValuesWhenMissing() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(generatePrompt()) - .operationMetadata(generateOperationMetadata()) - .requestOptions(ChatModelRequestOptions.builder().model("mistral").build()) + .provider("superprovider") + .requestOptions(ChatOptionsBuilder.builder().build()) .build(); - assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)) - .contains(KeyValue.of(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), KeyValue.NONE_VALUE)); + assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains( + KeyValue.of(LowCardinalityKeyNames.REQUEST_MODEL.asString(), KeyValue.NONE_VALUE), + KeyValue.of(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), KeyValue.NONE_VALUE)); assertThat(this.observationConvention.getHighCardinalityKeyValues(observationContext)).contains( KeyValue.of(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(), KeyValue.NONE_VALUE), KeyValue.of(HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), KeyValue.NONE_VALUE), @@ -152,13 +161,6 @@ private Prompt generatePrompt() { return new Prompt("Who let the dogs out?"); } - private AiOperationMetadata generateOperationMetadata() { - return AiOperationMetadata.builder() - .operationType(AiOperationType.CHAT.value()) - .provider(AiProvider.OLLAMA.value()) - .build(); - } - static class TestUsage implements Usage { @Override diff --git a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java index 4a5cdf0a9af..d6d4ceb7abc 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java @@ -23,9 +23,6 @@ import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; -import org.springframework.ai.observation.AiOperationMetadata; -import org.springframework.ai.observation.conventions.AiOperationType; -import org.springframework.ai.observation.conventions.AiProvider; import java.util.List; import java.util.Map; @@ -53,7 +50,7 @@ void shouldHaveName() { void contextualNameWhenModelIsDefined() { EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder() .embeddingRequest(generateEmbeddingRequest()) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(EmbeddingOptionsBuilder.builder().withModel("mistral").build()) .build(); assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("embedding mistral"); @@ -63,7 +60,7 @@ void contextualNameWhenModelIsDefined() { void contextualNameWhenModelIsNotDefined() { EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder() .embeddingRequest(generateEmbeddingRequest()) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(EmbeddingOptionsBuilder.builder().build()) .build(); assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("embedding"); @@ -73,7 +70,7 @@ void contextualNameWhenModelIsNotDefined() { void supportsOnlyEmbeddingModelObservationContext() { EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder() .embeddingRequest(generateEmbeddingRequest()) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(EmbeddingOptionsBuilder.builder().withModel("supermodel").build()) .build(); assertThat(this.observationConvention.supportsContext(observationContext)).isTrue(); @@ -84,12 +81,12 @@ void supportsOnlyEmbeddingModelObservationContext() { void shouldHaveLowCardinalityKeyValuesWhenDefined() { EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder() .embeddingRequest(generateEmbeddingRequest()) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(EmbeddingOptionsBuilder.builder().withModel("mistral").build()) .build(); assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains( KeyValue.of(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), "embedding"), - KeyValue.of(LowCardinalityKeyNames.AI_PROVIDER.asString(), "ollama"), + KeyValue.of(LowCardinalityKeyNames.AI_PROVIDER.asString(), "superprovider"), KeyValue.of(LowCardinalityKeyNames.REQUEST_MODEL.asString(), "mistral")); } @@ -97,7 +94,7 @@ void shouldHaveLowCardinalityKeyValuesWhenDefined() { void shouldHaveLowCardinalityKeyValuesWhenDefinedAndResponse() { EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder() .embeddingRequest(generateEmbeddingRequest()) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(EmbeddingOptionsBuilder.builder().withModel("mistral").withDimensions(1492).build()) .build(); observationContext.setResponse(new EmbeddingResponse(List.of(), @@ -114,7 +111,7 @@ void shouldHaveLowCardinalityKeyValuesWhenDefinedAndResponse() { void shouldHaveNoneKeyValuesWhenMissing() { EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder() .embeddingRequest(generateEmbeddingRequest()) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(EmbeddingOptionsBuilder.builder().build()) .build(); assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains( @@ -130,13 +127,6 @@ private EmbeddingRequest generateEmbeddingRequest() { return new EmbeddingRequest(List.of(), EmbeddingOptionsBuilder.builder().build()); } - private AiOperationMetadata generateOperationMetadata() { - return AiOperationMetadata.builder() - .operationType(AiOperationType.EMBEDDING.value()) - .provider(AiProvider.OLLAMA.value()) - .build(); - } - static class TestUsage implements Usage { @Override diff --git a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java index d849b2b1ca4..560f37a55b8 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java @@ -26,7 +26,6 @@ import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; -import org.springframework.ai.observation.AiOperationMetadata; import org.springframework.ai.observation.conventions.*; import java.util.List; @@ -70,7 +69,7 @@ void shouldCreateAllMetersDuringAnObservation() { assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()).meters()).hasSize(3); assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), AiOperationType.EMBEDDING.value()) - .tag(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.OLLAMA.value()) + .tag(LowCardinalityKeyNames.AI_PROVIDER.asString(), "superprovider") .tag(LowCardinalityKeyNames.REQUEST_MODEL.asString(), "mistral") .tag(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), "mistral-42") .meters()).hasSize(3); @@ -88,7 +87,7 @@ void shouldCreateAllMetersDuringAnObservation() { private EmbeddingModelObservationContext generateObservationContext() { return EmbeddingModelObservationContext.builder() .embeddingRequest(generateEmbeddingRequest()) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(EmbeddingOptionsBuilder.builder().withModel("mistral").build()) .build(); } @@ -97,13 +96,6 @@ private EmbeddingRequest generateEmbeddingRequest() { return new EmbeddingRequest(List.of(), EmbeddingOptionsBuilder.builder().build()); } - private AiOperationMetadata generateOperationMetadata() { - return AiOperationMetadata.builder() - .operationType(AiOperationType.EMBEDDING.value()) - .provider(AiProvider.OLLAMA.value()) - .build(); - } - static class TestUsage implements Usage { @Override diff --git a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContextTests.java index d7d46be96cb..8c3bbb0cc66 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContextTests.java @@ -18,9 +18,6 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; -import org.springframework.ai.observation.AiOperationMetadata; -import org.springframework.ai.observation.conventions.AiOperationType; -import org.springframework.ai.observation.conventions.AiProvider; import java.util.List; @@ -38,7 +35,7 @@ class EmbeddingModelObservationContextTests { void whenMandatoryRequestOptionsThenReturn() { var observationContext = EmbeddingModelObservationContext.builder() .embeddingRequest(generateEmbeddingRequest()) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(EmbeddingOptionsBuilder.builder().withModel("supermodel").build()) .build(); @@ -49,7 +46,7 @@ void whenMandatoryRequestOptionsThenReturn() { void whenRequestOptionsIsNullThenThrow() { assertThatThrownBy(() -> EmbeddingModelObservationContext.builder() .embeddingRequest(generateEmbeddingRequest()) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(null) .build()).isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("requestOptions cannot be null"); @@ -59,11 +56,4 @@ private EmbeddingRequest generateEmbeddingRequest() { return new EmbeddingRequest(List.of(), EmbeddingOptionsBuilder.builder().build()); } - private AiOperationMetadata generateOperationMetadata() { - return AiOperationMetadata.builder() - .operationType(AiOperationType.EMBEDDING.value()) - .provider(AiProvider.OLLAMA.value()) - .build(); - } - } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java index d2ae8466e8a..10ddafbc1c0 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/image/observation/DefaultImageModelObservationConventionTests.java @@ -20,10 +20,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.image.ImageOptionsBuilder; import org.springframework.ai.image.ImagePrompt; -import org.springframework.ai.observation.AiOperationMetadata; import org.springframework.ai.observation.conventions.AiObservationAttributes; -import org.springframework.ai.observation.conventions.AiOperationType; -import org.springframework.ai.observation.conventions.AiProvider; import static org.assertj.core.api.Assertions.assertThat; @@ -45,7 +42,7 @@ void shouldHaveName() { void contextualNameWhenModelIsDefined() { ImageModelObservationContext observationContext = ImageModelObservationContext.builder() .imagePrompt(generateImagePrompt()) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(ImageOptionsBuilder.builder().withModel("mistral").build()) .build(); assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("image mistral"); @@ -55,7 +52,7 @@ void contextualNameWhenModelIsDefined() { void contextualNameWhenModelIsNotDefined() { ImageModelObservationContext observationContext = ImageModelObservationContext.builder() .imagePrompt(generateImagePrompt()) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(ImageOptionsBuilder.builder().build()) .build(); assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("image"); @@ -65,7 +62,7 @@ void contextualNameWhenModelIsNotDefined() { void supportsOnlyImageModelObservationContext() { ImageModelObservationContext observationContext = ImageModelObservationContext.builder() .imagePrompt(generateImagePrompt()) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(ImageOptionsBuilder.builder().withModel("mistral").build()) .build(); assertThat(this.observationConvention.supportsContext(observationContext)).isTrue(); @@ -76,12 +73,12 @@ void supportsOnlyImageModelObservationContext() { void shouldHaveLowCardinalityKeyValuesWhenDefined() { ImageModelObservationContext observationContext = ImageModelObservationContext.builder() .imagePrompt(generateImagePrompt()) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(ImageOptionsBuilder.builder().withModel("mistral").build()) .build(); assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains( KeyValue.of(AiObservationAttributes.AI_OPERATION_TYPE.value(), "image"), - KeyValue.of(AiObservationAttributes.AI_PROVIDER.value(), "ollama"), + KeyValue.of(AiObservationAttributes.AI_PROVIDER.value(), "superprovider"), KeyValue.of(AiObservationAttributes.REQUEST_MODEL.value(), "mistral")); } @@ -89,7 +86,7 @@ void shouldHaveLowCardinalityKeyValuesWhenDefined() { void shouldHaveHighCardinalityKeyValuesWhenDefined() { ImageModelObservationContext observationContext = ImageModelObservationContext.builder() .imagePrompt(generateImagePrompt()) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(ImageOptionsBuilder.builder() .withModel("mistral") .withN(1) @@ -110,7 +107,7 @@ void shouldHaveHighCardinalityKeyValuesWhenDefined() { void shouldHaveNoneKeyValuesWhenMissing() { ImageModelObservationContext observationContext = ImageModelObservationContext.builder() .imagePrompt(generateImagePrompt()) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(ImageOptionsBuilder.builder().build()) .build(); @@ -126,11 +123,4 @@ private ImagePrompt generateImagePrompt() { return new ImagePrompt("here comes the sun"); } - private AiOperationMetadata generateOperationMetadata() { - return AiOperationMetadata.builder() - .operationType(AiOperationType.IMAGE.value()) - .provider(AiProvider.OLLAMA.value()) - .build(); - } - -} \ No newline at end of file +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelObservationContextTests.java index febdc94357b..ac59c321ad3 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelObservationContextTests.java @@ -18,9 +18,6 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.image.ImageOptionsBuilder; import org.springframework.ai.image.ImagePrompt; -import org.springframework.ai.observation.AiOperationMetadata; -import org.springframework.ai.observation.conventions.AiOperationType; -import org.springframework.ai.observation.conventions.AiProvider; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -36,7 +33,7 @@ class ImageModelObservationContextTests { void whenMandatoryRequestOptionsThenReturn() { var observationContext = ImageModelObservationContext.builder() .imagePrompt(generateImagePrompt()) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(ImageOptionsBuilder.builder().withModel("supersun").build()) .build(); @@ -47,7 +44,7 @@ void whenMandatoryRequestOptionsThenReturn() { void whenRequestOptionsIsNullThenThrow() { assertThatThrownBy(() -> ImageModelObservationContext.builder() .imagePrompt(generateImagePrompt()) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(null) .build()).isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("requestOptions cannot be null"); @@ -57,11 +54,4 @@ private ImagePrompt generateImagePrompt() { return new ImagePrompt("here comes the sun"); } - private AiOperationMetadata generateOperationMetadata() { - return AiOperationMetadata.builder() - .operationType(AiOperationType.IMAGE.value()) - .provider(AiProvider.OLLAMA.value()) - .build(); - } - } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelPromptContentObservationFilterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelPromptContentObservationFilterTests.java index 122f9783798..7fc11e39d5c 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelPromptContentObservationFilterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/image/observation/ImageModelPromptContentObservationFilterTests.java @@ -21,10 +21,7 @@ import org.springframework.ai.image.ImageMessage; import org.springframework.ai.image.ImageOptionsBuilder; import org.springframework.ai.image.ImagePrompt; -import org.springframework.ai.observation.AiOperationMetadata; import org.springframework.ai.observation.conventions.AiObservationAttributes; -import org.springframework.ai.observation.conventions.AiOperationType; -import org.springframework.ai.observation.conventions.AiProvider; import java.util.List; @@ -51,7 +48,7 @@ void whenNotSupportedObservationContextThenReturnOriginalContext() { void whenEmptyPromptThenReturnOriginalContext() { var expectedContext = ImageModelObservationContext.builder() .imagePrompt(new ImagePrompt("")) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(ImageOptionsBuilder.builder().withModel("mistral").build()) .build(); var actualContext = observationFilter.map(expectedContext); @@ -63,7 +60,7 @@ void whenEmptyPromptThenReturnOriginalContext() { void whenPromptWithTextThenAugmentContext() { var originalContext = ImageModelObservationContext.builder() .imagePrompt(new ImagePrompt("supercalifragilisticexpialidocious")) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(ImageOptionsBuilder.builder().withModel("mistral").build()) .build(); var augmentedContext = observationFilter.map(originalContext); @@ -77,7 +74,7 @@ void whenPromptWithMessagesThenAugmentContext() { var originalContext = ImageModelObservationContext.builder() .imagePrompt(new ImagePrompt(List.of(new ImageMessage("you're a chimney sweep"), new ImageMessage("supercalifragilisticexpialidocious")))) - .operationMetadata(generateOperationMetadata()) + .provider("superprovider") .requestOptions(ImageOptionsBuilder.builder().withModel("mistral").build()) .build(); var augmentedContext = observationFilter.map(originalContext); @@ -87,11 +84,4 @@ void whenPromptWithMessagesThenAugmentContext() { "[\"you're a chimney sweep\", \"supercalifragilisticexpialidocious\"]")); } - private AiOperationMetadata generateOperationMetadata() { - return AiOperationMetadata.builder() - .operationType(AiOperationType.IMAGE.value()) - .provider(AiProvider.OLLAMA.value()) - .build(); - } - } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc index 9b93541f83a..4d246676ceb 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc @@ -97,7 +97,7 @@ The prefix `spring.ai.mistralai.chat` is the property prefix that lets you confi | spring.ai.mistralai.chat.options.maxTokens | The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length. | - | spring.ai.mistralai.chat.options.safePrompt | Indicates whether to inject a security prompt before all conversations. | false | spring.ai.mistralai.chat.options.randomSeed | This feature is in Beta. If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. | - -| spring.ai.mistralai.chat.options.stop | Up to 4 sequences where the API will stop generating further tokens. | - +| spring.ai.mistralai.chat.options.stop | Stop generation if this token is detected. Or if one of these tokens is detected when providing an array. | - | spring.ai.mistralai.chat.options.topP | An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. | - | spring.ai.mistralai.chat.options.responseFormat | An object specifying the format that the model must output. Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.| - | spring.ai.mistralai.chat.options.tools | A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. | - diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc index bd8ca9cd60c..0224c328db1 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc @@ -15,6 +15,13 @@ The following are the vector stores that currently don't support the `initialize 2. Pinecone 3. Weaviate +* In Bedrock Jurassic 2, the chat options `countPenalty`, `frequencyPenalty`, and `presencePenalty` +have been renamed to `countPenaltyOptions`, `frequencyPenaltyOptions`, and `presencePenaltyOptions`. +Furthermore, the type of the chat option `stopSequences` have been changed from `String[]` to `List`. + +* In Azure OpenAI, the type of the chat options `frequencyPenalty` and `presencePenalty` +has been changed from `Double` to `Float`, consistently with all the other implementations. + == Upgrading to 1.0.0.M1 On our march to release 1.0.0 M1 we have made several breaking changes. Apologies, it is for the best! diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiPropertiesTests.java index d155393294c..6d7cd9afb1e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiPropertiesTests.java @@ -205,13 +205,7 @@ public void chatOptionsTest() { "spring.ai.zhipuai.base-url=TEST_BASE_URL", "spring.ai.zhipuai.chat.options.model=MODEL_XYZ", - "spring.ai.zhipuai.chat.options.frequencyPenalty=-1.5", - "spring.ai.zhipuai.chat.options.logitBias.myTokenId=-5", "spring.ai.zhipuai.chat.options.maxTokens=123", - "spring.ai.zhipuai.chat.options.n=10", - "spring.ai.zhipuai.chat.options.presencePenalty=0", - "spring.ai.zhipuai.chat.options.responseFormat.type=json", - "spring.ai.zhipuai.chat.options.seed=66", "spring.ai.zhipuai.chat.options.stop=boza,koza", "spring.ai.zhipuai.chat.options.temperature=0.55", "spring.ai.zhipuai.chat.options.topP=0.56",