Skip to content

Commit 1aa3e3c

Browse files
committed
Add three request body parameters to Mistral AI Chat Completion
Add presence_penalty, frequency_penalty, and n parameters Following Mistral AI API specifications as referenced in https://docs.mistral.ai/api/#tag/chat Signed-off-by: Seunghyeon Ji <[email protected]>
1 parent 2dad482 commit 1aa3e3c

File tree

2 files changed

+78
-16
lines changed

2 files changed

+78
-16
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,28 @@ public class MistralAiChatOptions implements ToolCallingChatOptions {
101101
*/
102102
private @JsonProperty("stop") List<String> stop;
103103

104+
/**
105+
* Number between -2.0 and 2.0. frequency_penalty penalizes the repetition of words
106+
* based on their frequency in the generated text. A higher frequency penalty discourages
107+
* the model from repeating words that have already appeared frequently in the
108+
* output, promoting diversity and reducing repetition.
109+
*/
110+
private @JsonProperty("frequency_penalty") Double frequencyPenalty;
111+
112+
/**
113+
* Number between -2.0 and 2.0. presence_penalty determines how much the model
114+
* penalizes the repetition of words or phrases. A higher presence penalty encourages
115+
* the model to use a wider variety of words and phrases, making the output more
116+
* diverse and creative.
117+
*/
118+
private @JsonProperty("presence_penalty") Double presencePenalty;
119+
120+
/**
121+
* Number of completions to return for each request, input tokens are only billed
122+
* once.
123+
*/
124+
private @JsonProperty("n") Integer n;
125+
104126
/**
105127
* A list of tools the model may call. Currently, only functions are supported as a
106128
* tool. Use this to provide a list of functions the model may generate JSON inputs
@@ -151,6 +173,9 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions)
151173
.topP(fromOptions.getTopP())
152174
.responseFormat(fromOptions.getResponseFormat())
153175
.stop(fromOptions.getStop())
176+
.frequencyPenalty(fromOptions.getFrequencyPenalty())
177+
.presencePenalty(fromOptions.getPresencePenalty())
178+
.N(fromOptions.getN())
154179
.tools(fromOptions.getTools())
155180
.toolChoice(fromOptions.getToolChoice())
156181
.toolCallbacks(fromOptions.getToolCallbacks())
@@ -255,6 +280,32 @@ public void setTopP(Double topP) {
255280
this.topP = topP;
256281
}
257282

283+
@Override
284+
public Double getFrequencyPenalty() {
285+
return this.frequencyPenalty;
286+
}
287+
288+
public void setFrequencyPenalty(Double frequencyPenalty) {
289+
this.frequencyPenalty = frequencyPenalty;
290+
}
291+
292+
@Override
293+
public Double getPresencePenalty() {
294+
return this.presencePenalty;
295+
}
296+
297+
public void setPresencePenalty(Double presencePenalty) {
298+
this.presencePenalty = presencePenalty;
299+
}
300+
301+
public Integer getN() {
302+
return this.n;
303+
}
304+
305+
public void setN(Integer n) {
306+
this.n = n;
307+
}
308+
258309
@Override
259310
@JsonIgnore
260311
public List<FunctionCallback> getToolCallbacks() {
@@ -325,18 +376,6 @@ public void setFunctions(Set<String> functionNames) {
325376
this.setToolNames(functionNames);
326377
}
327378

328-
@Override
329-
@JsonIgnore
330-
public Double getFrequencyPenalty() {
331-
return null;
332-
}
333-
334-
@Override
335-
@JsonIgnore
336-
public Double getPresencePenalty() {
337-
return null;
338-
}
339-
340379
@Override
341380
@JsonIgnore
342381
public Integer getTopK() {
@@ -376,8 +415,8 @@ public MistralAiChatOptions copy() {
376415
@Override
377416
public int hashCode() {
378417
return Objects.hash(this.model, this.temperature, this.topP, this.maxTokens, this.safePrompt, this.randomSeed,
379-
this.responseFormat, this.stop, this.tools, this.toolChoice, this.toolCallbacks, this.tools,
380-
this.internalToolExecutionEnabled, this.toolContext);
418+
this.responseFormat, this.stop, this.frequencyPenalty, this.presencePenalty, this.n, this.tools,
419+
this.toolChoice, this.toolCallbacks, this.tools, this.internalToolExecutionEnabled, this.toolContext);
381420
}
382421

383422
@Override
@@ -397,6 +436,8 @@ public boolean equals(Object obj) {
397436
&& Objects.equals(this.safePrompt, other.safePrompt)
398437
&& Objects.equals(this.randomSeed, other.randomSeed)
399438
&& Objects.equals(this.responseFormat, other.responseFormat) && Objects.equals(this.stop, other.stop)
439+
&& Objects.equals(this.frequencyPenalty, other.frequencyPenalty)
440+
&& Objects.equals(this.presencePenalty, other.presencePenalty) && Objects.equals(this.n, other.n)
400441
&& Objects.equals(this.tools, other.tools) && Objects.equals(this.toolChoice, other.toolChoice)
401442
&& Objects.equals(this.toolCallbacks, other.toolCallbacks)
402443
&& Objects.equals(this.toolNames, other.toolNames)
@@ -438,6 +479,21 @@ public Builder stop(List<String> stop) {
438479
return this;
439480
}
440481

482+
public Builder frequencyPenalty(Double frequencyPenalty) {
483+
this.options.frequencyPenalty = frequencyPenalty;
484+
return this;
485+
}
486+
487+
public Builder presencePenalty(Double presencePenalty) {
488+
this.options.presencePenalty = presencePenalty;
489+
return this;
490+
}
491+
492+
public Builder N(Integer n) {
493+
this.options.n = n;
494+
return this;
495+
}
496+
441497
public Builder temperature(Double temperature) {
442498
this.options.setTemperature(temperature);
443499
return this;

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ void observationForChatOperation() {
7474
.stop(List.of("this-is-the-end"))
7575
.temperature(0.7)
7676
.topP(1.0)
77+
.presencePenalty(0.0)
78+
.frequencyPenalty(0.0)
79+
.N(2)
7780
.build();
7881

7982
Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
@@ -95,6 +98,9 @@ void observationForStreamingChatOperation() {
9598
.stop(List.of("this-is-the-end"))
9699
.temperature(0.7)
97100
.topP(1.0)
101+
.presencePenalty(0.0)
102+
.frequencyPenalty(0.0)
103+
.N(2)
98104
.build();
99105

100106
Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
@@ -133,9 +139,9 @@ private void validate(ChatResponseMetadata responseMetadata) {
133139
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(),
134140
StringUtils.hasText(responseMetadata.getModel()) ? responseMetadata.getModel()
135141
: KeyValue.NONE_VALUE)
136-
.doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString())
142+
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(), "0.0")
143+
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString(), "0.0")
137144
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048")
138-
.doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString())
139145
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES.asString(),
140146
"[\"this-is-the-end\"]")
141147
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), "0.7")

0 commit comments

Comments
 (0)