diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index a94115a9b99..3ff5cc70f2c 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -16,8 +16,8 @@ package org.springframework.ai.azure.openai; -import com.azure.ai.openai.models.ChatCompletionsJsonSchemaResponseFormat; -import com.azure.ai.openai.models.ChatCompletionsJsonSchemaResponseFormatJsonSchema; +import com.azure.ai.openai.models.*; + import java.util.ArrayList; import java.util.Base64; import java.util.Collections; @@ -30,31 +30,6 @@ import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.ai.openai.implementation.accesshelpers.ChatCompletionsOptionsAccessHelper; -import com.azure.ai.openai.models.ChatChoice; -import com.azure.ai.openai.models.ChatCompletions; -import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall; -import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinition; -import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinitionFunction; -import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; -import com.azure.ai.openai.models.ChatCompletionsOptions; -import com.azure.ai.openai.models.ChatCompletionsResponseFormat; -import com.azure.ai.openai.models.ChatCompletionStreamOptions; -import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat; -import com.azure.ai.openai.models.ChatCompletionsToolCall; -import com.azure.ai.openai.models.ChatCompletionsToolDefinition; -import com.azure.ai.openai.models.ChatMessageContentItem; -import com.azure.ai.openai.models.ChatMessageImageContentItem; -import com.azure.ai.openai.models.ChatMessageImageUrl; -import com.azure.ai.openai.models.ChatMessageTextContentItem; -import com.azure.ai.openai.models.ChatRequestAssistantMessage; -import com.azure.ai.openai.models.ChatRequestMessage; -import com.azure.ai.openai.models.ChatRequestSystemMessage; -import com.azure.ai.openai.models.ChatRequestToolMessage; -import com.azure.ai.openai.models.ChatRequestUserMessage; -import com.azure.ai.openai.models.CompletionsFinishReason; -import com.azure.ai.openai.models.CompletionsUsage; -import com.azure.ai.openai.models.ContentFilterResultsForPrompt; -import com.azure.ai.openai.models.FunctionCall; import com.azure.core.util.BinaryData; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; @@ -63,6 +38,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.JsonSchema; import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.Type; +import org.springframework.util.StringUtils; import reactor.core.publisher.Flux; import reactor.core.scheduler.Schedulers; @@ -760,6 +736,14 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions, mergedAzureOptions.setEnhancements(fromAzureOptions.getEnhancements() != null ? fromAzureOptions.getEnhancements() : toSpringAiOptions.getEnhancements()); + ReasoningEffortValue reasoningEffort = (fromAzureOptions.getReasoningEffort() != null) + ? fromAzureOptions.getReasoningEffort() : (StringUtils.hasText(toSpringAiOptions.getReasoningEffort()) + ? ReasoningEffortValue.fromString(toSpringAiOptions.getReasoningEffort()) : null); + + if (reasoningEffort != null) { + mergedAzureOptions.setReasoningEffort(reasoningEffort); + } + return mergedAzureOptions; } @@ -849,6 +833,11 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions, mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements()); } + if (StringUtils.hasText(fromSpringAiOptions.getReasoningEffort())) { + mergedAzureOptions + .setReasoningEffort(ReasoningEffortValue.fromString(fromSpringAiOptions.getReasoningEffort())); + } + return mergedAzureOptions; } @@ -914,6 +903,10 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) { copyOptions.setEnhancements(fromOptions.getEnhancements()); } + if (fromOptions.getReasoningEffort() != null) { + copyOptions.setReasoningEffort(fromOptions.getReasoningEffort()); + } + return copyOptions; } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index b1b785c0cf2..da442b4ad4d 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -207,6 +207,15 @@ public class AzureOpenAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean enableStreamUsage; + /** + * Constrains effort on reasoning for reasoning models. Currently supported values are + * low, medium, and high. Reducing reasoning effort can result in faster responses and + * fewer tokens used on reasoning in a response. Optional. Defaults to medium. Only + * for reasoning models. + */ + @JsonProperty("reasoning_effort") + private String reasoningEffort; + @Override @JsonIgnore public List getToolCallbacks() { @@ -268,6 +277,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) .responseFormat(fromOptions.getResponseFormat()) .streamUsage(fromOptions.getStreamUsage()) + .reasoningEffort(fromOptions.getReasoningEffort()) .seed(fromOptions.getSeed()) .logprobs(fromOptions.isLogprobs()) .topLogprobs(fromOptions.getTopLogProbs()) @@ -408,6 +418,14 @@ public void setStreamUsage(Boolean enableStreamUsage) { this.enableStreamUsage = enableStreamUsage; } + public String getReasoningEffort() { + return this.reasoningEffort; + } + + public void setReasoningEffort(String reasoningEffort) { + this.reasoningEffort = reasoningEffort; + } + @Override @JsonIgnore public Integer getTopK() { @@ -490,6 +508,7 @@ public boolean equals(Object o) { && Objects.equals(this.enhancements, that.enhancements) && Objects.equals(this.streamOptions, that.streamOptions) && Objects.equals(this.enableStreamUsage, that.enableStreamUsage) + && Objects.equals(this.reasoningEffort, that.reasoningEffort) && Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.maxTokens, that.maxTokens) && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) && Objects.equals(this.presencePenalty, that.presencePenalty) @@ -500,8 +519,9 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(this.logitBias, this.user, this.n, this.stop, this.deploymentName, this.responseFormat, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.seed, this.logprobs, - this.topLogProbs, this.enhancements, this.streamOptions, this.enableStreamUsage, this.toolContext, - this.maxTokens, this.frequencyPenalty, this.presencePenalty, this.temperature, this.topP); + this.topLogProbs, this.enhancements, this.streamOptions, this.reasoningEffort, this.enableStreamUsage, + this.toolContext, this.maxTokens, this.frequencyPenalty, this.presencePenalty, this.temperature, + this.topP); } public static class Builder { @@ -576,6 +596,11 @@ public Builder streamUsage(Boolean enableStreamUsage) { return this; } + public Builder reasoningEffort(String reasoningEffort) { + this.options.reasoningEffort = reasoningEffort; + return this; + } + public Builder seed(Long seed) { this.options.seed = seed; return this; diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java index 60568f540c4..789635d358e 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java @@ -59,6 +59,7 @@ void testBuilderWithAllFields() { .user("test-user") .responseFormat(responseFormat) .streamUsage(true) + .reasoningEffort("low") .seed(12345L) .logprobs(true) .topLogprobs(5) @@ -68,10 +69,10 @@ void testBuilderWithAllFields() { assertThat(options) .extracting("deploymentName", "frequencyPenalty", "logitBias", "maxTokens", "n", "presencePenalty", "stop", - "temperature", "topP", "user", "responseFormat", "streamUsage", "seed", "logprobs", "topLogProbs", - "enhancements", "streamOptions") + "temperature", "topP", "user", "responseFormat", "streamUsage", "reasoningEffort", "seed", + "logprobs", "topLogProbs", "enhancements", "streamOptions") .containsExactly("test-deployment", 0.5, Map.of("token1", 1, "token2", -1), 200, 2, 0.8, - List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, true, 12345L, true, 5, + List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, true, "low", 12345L, true, 5, enhancements, streamOptions); } @@ -100,6 +101,7 @@ void testCopy() { .user("test-user") .responseFormat(responseFormat) .streamUsage(true) + .reasoningEffort("low") .seed(12345L) .logprobs(true) .topLogprobs(5) @@ -137,6 +139,7 @@ void testSetters() { options.setUser("test-user"); options.setResponseFormat(responseFormat); options.setStreamUsage(true); + options.setReasoningEffort("low"); options.setSeed(12345L); options.setLogprobs(true); options.setTopLogProbs(5); @@ -158,6 +161,7 @@ void testSetters() { assertThat(options.getUser()).isEqualTo("test-user"); assertThat(options.getResponseFormat()).isEqualTo(responseFormat); assertThat(options.getStreamUsage()).isTrue(); + assertThat(options.getReasoningEffort()).isEqualTo("low"); assertThat(options.getSeed()).isEqualTo(12345L); assertThat(options.isLogprobs()).isTrue(); assertThat(options.getTopLogProbs()).isEqualTo(5); @@ -182,6 +186,7 @@ void testDefaultValues() { assertThat(options.getUser()).isNull(); assertThat(options.getResponseFormat()).isNull(); assertThat(options.getStreamUsage()).isNull(); + assertThat(options.getReasoningEffort()).isNull(); assertThat(options.getSeed()).isNull(); assertThat(options.isLogprobs()).isNull(); assertThat(options.getTopLogProbs()).isNull();