Skip to content

Commit d781199

Browse files
sobychackomarkpollack
authored andcommitted
GH-1378: Add parameter warnings and implement penalty options for Vertex AI Gemini
Fixes: #1378 - Add warnings for unsupported frequency/presence penalty in Anthropic API - Add warnings for unsupported frequency/presence penalty and topK in BedrockProxyChatModel - Add warning for unsupported topK in OpenAI chat models - Implement frequency and presence penalty parameters for Vertex AI Gemini model - Update Vertex AI Gemini options and tests Signed-off-by: Soby Chacko <[email protected]>
1 parent 0e15197 commit d781199

File tree

6 files changed

+91
-13
lines changed

6 files changed

+91
-13
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
* @author Claudio Silva Junior
9090
* @author Alexandros Pappas
9191
* @author Jonghoon Park
92+
* @author Soby Chacko
9293
* @since 1.0.0
9394
*/
9495
public class AnthropicChatModel implements ChatModel {
@@ -424,6 +425,12 @@ Prompt buildRequestPrompt(Prompt prompt) {
424425
// Merge @JsonIgnore-annotated options explicitly since they are ignored by
425426
// Jackson, used by ModelOptionsUtils.
426427
if (runtimeOptions != null) {
428+
if (runtimeOptions.getFrequencyPenalty() != null) {
429+
logger.warn("The frequencyPenalty option is not supported by Anthropic API. Ignoring.");
430+
}
431+
if (runtimeOptions.getPresencePenalty() != null) {
432+
logger.warn("The presencePenalty option is not supported by Anthropic API. Ignoring.");
433+
}
427434
requestOptions.setHttpHeaders(
428435
mergeHttpHeaders(runtimeOptions.getHttpHeaders(), this.defaultOptions.getHttpHeaders()));
429436
requestOptions.setInternalToolExecutionEnabled(

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@
130130
* @author Wei Jiang
131131
* @author Alexandros Pappas
132132
* @author Jihoon Kim
133+
* @author Soby Chacko
133134
* @since 1.0.0
134135
*/
135136
public class BedrockProxyChatModel implements ChatModel {
@@ -279,19 +280,23 @@ Prompt buildRequestPrompt(Prompt prompt) {
279280
updatedRuntimeOptions = this.defaultOptions.copy();
280281
}
281282
else {
283+
if (runtimeOptions.getFrequencyPenalty() != null) {
284+
logger.warn("The frequencyPenalty option is not supported by BedrockProxyChatModel. Ignoring.");
285+
}
286+
if (runtimeOptions.getPresencePenalty() != null) {
287+
logger.warn("The presencePenalty option is not supported by BedrockProxyChatModel. Ignoring.");
288+
}
289+
if (runtimeOptions.getTopK() != null) {
290+
logger.warn("The topK option is not supported by BedrockProxyChatModel. Ignoring.");
291+
}
282292
updatedRuntimeOptions = ToolCallingChatOptions.builder()
283293
.model(runtimeOptions.getModel() != null ? runtimeOptions.getModel() : this.defaultOptions.getModel())
284-
.frequencyPenalty(runtimeOptions.getFrequencyPenalty() != null ? runtimeOptions.getFrequencyPenalty()
285-
: this.defaultOptions.getFrequencyPenalty())
286294
.maxTokens(runtimeOptions.getMaxTokens() != null ? runtimeOptions.getMaxTokens()
287295
: this.defaultOptions.getMaxTokens())
288-
.presencePenalty(runtimeOptions.getPresencePenalty() != null ? runtimeOptions.getPresencePenalty()
289-
: this.defaultOptions.getPresencePenalty())
290296
.stopSequences(runtimeOptions.getStopSequences() != null ? runtimeOptions.getStopSequences()
291297
: this.defaultOptions.getStopSequences())
292298
.temperature(runtimeOptions.getTemperature() != null ? runtimeOptions.getTemperature()
293299
: this.defaultOptions.getTemperature())
294-
.topK(runtimeOptions.getTopK() != null ? runtimeOptions.getTopK() : this.defaultOptions.getTopK())
295300
.topP(runtimeOptions.getTopP() != null ? runtimeOptions.getTopP() : this.defaultOptions.getTopP())
296301

297302
.toolCallbacks(runtimeOptions.getToolCallbacks() != null ? runtimeOptions.getToolCallbacks()

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
* @author Thomas Vitale
104104
* @author Ilayaperumal Gopinathan
105105
* @author Alexandros Pappas
106+
* @author Soby Chacko
106107
* @see ChatModel
107108
* @see StreamingChatModel
108109
* @see OpenAiApi
@@ -507,6 +508,10 @@ Prompt buildRequestPrompt(Prompt prompt) {
507508
// Merge @JsonIgnore-annotated options explicitly since they are ignored by
508509
// Jackson, used by ModelOptionsUtils.
509510
if (runtimeOptions != null) {
511+
if (runtimeOptions.getTopK() != null) {
512+
logger.warn("The topK option is not supported by OpenAI chat models. Ignoring.");
513+
}
514+
510515
requestOptions.setHttpHeaders(
511516
mergeHttpHeaders(runtimeOptions.getHttpHeaders(), this.defaultOptions.getHttpHeaders()));
512517
requestOptions.setInternalToolExecutionEnabled(

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,12 @@ private GenerationConfig toGenerationConfig(VertexAiGeminiChatOptions options) {
728728
if (options.getResponseMimeType() != null) {
729729
generationConfigBuilder.setResponseMimeType(options.getResponseMimeType());
730730
}
731+
if (options.getFrequencyPenalty() != null) {
732+
generationConfigBuilder.setFrequencyPenalty(options.getFrequencyPenalty().floatValue());
733+
}
734+
if (options.getPresencePenalty() != null) {
735+
generationConfigBuilder.setPresencePenalty(options.getPresencePenalty().floatValue());
736+
}
731737

732738
return generationConfigBuilder.build();
733739
}

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
* @author Thomas Vitale
4545
* @author Grogdunn
4646
* @author Ilayaperumal Gopinathan
47+
* @author Soby Chacko
4748
* @since 1.0.0
4849
*/
4950
@JsonInclude(Include.NON_NULL)
@@ -95,6 +96,16 @@ public class VertexAiGeminiChatOptions implements ToolCallingChatOptions {
9596
*/
9697
private @JsonProperty("responseMimeType") String responseMimeType;
9798

99+
/**
100+
* Optional. Frequency penalties.
101+
*/
102+
private @JsonProperty("frequencyPenalty") Double frequencyPenalty;
103+
104+
/**
105+
* Optional. Positive penalties.
106+
*/
107+
private @JsonProperty("presencePenalty") Double presencePenalty;
108+
98109
/**
99110
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
100111
* completion requests.
@@ -138,6 +149,8 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr
138149
options.setTemperature(fromOptions.getTemperature());
139150
options.setTopP(fromOptions.getTopP());
140151
options.setTopK(fromOptions.getTopK());
152+
options.setFrequencyPenalty(fromOptions.getFrequencyPenalty());
153+
options.setPresencePenalty(fromOptions.getPresencePenalty());
141154
options.setCandidateCount(fromOptions.getCandidateCount());
142155
options.setMaxOutputTokens(fromOptions.getMaxOutputTokens());
143156
options.setModel(fromOptions.getModel());
@@ -269,15 +282,21 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut
269282
}
270283

271284
@Override
272-
@JsonIgnore
273285
public Double getFrequencyPenalty() {
274-
return null;
286+
return this.frequencyPenalty;
275287
}
276288

277289
@Override
278-
@JsonIgnore
279290
public Double getPresencePenalty() {
280-
return null;
291+
return this.presencePenalty;
292+
}
293+
294+
public void setFrequencyPenalty(Double frequencyPenalty) {
295+
this.frequencyPenalty = frequencyPenalty;
296+
}
297+
298+
public void setPresencePenalty(Double presencePenalty) {
299+
this.presencePenalty = presencePenalty;
281300
}
282301

283302
public Boolean getGoogleSearchRetrieval() {
@@ -319,6 +338,8 @@ public boolean equals(Object o) {
319338
&& Objects.equals(this.stopSequences, that.stopSequences)
320339
&& Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP)
321340
&& Objects.equals(this.topK, that.topK) && Objects.equals(this.candidateCount, that.candidateCount)
341+
&& Objects.equals(this.frequencyPenalty, that.frequencyPenalty)
342+
&& Objects.equals(this.presencePenalty, that.presencePenalty)
322343
&& Objects.equals(this.maxOutputTokens, that.maxOutputTokens) && Objects.equals(this.model, that.model)
323344
&& Objects.equals(this.responseMimeType, that.responseMimeType)
324345
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
@@ -331,14 +352,16 @@ public boolean equals(Object o) {
331352
@Override
332353
public int hashCode() {
333354
return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount,
334-
this.maxOutputTokens, this.model, this.responseMimeType, this.toolCallbacks, this.toolNames,
335-
this.googleSearchRetrieval, this.safetySettings, this.internalToolExecutionEnabled, this.toolContext);
355+
this.frequencyPenalty, this.presencePenalty, this.maxOutputTokens, this.model, this.responseMimeType,
356+
this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, this.safetySettings,
357+
this.internalToolExecutionEnabled, this.toolContext);
336358
}
337359

338360
@Override
339361
public String toString() {
340362
return "VertexAiGeminiChatOptions{" + "stopSequences=" + this.stopSequences + ", temperature="
341-
+ this.temperature + ", topP=" + this.topP + ", topK=" + this.topK + ", candidateCount="
363+
+ this.temperature + ", topP=" + this.topP + ", topK=" + this.topK + ", frequencyPenalty="
364+
+ this.frequencyPenalty + ", presencePenalty=" + this.presencePenalty + ", candidateCount="
342365
+ this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\''
343366
+ ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks
344367
+ ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" + this.googleSearchRetrieval
@@ -380,6 +403,16 @@ public Builder topK(Integer topK) {
380403
return this;
381404
}
382405

406+
public Builder frequencePenalty(Double frequencyPenalty) {
407+
this.options.setFrequencyPenalty(frequencyPenalty);
408+
return this;
409+
}
410+
411+
public Builder presencePenalty(Double presencePenalty) {
412+
this.options.setPresencePenalty(presencePenalty);
413+
return this;
414+
}
415+
383416
public Builder candidateCount(Integer candidateCount) {
384417
this.options.setCandidateCount(candidateCount);
385418
return this;

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -45,6 +45,7 @@
4545

4646
/**
4747
* @author Christian Tzolov
48+
* @author Soby Chacko
4849
*/
4950
@ExtendWith(MockitoExtension.class)
5051
public class CreateGeminiRequestTests {
@@ -79,6 +80,27 @@ public void createRequestWithChatOptions() {
7980
assertThat(request.model().getGenerationConfig().getTemperature()).isEqualTo(99.9f);
8081
}
8182

83+
@Test
84+
public void createRequestWithFrequencyAndPresencePenalty() {
85+
86+
var client = VertexAiGeminiChatModel.builder()
87+
.vertexAI(this.vertexAI)
88+
.defaultOptions(VertexAiGeminiChatOptions.builder()
89+
.model("DEFAULT_MODEL")
90+
.frequencePenalty(.25)
91+
.presencePenalty(.75)
92+
.build())
93+
.build();
94+
95+
GeminiRequest request = client.createGeminiRequest(client
96+
.buildRequestPrompt(new Prompt("Test message content", VertexAiGeminiChatOptions.builder().build())));
97+
98+
assertThat(request.contents()).hasSize(1);
99+
100+
assertThat(request.model().getGenerationConfig().getFrequencyPenalty()).isEqualTo(.25F);
101+
assertThat(request.model().getGenerationConfig().getPresencePenalty()).isEqualTo(.75F);
102+
}
103+
82104
@Test
83105
public void createRequestWithSystemMessage() throws MalformedURLException {
84106

0 commit comments

Comments
 (0)