diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 3a43ca4043b..c66ef7f6514 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -89,6 +89,7 @@ * @author Claudio Silva Junior * @author Alexandros Pappas * @author Jonghoon Park + * @author Soby Chacko * @since 1.0.0 */ public class AnthropicChatModel implements ChatModel { @@ -424,6 +425,12 @@ Prompt buildRequestPrompt(Prompt prompt) { // Merge @JsonIgnore-annotated options explicitly since they are ignored by // Jackson, used by ModelOptionsUtils. if (runtimeOptions != null) { + if (runtimeOptions.getFrequencyPenalty() != null) { + logger.warn("Frequency penalty option is ignored by the Anthropic API"); + } + if (runtimeOptions.getPresencePenalty() != null) { + logger.warn("Presence penalty option is ignored by the Anthropic API"); + } requestOptions.setHttpHeaders( mergeHttpHeaders(runtimeOptions.getHttpHeaders(), this.defaultOptions.getHttpHeaders())); requestOptions.setInternalToolExecutionEnabled( diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 0f4d136a140..2d539950c09 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -130,6 +130,7 @@ * @author Wei Jiang * @author Alexandros Pappas * @author Jihoon Kim + * @author Soby Chacko * @since 1.0.0 */ public class BedrockProxyChatModel implements ChatModel { @@ -279,19 +280,23 @@ Prompt buildRequestPrompt(Prompt prompt) { updatedRuntimeOptions = this.defaultOptions.copy(); } else { + if (runtimeOptions.getFrequencyPenalty() != null) { + logger.warn("The frequencyPenalty option is not supported by the BedrockProxyChatModel. Ignoring."); + } + if (runtimeOptions.getPresencePenalty() != null) { + logger.warn("The presencePenalty option is not supported by the BedrockProxyChatModel. Ignoring."); + } + if (runtimeOptions.getTopK() != null) { + logger.warn("The topK option is not supported by the BedrockProxyChatModel. Ignoring."); + } updatedRuntimeOptions = ToolCallingChatOptions.builder() .model(runtimeOptions.getModel() != null ? runtimeOptions.getModel() : this.defaultOptions.getModel()) - .frequencyPenalty(runtimeOptions.getFrequencyPenalty() != null ? runtimeOptions.getFrequencyPenalty() - : this.defaultOptions.getFrequencyPenalty()) .maxTokens(runtimeOptions.getMaxTokens() != null ? runtimeOptions.getMaxTokens() : this.defaultOptions.getMaxTokens()) - .presencePenalty(runtimeOptions.getPresencePenalty() != null ? runtimeOptions.getPresencePenalty() - : this.defaultOptions.getPresencePenalty()) .stopSequences(runtimeOptions.getStopSequences() != null ? runtimeOptions.getStopSequences() : this.defaultOptions.getStopSequences()) .temperature(runtimeOptions.getTemperature() != null ? runtimeOptions.getTemperature() : this.defaultOptions.getTemperature()) - .topK(runtimeOptions.getTopK() != null ? runtimeOptions.getTopK() : this.defaultOptions.getTopK()) .topP(runtimeOptions.getTopP() != null ? runtimeOptions.getTopP() : this.defaultOptions.getTopP()) .toolCallbacks(runtimeOptions.getToolCallbacks() != null ? runtimeOptions.getToolCallbacks() diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 5a11bcad999..cc7fcadfd03 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -103,6 +103,7 @@ * @author Thomas Vitale * @author Ilayaperumal Gopinathan * @author Alexandros Pappas + * @author Soby Chacko * @see ChatModel * @see StreamingChatModel * @see OpenAiApi @@ -507,6 +508,10 @@ Prompt buildRequestPrompt(Prompt prompt) { // Merge @JsonIgnore-annotated options explicitly since they are ignored by // Jackson, used by ModelOptionsUtils. if (runtimeOptions != null) { + if (runtimeOptions.getTopK() != null) { + logger.warn("topK is not supported for chat models in OpenAI"); + } + requestOptions.setHttpHeaders( mergeHttpHeaders(runtimeOptions.getHttpHeaders(), this.defaultOptions.getHttpHeaders())); requestOptions.setInternalToolExecutionEnabled( diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 8b751e23fe9..eaac6e586be 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -728,6 +728,12 @@ private GenerationConfig toGenerationConfig(VertexAiGeminiChatOptions options) { if (options.getResponseMimeType() != null) { generationConfigBuilder.setResponseMimeType(options.getResponseMimeType()); } + if (options.getFrequencyPenalty() != null) { + generationConfigBuilder.setFrequencyPenalty(options.getFrequencyPenalty().floatValue()); + } + if (options.getPresencePenalty() != null) { + generationConfigBuilder.setPresencePenalty(options.getPresencePenalty().floatValue()); + } return generationConfigBuilder.build(); } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index 0e5df922e13..68ae24a92e2 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -44,6 +44,7 @@ * @author Thomas Vitale * @author Grogdunn * @author Ilayaperumal Gopinathan + * @author Soby Chacko * @since 1.0.0 */ @JsonInclude(Include.NON_NULL) @@ -95,6 +96,16 @@ public class VertexAiGeminiChatOptions implements ToolCallingChatOptions { */ private @JsonProperty("responseMimeType") String responseMimeType; + /** + * Optional. Frequency penalties. + */ + private @JsonProperty("frequencyPenalty") Double frequencyPenalty; + + /** + * Optional. Positive penalties. + */ + private @JsonProperty("presencePenalty") Double presencePenalty; + /** * Collection of {@link ToolCallback}s to be used for tool calling in the chat * completion requests. @@ -138,6 +149,8 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr options.setTemperature(fromOptions.getTemperature()); options.setTopP(fromOptions.getTopP()); options.setTopK(fromOptions.getTopK()); + options.setFrequencyPenalty(fromOptions.getFrequencyPenalty()); + options.setPresencePenalty(fromOptions.getPresencePenalty()); options.setCandidateCount(fromOptions.getCandidateCount()); options.setMaxOutputTokens(fromOptions.getMaxOutputTokens()); options.setModel(fromOptions.getModel()); @@ -269,15 +282,21 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut } @Override - @JsonIgnore public Double getFrequencyPenalty() { - return null; + return this.frequencyPenalty; } @Override - @JsonIgnore public Double getPresencePenalty() { - return null; + return this.presencePenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; } public Boolean getGoogleSearchRetrieval() { @@ -319,6 +338,8 @@ public boolean equals(Object o) { && Objects.equals(this.stopSequences, that.stopSequences) && Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP) && Objects.equals(this.topK, that.topK) && Objects.equals(this.candidateCount, that.candidateCount) + && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) + && Objects.equals(this.presencePenalty, that.presencePenalty) && Objects.equals(this.maxOutputTokens, that.maxOutputTokens) && Objects.equals(this.model, that.model) && Objects.equals(this.responseMimeType, that.responseMimeType) && Objects.equals(this.toolCallbacks, that.toolCallbacks) @@ -331,14 +352,16 @@ public boolean equals(Object o) { @Override public int hashCode() { return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount, - this.maxOutputTokens, this.model, this.responseMimeType, this.toolCallbacks, this.toolNames, - this.googleSearchRetrieval, this.safetySettings, this.internalToolExecutionEnabled, this.toolContext); + this.frequencyPenalty, this.presencePenalty, this.maxOutputTokens, this.model, this.responseMimeType, + this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, this.safetySettings, + this.internalToolExecutionEnabled, this.toolContext); } @Override public String toString() { return "VertexAiGeminiChatOptions{" + "stopSequences=" + this.stopSequences + ", temperature=" - + this.temperature + ", topP=" + this.topP + ", topK=" + this.topK + ", candidateCount=" + + this.temperature + ", topP=" + this.topP + ", topK=" + this.topK + ", frequencyPenalty=" + + this.frequencyPenalty + ", presencePenalty=" + this.presencePenalty + ", candidateCount=" + this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\'' + ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks + ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" + this.googleSearchRetrieval @@ -380,6 +403,16 @@ public Builder topK(Integer topK) { return this; } + public Builder frequencePenalty(Double frequencyPenalty) { + this.options.setFrequencyPenalty(frequencyPenalty); + return this; + } + + public Builder presencePenalty(Double presencePenalty) { + this.options.setPresencePenalty(presencePenalty); + return this; + } + public Builder candidateCount(Integer candidateCount) { this.options.setCandidateCount(candidateCount); return this; diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java index af061c0b01f..37d68bcf613 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,6 +45,7 @@ /** * @author Christian Tzolov + * @author Soby Chacko */ @ExtendWith(MockitoExtension.class) public class CreateGeminiRequestTests { @@ -79,6 +80,27 @@ public void createRequestWithChatOptions() { assertThat(request.model().getGenerationConfig().getTemperature()).isEqualTo(99.9f); } + @Test + public void createRequestWithFrequencyAndPresencePenalty() { + + var client = VertexAiGeminiChatModel.builder() + .vertexAI(this.vertexAI) + .defaultOptions(VertexAiGeminiChatOptions.builder() + .model("DEFAULT_MODEL") + .frequencePenalty(.25) + .presencePenalty(.75) + .build()) + .build(); + + GeminiRequest request = client.createGeminiRequest(client + .buildRequestPrompt(new Prompt("Test message content", VertexAiGeminiChatOptions.builder().build()))); + + assertThat(request.contents()).hasSize(1); + + assertThat(request.model().getGenerationConfig().getFrequencyPenalty()).isEqualTo(.25F); + assertThat(request.model().getGenerationConfig().getPresencePenalty()).isEqualTo(.75F); + } + @Test public void createRequestWithSystemMessage() throws MalformedURLException {