Skip to content

Commit 1ea7694

Browse files
committed
Support reasoning_effort in AzureOpenAiChatOptions
* Adds reasoningEffort field to AzureOpenAiChatOptions builder, copy, equals, hashCode * Propagates value to ChatCompletionsOptions Fixes #2703 Signed-off-by: Andres da Silva Santos <[email protected]>
1 parent d619e25 commit 1ea7694

File tree

3 files changed

+55
-32
lines changed

3 files changed

+55
-32
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
package org.springframework.ai.azure.openai;
1818

19-
import com.azure.ai.openai.models.ChatCompletionsJsonSchemaResponseFormat;
20-
import com.azure.ai.openai.models.ChatCompletionsJsonSchemaResponseFormatJsonSchema;
19+
import com.azure.ai.openai.models.*;
20+
2121
import java.util.ArrayList;
2222
import java.util.Base64;
2323
import java.util.Collections;
@@ -30,31 +30,6 @@
3030
import com.azure.ai.openai.OpenAIClient;
3131
import com.azure.ai.openai.OpenAIClientBuilder;
3232
import com.azure.ai.openai.implementation.accesshelpers.ChatCompletionsOptionsAccessHelper;
33-
import com.azure.ai.openai.models.ChatChoice;
34-
import com.azure.ai.openai.models.ChatCompletions;
35-
import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall;
36-
import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinition;
37-
import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinitionFunction;
38-
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
39-
import com.azure.ai.openai.models.ChatCompletionsOptions;
40-
import com.azure.ai.openai.models.ChatCompletionsResponseFormat;
41-
import com.azure.ai.openai.models.ChatCompletionStreamOptions;
42-
import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat;
43-
import com.azure.ai.openai.models.ChatCompletionsToolCall;
44-
import com.azure.ai.openai.models.ChatCompletionsToolDefinition;
45-
import com.azure.ai.openai.models.ChatMessageContentItem;
46-
import com.azure.ai.openai.models.ChatMessageImageContentItem;
47-
import com.azure.ai.openai.models.ChatMessageImageUrl;
48-
import com.azure.ai.openai.models.ChatMessageTextContentItem;
49-
import com.azure.ai.openai.models.ChatRequestAssistantMessage;
50-
import com.azure.ai.openai.models.ChatRequestMessage;
51-
import com.azure.ai.openai.models.ChatRequestSystemMessage;
52-
import com.azure.ai.openai.models.ChatRequestToolMessage;
53-
import com.azure.ai.openai.models.ChatRequestUserMessage;
54-
import com.azure.ai.openai.models.CompletionsFinishReason;
55-
import com.azure.ai.openai.models.CompletionsUsage;
56-
import com.azure.ai.openai.models.ContentFilterResultsForPrompt;
57-
import com.azure.ai.openai.models.FunctionCall;
5833
import com.azure.core.util.BinaryData;
5934
import io.micrometer.observation.Observation;
6035
import io.micrometer.observation.ObservationRegistry;
@@ -63,6 +38,7 @@
6338
import org.slf4j.LoggerFactory;
6439
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.JsonSchema;
6540
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.Type;
41+
import org.springframework.util.StringUtils;
6642
import reactor.core.publisher.Flux;
6743
import reactor.core.scheduler.Schedulers;
6844

@@ -760,6 +736,14 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
760736
mergedAzureOptions.setEnhancements(fromAzureOptions.getEnhancements() != null
761737
? fromAzureOptions.getEnhancements() : toSpringAiOptions.getEnhancements());
762738

739+
ReasoningEffortValue reasoningEffort = (fromAzureOptions.getReasoningEffort() != null)
740+
? fromAzureOptions.getReasoningEffort() : (StringUtils.hasText(toSpringAiOptions.getReasoningEffort())
741+
? ReasoningEffortValue.fromString(toSpringAiOptions.getReasoningEffort()) : null);
742+
743+
if (reasoningEffort != null) {
744+
mergedAzureOptions.setReasoningEffort(reasoningEffort);
745+
}
746+
763747
return mergedAzureOptions;
764748
}
765749

@@ -849,6 +833,11 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
849833
mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements());
850834
}
851835

836+
if (StringUtils.hasText(fromSpringAiOptions.getReasoningEffort())) {
837+
mergedAzureOptions
838+
.setReasoningEffort(ReasoningEffortValue.fromString(fromSpringAiOptions.getReasoningEffort()));
839+
}
840+
852841
return mergedAzureOptions;
853842
}
854843

@@ -914,6 +903,10 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
914903
copyOptions.setEnhancements(fromOptions.getEnhancements());
915904
}
916905

