Skip to content

Commit ccb37fe

Browse files
committed
feat: Add reasoningEffort parameter to OpenAI API and Chat Options
This commit introduces the `reasoningEffort` parameter to the OpenAI API integration, allowing control over the reasoning effort used by models like `o1-mini`. Changes: - Adds `reasoningEffort` field to `OpenAiApi.ChatCompletionRequest`. - Adds `reasoningEffort` field and builder method to `OpenAiChatOptions`. Signed-off-by: Alexandros Pappas <[email protected]>
1 parent 54463e6 commit ccb37fe

File tree

3 files changed

+54
-8
lines changed

3 files changed

+54
-8
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,15 @@ public class OpenAiChatOptions implements FunctionCallingOptions {
182182
* Developer-defined tags and values used for filtering completions in the <a href="https://platform.openai.com/chat-completions">dashboard</a>.
183183
*/
184184
private @JsonProperty("metadata") Map<String, String> metadata;
185+
186+
/**
187+
* Constrains effort on reasoning for reasoning models. Currently supported values are low, medium, and high.
188+
* Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response.
189+
* Optional. Defaults to medium.
190+
* Only for 'o1' models.
191+
*/
192+
private @JsonProperty("reasoning_effort") String reasoningEffort;
193+
185194
/**
186195
* OpenAI Tool Function Callbacks to register with the ChatModel.
187196
* For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution.
@@ -256,6 +265,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) {
256265
.toolContext(fromOptions.getToolContext())
257266
.store(fromOptions.getStore())
258267
.metadata(fromOptions.getMetadata())
268+
.reasoningEffort(fromOptions.getReasoningEffort())
259269
.build();
260270
}
261271

@@ -520,6 +530,14 @@ public void setMetadata(Map<String, String> metadata) {
520530
this.metadata = metadata;
521531
}
522532

533+
public String getReasoningEffort() {
534+
return this.reasoningEffort;
535+
}
536+
537+
public void setReasoningEffort(String reasoningEffort) {
538+
this.reasoningEffort = reasoningEffort;
539+
}
540+
523541
@Override
524542
public OpenAiChatOptions copy() {
525543
return OpenAiChatOptions.fromOptions(this);
@@ -532,7 +550,7 @@ public int hashCode() {
532550
this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice,
533551
this.user, this.parallelToolCalls, this.functionCallbacks, this.functions, this.httpHeaders,
534552
this.proxyToolCalls, this.toolContext, this.outputModalities, this.outputAudio, this.store,
535-
this.metadata);
553+
this.metadata, this.reasoningEffort);
536554
}
537555

538556
@Override
@@ -563,7 +581,8 @@ public boolean equals(Object o) {
563581
&& Objects.equals(this.proxyToolCalls, other.proxyToolCalls)
564582
&& Objects.equals(this.outputModalities, other.outputModalities)
565583
&& Objects.equals(this.outputAudio, other.outputAudio) && Objects.equals(this.store, other.store)
566-
&& Objects.equals(this.metadata, other.metadata);
584+
&& Objects.equals(this.metadata, other.metadata)
585+
&& Objects.equals(this.reasoningEffort, other.reasoningEffort);
567586
}
568587

569588
@Override
@@ -740,6 +759,11 @@ public Builder metadata(Map<String, String> metadata) {
740759
return this;
741760
}
742761

762+
public Builder reasoningEffort(String reasoningEffort) {
763+
this.options.reasoningEffort = reasoningEffort;
764+
return this;
765+
}
766+
743767
public OpenAiChatOptions build() {
744768
return this.options;
745769
}

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
* @author Mariusz Bernacki
5959
* @author Thomas Vitale
6060
* @author David Frizelle
61+
* @author Alexandros Pappas
6162
*/
6263
public class OpenAiApi {
6364

@@ -804,6 +805,7 @@ public record ChatCompletionRequest(// @formatter:off
804805
@JsonProperty("messages") List<ChatCompletionMessage> messages,
805806
@JsonProperty("model") String model,
806807
@JsonProperty("store") Boolean store,
808+
@JsonProperty("reasoning_effort") String reasoningEffort,
807809
@JsonProperty("metadata") Map<String, String> metadata,
808810
@JsonProperty("frequency_penalty") Double frequencyPenalty,
809811
@JsonProperty("logit_bias") Map<String, Integer> logitBias,
@@ -836,7 +838,7 @@ public record ChatCompletionRequest(// @formatter:off
836838
* @param temperature What sampling temperature to use, between 0 and 1.
837839
*/
838840
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature) {
839-
this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null,
841+
this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, null,
840842
null, null, null, false, null, temperature, null,
841843
null, null, null, null);
842844
}
@@ -849,7 +851,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
849851
* @param audio Parameters for audio output. Required when audio output is requested with outputModalities: ["audio"].
850852
*/
851853
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, AudioParameters audio, boolean stream) {
852-
this(messages, model, null, null, null, null, null, null,
854+
this(messages, model, null, null, null, null, null, null, null,
853855
null, null, null, List.of(OutputModality.AUDIO, OutputModality.TEXT), audio, null, null,
854856
null, null, null, stream, null, null, null,
855857
null, null, null, null);
@@ -865,7 +867,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
865867
* as they become available, with the stream terminated by a data: [DONE] message.
866868
*/
867869
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature, boolean stream) {
868-
this(messages, model, null, null, null, null, null, null, null, null, null,
870+
this(messages, model, null, null, null, null, null, null, null, null, null, null,
869871
null, null, null, null, null, null, null, stream, null, temperature, null,
870872
null, null, null, null);
871873
}
@@ -881,7 +883,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
881883
*/
882884
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
883885
List<FunctionTool> tools, Object toolChoice) {
884-
this(messages, model, null, null, null, null, null, null, null, null, null,
886+
this(messages, model, null, null, null, null, null, null, null, null, null, null,
885887
null, null, null, null, null, null, null, false, null, 0.8, null,
886888
tools, toolChoice, null, null);
887889
}
@@ -894,7 +896,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
894896
* as they become available, with the stream terminated by a data: [DONE] message.
895897
*/
896898
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
897-
this(messages, null, null, null, null, null, null, null, null, null, null,
899+
this(messages, null, null, null, null, null, null, null, null, null, null, null,
898900
null, null, null, null, null, null, null, stream, null, null, null,
899901
null, null, null, null);
900902
}
@@ -906,7 +908,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean strea
906908
* @return A new {@link ChatCompletionRequest} with the specified stream options.
907909
*/
908910
public ChatCompletionRequest streamOptions(StreamOptions streamOptions) {
909-
return new ChatCompletionRequest(this.messages, this.model, this.store, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs,
911+
return new ChatCompletionRequest(this.messages, this.model, this.store, this.reasoningEffort, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs,
910912
this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty,
911913
this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP,
912914
this.tools, this.toolChoice, this.parallelToolCalls, this.user);

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
/**
4141
* @author Christian Tzolov
4242
* @author Thomas Vitale
43+
* @author Alexandros Pappas
4344
*/
4445
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
4546
public class OpenAiApiIT {
@@ -66,6 +67,25 @@ void chatCompletionStream() {
6667
assertThat(response.collectList().block()).isNotNull();
6768
}
6869

70+
@Test
71+
void validateReasoningTokens() {
72+
ChatCompletionMessage userMessage = new ChatCompletionMessage(
73+
"If a train travels 100 miles in 2 hours, what is its average speed?", ChatCompletionMessage.Role.USER);
74+
ChatCompletionRequest request = new ChatCompletionRequest(List.of(userMessage), "o1", null,
75+
"low", null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, false,
76+
null, null, null, null, null, null, null);
77+
ResponseEntity<ChatCompletion> response = this.openAiApi.chatCompletionEntity(request);
78+
79+
assertThat(response).isNotNull();
80+
assertThat(response.getBody()).isNotNull();
81+
82+
OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = response.getBody()
83+
.usage()
84+
.completionTokenDetails();
85+
assertThat(completionTokenDetails).isNotNull();
86+
assertThat(completionTokenDetails.reasoningTokens()).isPositive();
87+
}
88+
6989
@Test
7090
void embeddings() {
7191
ResponseEntity<EmbeddingList<Embedding>> response = this.openAiApi

0 commit comments

Comments
 (0)