diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index 831aedf227d..6c3e1246ca2 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -182,6 +182,15 @@ public class OpenAiChatOptions implements FunctionCallingOptions { * Developer-defined tags and values used for filtering completions in the dashboard. */ private @JsonProperty("metadata") Map metadata; + + /** + * Constrains effort on reasoning for reasoning models. Currently supported values are low, medium, and high. + * Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response. + * Optional. Defaults to medium. + * Only for 'o1' models. + */ + private @JsonProperty("reasoning_effort") String reasoningEffort; + /** * OpenAI Tool Function Callbacks to register with the ChatModel. * For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution. @@ -256,6 +265,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) { .toolContext(fromOptions.getToolContext()) .store(fromOptions.getStore()) .metadata(fromOptions.getMetadata()) + .reasoningEffort(fromOptions.getReasoningEffort()) .build(); } @@ -520,6 +530,14 @@ public void setMetadata(Map metadata) { this.metadata = metadata; } + public String getReasoningEffort() { + return this.reasoningEffort; + } + + public void setReasoningEffort(String reasoningEffort) { + this.reasoningEffort = reasoningEffort; + } + @Override public OpenAiChatOptions copy() { return OpenAiChatOptions.fromOptions(this); @@ -532,7 +550,7 @@ public int hashCode() { this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice, this.user, this.parallelToolCalls, this.functionCallbacks, this.functions, this.httpHeaders, this.proxyToolCalls, this.toolContext, this.outputModalities, this.outputAudio, this.store, - this.metadata); + this.metadata, this.reasoningEffort); } @Override @@ -563,7 +581,8 @@ public boolean equals(Object o) { && Objects.equals(this.proxyToolCalls, other.proxyToolCalls) && Objects.equals(this.outputModalities, other.outputModalities) && Objects.equals(this.outputAudio, other.outputAudio) && Objects.equals(this.store, other.store) - && Objects.equals(this.metadata, other.metadata); + && Objects.equals(this.metadata, other.metadata) + && Objects.equals(this.reasoningEffort, other.reasoningEffort); } @Override @@ -740,6 +759,11 @@ public Builder metadata(Map metadata) { return this; } + public Builder reasoningEffort(String reasoningEffort) { + this.options.reasoningEffort = reasoningEffort; + return this; + } + public OpenAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index a29ebcd8fcd..f73dc7b7d88 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -58,6 +58,7 @@ * @author Mariusz Bernacki * @author Thomas Vitale * @author David Frizelle + * @author Alexandros Pappas */ public class OpenAiApi { @@ -826,7 +827,8 @@ public record ChatCompletionRequest(// @formatter:off @JsonProperty("tools") List tools, @JsonProperty("tool_choice") Object toolChoice, @JsonProperty("parallel_tool_calls") Boolean parallelToolCalls, - @JsonProperty("user") String user) { + @JsonProperty("user") String user, + @JsonProperty("reasoning_effort") String reasoningEffort) { /** * Shortcut constructor for a chat completion request with the given messages, model and temperature. @@ -838,7 +840,7 @@ public record ChatCompletionRequest(// @formatter:off public ChatCompletionRequest(List messages, String model, Double temperature) { this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, false, null, temperature, null, - null, null, null, null); + null, null, null, null, null); } /** @@ -852,7 +854,7 @@ public ChatCompletionRequest(List messages, String model, this(messages, model, null, null, null, null, null, null, null, null, null, List.of(OutputModality.AUDIO, OutputModality.TEXT), audio, null, null, null, null, null, stream, null, null, null, - null, null, null, null); + null, null, null, null, null); } /** @@ -867,7 +869,7 @@ public ChatCompletionRequest(List messages, String model, public ChatCompletionRequest(List messages, String model, Double temperature, boolean stream) { this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, stream, null, temperature, null, - null, null, null, null); + null, null, null, null, null); } /** @@ -883,7 +885,7 @@ public ChatCompletionRequest(List messages, String model, List tools, Object toolChoice) { this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, false, null, 0.8, null, - tools, toolChoice, null, null); + tools, toolChoice, null, null, null); } /** @@ -896,7 +898,7 @@ public ChatCompletionRequest(List messages, String model, public ChatCompletionRequest(List messages, Boolean stream) { this(messages, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, stream, null, null, null, - null, null, null, null); + null, null, null, null, null); } /** @@ -909,7 +911,7 @@ public ChatCompletionRequest streamOptions(StreamOptions streamOptions) { return new ChatCompletionRequest(this.messages, this.model, this.store, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs, this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty, this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP, - this.tools, this.toolChoice, this.parallelToolCalls, this.user); + this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort); } /** diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java index a3c098f6688..8c84fd98886 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java @@ -40,6 +40,7 @@ /** * @author Christian Tzolov * @author Thomas Vitale + * @author Alexandros Pappas */ @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") public class OpenAiApiIT { @@ -66,6 +67,25 @@ void chatCompletionStream() { assertThat(response.collectList().block()).isNotNull(); } + @Test + void validateReasoningTokens() { + ChatCompletionMessage userMessage = new ChatCompletionMessage( + "If a train travels 100 miles in 2 hours, what is its average speed?", ChatCompletionMessage.Role.USER); + ChatCompletionRequest request = new ChatCompletionRequest(List.of(userMessage), "o1", null, null, null, null, + null, null, null, null, null, null, null, null, null, null, null, null, false, null, null, null, null, + null, null, null, "low"); + ResponseEntity response = this.openAiApi.chatCompletionEntity(request); + + assertThat(response).isNotNull(); + assertThat(response.getBody()).isNotNull(); + + OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = response.getBody() + .usage() + .completionTokenDetails(); + assertThat(completionTokenDetails).isNotNull(); + assertThat(completionTokenDetails.reasoningTokens()).isPositive(); + } + @Test void embeddings() { ResponseEntity> response = this.openAiApi