906+
if (fromOptions.getReasoningEffort() != null) {
907+
copyOptions.setReasoningEffort(fromOptions.getReasoningEffort());
908+
}
909+
917910
return copyOptions;
918911
}
919912

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,15 @@ public class AzureOpenAiChatOptions implements ToolCallingChatOptions {
207207
@JsonIgnore
208208
private Boolean enableStreamUsage;
209209

210+
/**
211+
* Constrains effort on reasoning for reasoning models. Currently supported values are
212+
* low, medium, and high. Reducing reasoning effort can result in faster responses and
213+
* fewer tokens used on reasoning in a response. Optional. Defaults to medium. Only
214+
* for reasoning models.
215+
*/
216+
@JsonProperty("reasoning_effort")
217+
private String reasoningEffort;
218+
210219
@Override
211220
@JsonIgnore
212221
public List<ToolCallback> getToolCallbacks() {
@@ -268,6 +277,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
268277
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
269278
.responseFormat(fromOptions.getResponseFormat())
270279
.streamUsage(fromOptions.getStreamUsage())
280+
.reasoningEffort(fromOptions.getReasoningEffort())
271281
.seed(fromOptions.getSeed())
272282
.logprobs(fromOptions.isLogprobs())
273283
.topLogprobs(fromOptions.getTopLogProbs())
@@ -408,6 +418,14 @@ public void setStreamUsage(Boolean enableStreamUsage) {
408418
this.enableStreamUsage = enableStreamUsage;
409419
}
410420

421+
public String getReasoningEffort() {
422+
return this.reasoningEffort;
423+
}
424+
425+
public void setReasoningEffort(String reasoningEffort) {
426+
this.reasoningEffort = reasoningEffort;
427+
}
428+
411429
@Override
412430
@JsonIgnore
413431
public Integer getTopK() {
@@ -490,6 +508,7 @@ public boolean equals(Object o) {
490508
&& Objects.equals(this.enhancements, that.enhancements)
491509
&& Objects.equals(this.streamOptions, that.streamOptions)
492510
&& Objects.equals(this.enableStreamUsage, that.enableStreamUsage)
511+
&& Objects.equals(this.reasoningEffort, that.reasoningEffort)
493512
&& Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.maxTokens, that.maxTokens)
494513
&& Objects.equals(this.frequencyPenalty, that.frequencyPenalty)
495514
&& Objects.equals(this.presencePenalty, that.presencePenalty)
@@ -500,8 +519,9 @@ public boolean equals(Object o) {
500519
public int hashCode() {
501520
return Objects.hash(this.logitBias, this.user, this.n, this.stop, this.deploymentName, this.responseFormat,
502521
this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.seed, this.logprobs,
503-
this.topLogProbs, this.enhancements, this.streamOptions, this.enableStreamUsage, this.toolContext,
504-
this.maxTokens, this.frequencyPenalty, this.presencePenalty, this.temperature, this.topP);
522+
this.topLogProbs, this.enhancements, this.streamOptions, this.reasoningEffort, this.enableStreamUsage,
523+
this.toolContext, this.maxTokens, this.frequencyPenalty, this.presencePenalty, this.temperature,
524+
this.topP);
505525
}
506526

507527
public static class Builder {
@@ -576,6 +596,11 @@ public Builder streamUsage(Boolean enableStreamUsage) {
576596
return this;
577597
}
578598

599+
public Builder reasoningEffort(String reasoningEffort) {
600+
this.options.reasoningEffort = reasoningEffort;
601+
return this;
602+
}
603+
579604
public Builder seed(Long seed) {
580605
this.options.seed = seed;
581606
return this;

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ void testBuilderWithAllFields() {
5959
.user("test-user")
6060
.responseFormat(responseFormat)
6161
.streamUsage(true)
62+
.reasoningEffort("low")
6263
.seed(12345L)
6364
.logprobs(true)
6465
.topLogprobs(5)
@@ -68,10 +69,10 @@ void testBuilderWithAllFields() {
6869

6970
assertThat(options)
7071
.extracting("deploymentName", "frequencyPenalty", "logitBias", "maxTokens", "n", "presencePenalty", "stop",
71-
"temperature", "topP", "user", "responseFormat", "streamUsage", "seed", "logprobs", "topLogProbs",
72-
"enhancements", "streamOptions")
72+
"temperature", "topP", "user", "responseFormat", "streamUsage", "reasoningEffort", "seed",
73+
"logprobs", "topLogProbs", "enhancements", "streamOptions")
7374
.containsExactly("test-deployment", 0.5, Map.of("token1", 1, "token2", -1), 200, 2, 0.8,
74-
List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, true, 12345L, true, 5,
75+
List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, true, "low", 12345L, true, 5,
7576
enhancements, streamOptions);
7677
}
7778

@@ -100,6 +101,7 @@ void testCopy() {
100101
.user("test-user")
101102
.responseFormat(responseFormat)
102103
.streamUsage(true)
104+
.reasoningEffort("low")
103105
.seed(12345L)
104106
.logprobs(true)
105107
.topLogprobs(5)
@@ -137,6 +139,7 @@ void testSetters() {
137139
options.setUser("test-user");
138140
options.setResponseFormat(responseFormat);
139141
options.setStreamUsage(true);
142+
options.setReasoningEffort("low");
140143
options.setSeed(12345L);
141144
options.setLogprobs(true);
142145
options.setTopLogProbs(5);
@@ -158,6 +161,7 @@ void testSetters() {
158161
assertThat(options.getUser()).isEqualTo("test-user");
159162
assertThat(options.getResponseFormat()).isEqualTo(responseFormat);
160163
assertThat(options.getStreamUsage()).isTrue();
164+
assertThat(options.getReasoningEffort()).isEqualTo("low");
161165
assertThat(options.getSeed()).isEqualTo(12345L);
162166
assertThat(options.isLogprobs()).isTrue();
163167
assertThat(options.getTopLogProbs()).isEqualTo(5);
@@ -182,6 +186,7 @@ void testDefaultValues() {
182186
assertThat(options.getUser()).isNull();
183187
assertThat(options.getResponseFormat()).isNull();
184188
assertThat(options.getStreamUsage()).isNull();
189+
assertThat(options.getReasoningEffort()).isNull();
185190
assertThat(options.getSeed()).isNull();
186191
assertThat(options.isLogprobs()).isNull();
187192
assertThat(options.getTopLogProbs()).isNull();

0 commit comments

Comments
 (0)