Skip to content

Commit fb402a6

Browse files
andresssantosilayaperumalg
authored andcommitted
Support reasoning_effort in AzureOpenAiChatOptions
* Adds reasoningEffort field to AzureOpenAiChatOptions builder, copy, equals, hashCode * Propagates value to ChatCompletionsOptions Fixes #2703 Signed-off-by: Andres da Silva Santos <[email protected]>
1 parent 69c081c commit fb402a6

File tree

3 files changed

+60
-11
lines changed

3 files changed

+60
-11
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
package org.springframework.ai.azure.openai;
1818

19-
import com.azure.ai.openai.models.ChatCompletionsJsonSchemaResponseFormat;
20-
import com.azure.ai.openai.models.ChatCompletionsJsonSchemaResponseFormatJsonSchema;
2119
import java.util.ArrayList;
2220
import java.util.Base64;
2321
import java.util.Collections;
@@ -31,14 +29,16 @@
3129
import com.azure.ai.openai.OpenAIClientBuilder;
3230
import com.azure.ai.openai.implementation.accesshelpers.ChatCompletionsOptionsAccessHelper;
3331
import com.azure.ai.openai.models.ChatChoice;
32+
import com.azure.ai.openai.models.ChatCompletionStreamOptions;
3433
import com.azure.ai.openai.models.ChatCompletions;
3534
import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall;
3635
import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinition;
3736
import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinitionFunction;
3837
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
38+
import com.azure.ai.openai.models.ChatCompletionsJsonSchemaResponseFormat;
39+
import com.azure.ai.openai.models.ChatCompletionsJsonSchemaResponseFormatJsonSchema;
3940
import com.azure.ai.openai.models.ChatCompletionsOptions;
4041
import com.azure.ai.openai.models.ChatCompletionsResponseFormat;
41-
import com.azure.ai.openai.models.ChatCompletionStreamOptions;
4242
import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat;
4343
import com.azure.ai.openai.models.ChatCompletionsToolCall;
4444
import com.azure.ai.openai.models.ChatCompletionsToolDefinition;
@@ -55,17 +55,18 @@
5555
import com.azure.ai.openai.models.CompletionsUsage;
5656
import com.azure.ai.openai.models.ContentFilterResultsForPrompt;
5757
import com.azure.ai.openai.models.FunctionCall;
58+
import com.azure.ai.openai.models.ReasoningEffortValue;
5859
import com.azure.core.util.BinaryData;
5960
import io.micrometer.observation.Observation;
6061
import io.micrometer.observation.ObservationRegistry;
6162
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
6263
import org.slf4j.Logger;
6364
import org.slf4j.LoggerFactory;
64-
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.JsonSchema;
65-
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.Type;
6665
import reactor.core.publisher.Flux;
6766
import reactor.core.scheduler.Schedulers;
6867

68+
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.JsonSchema;
69+
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.Type;
6970
import org.springframework.ai.chat.messages.AssistantMessage;
7071
import org.springframework.ai.chat.messages.Message;
7172
import org.springframework.ai.chat.messages.ToolResponseMessage;
@@ -77,7 +78,6 @@
7778
import org.springframework.ai.chat.metadata.PromptMetadata;
7879
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
7980
import org.springframework.ai.chat.metadata.Usage;
80-
import org.springframework.ai.support.UsageCalculator;
8181
import org.springframework.ai.chat.model.ChatModel;
8282
import org.springframework.ai.chat.model.ChatResponse;
8383
import org.springframework.ai.chat.model.Generation;
@@ -96,9 +96,11 @@
9696
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
9797
import org.springframework.ai.model.tool.ToolExecutionResult;
9898
import org.springframework.ai.observation.conventions.AiProvider;
99+
import org.springframework.ai.support.UsageCalculator;
99100
import org.springframework.ai.tool.definition.ToolDefinition;
100101
import org.springframework.util.Assert;
101102
import org.springframework.util.CollectionUtils;
103+
import org.springframework.util.StringUtils;
102104

103105
/**
104106
* {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by
@@ -761,6 +763,14 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
761763
mergedAzureOptions.setEnhancements(fromAzureOptions.getEnhancements() != null
762764
? fromAzureOptions.getEnhancements() : toSpringAiOptions.getEnhancements());
763765

766+
ReasoningEffortValue reasoningEffort = (fromAzureOptions.getReasoningEffort() != null)
767+
? fromAzureOptions.getReasoningEffort() : (StringUtils.hasText(toSpringAiOptions.getReasoningEffort())
768+
? ReasoningEffortValue.fromString(toSpringAiOptions.getReasoningEffort()) : null);
769+
770+
if (reasoningEffort != null) {
771+
mergedAzureOptions.setReasoningEffort(reasoningEffort);
772+
}
773+
764774
return mergedAzureOptions;
765775
}
766776

@@ -850,6 +860,11 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
850860
mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements());
851861
}
852862

863+
if (StringUtils.hasText(fromSpringAiOptions.getReasoningEffort())) {
864+
mergedAzureOptions
865+
.setReasoningEffort(ReasoningEffortValue.fromString(fromSpringAiOptions.getReasoningEffort()));
866+
}
867+
853868
return mergedAzureOptions;
854869
}
855870

@@ -915,6 +930,10 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
915930
copyOptions.setEnhancements(fromOptions.getEnhancements());
916931
}
917932

933+
if (fromOptions.getReasoningEffort() != null) {
934+
copyOptions.setReasoningEffort(fromOptions.getReasoningEffort());
935+
}
936+
918937
return copyOptions;
919938
}
920939

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,15 @@ public class AzureOpenAiChatOptions implements ToolCallingChatOptions {
207207
@JsonIgnore
208208
private Boolean enableStreamUsage;
209209

210+
/**
211+
* Constrains effort on reasoning for reasoning models. Currently supported values are
212+
* low, medium, and high. Reducing reasoning effort can result in faster responses and
213+
* fewer tokens used on reasoning in a response. Optional. Defaults to medium. Only
214+
* for reasoning models.
215+
*/
216+
@JsonProperty("reasoning_effort")
217+
private String reasoningEffort;
218+
210219
@Override
211220
@JsonIgnore
212221
public List<ToolCallback> getToolCallbacks() {
@@ -268,6 +277,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
268277
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
269278
.responseFormat(fromOptions.getResponseFormat())
270279
.streamUsage(fromOptions.getStreamUsage())
280+
.reasoningEffort(fromOptions.getReasoningEffort())
271281
.seed(fromOptions.getSeed())
272282
.logprobs(fromOptions.isLogprobs())
273283
.topLogprobs(fromOptions.getTopLogProbs())
@@ -408,6 +418,14 @@ public void setStreamUsage(Boolean enableStreamUsage) {
408418
this.enableStreamUsage = enableStreamUsage;
409419
}
410420

421+
public String getReasoningEffort() {
422+
return this.reasoningEffort;
423+
}
424+
425+
public void setReasoningEffort(String reasoningEffort) {
426+
this.reasoningEffort = reasoningEffort;
427+
}
428+
411429
@Override
412430
@JsonIgnore
413431
public Integer getTopK() {
@@ -490,6 +508,7 @@ public boolean equals(Object o) {
490508
&& Objects.equals(this.enhancements, that.enhancements)
491509
&& Objects.equals(this.streamOptions, that.streamOptions)
492510
&& Objects.equals(this.enableStreamUsage, that.enableStreamUsage)
511+
&& Objects.equals(this.reasoningEffort, that.reasoningEffort)
493512
&& Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.maxTokens, that.maxTokens)
494513
&& Objects.equals(this.frequencyPenalty, that.frequencyPenalty)
495514
&& Objects.equals(this.presencePenalty, that.presencePenalty)
@@ -500,8 +519,9 @@ public boolean equals(Object o) {
500519
public int hashCode() {
501520
return Objects.hash(this.logitBias, this.user, this.n, this.stop, this.deploymentName, this.responseFormat,
502521
this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.seed, this.logprobs,
503-
this.topLogProbs, this.enhancements, this.streamOptions, this.enableStreamUsage, this.toolContext,
504-
this.maxTokens, this.frequencyPenalty, this.presencePenalty, this.temperature, this.topP);
522+
this.topLogProbs, this.enhancements, this.streamOptions, this.reasoningEffort, this.enableStreamUsage,
523+
this.toolContext, this.maxTokens, this.frequencyPenalty, this.presencePenalty, this.temperature,
524+
this.topP);
505525
}
506526

507527
public static class Builder {
@@ -576,6 +596,11 @@ public Builder streamUsage(Boolean enableStreamUsage) {
576596
return this;
577597
}
578598

599+
public Builder reasoningEffort(String reasoningEffort) {
600+
this.options.reasoningEffort = reasoningEffort;
601+
return this;
602+
}
603+
579604
public Builder seed(Long seed) {
580605
this.options.seed = seed;
581606
return this;

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ void testBuilderWithAllFields() {
5959
.user("test-user")
6060
.responseFormat(responseFormat)
6161
.streamUsage(true)
62+
.reasoningEffort("low")
6263
.seed(12345L)
6364
.logprobs(true)
6465
.topLogprobs(5)
@@ -68,10 +69,10 @@ void testBuilderWithAllFields() {
6869

6970
assertThat(options)
7071
.extracting("deploymentName", "frequencyPenalty", "logitBias", "maxTokens", "n", "presencePenalty", "stop",
71-
"temperature", "topP", "user", "responseFormat", "streamUsage", "seed", "logprobs", "topLogProbs",
72-
"enhancements", "streamOptions")
72+
"temperature", "topP", "user", "responseFormat", "streamUsage", "reasoningEffort", "seed",
73+
"logprobs", "topLogProbs", "enhancements", "streamOptions")
7374
.containsExactly("test-deployment", 0.5, Map.of("token1", 1, "token2", -1), 200, 2, 0.8,
74-
List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, true, 12345L, true, 5,
75+
List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, true, "low", 12345L, true, 5,
7576
enhancements, streamOptions);
7677
}
7778

@@ -100,6 +101,7 @@ void testCopy() {
100101
.user("test-user")
101102
.responseFormat(responseFormat)
102103
.streamUsage(true)
104+
.reasoningEffort("low")
103105
.seed(12345L)
104106
.logprobs(true)
105107
.topLogprobs(5)
@@ -137,6 +139,7 @@ void testSetters() {
137139
options.setUser("test-user");
138140
options.setResponseFormat(responseFormat);
139141
options.setStreamUsage(true);
142+
options.setReasoningEffort("low");
140143
options.setSeed(12345L);
141144
options.setLogprobs(true);
142145
options.setTopLogProbs(5);
@@ -158,6 +161,7 @@ void testSetters() {
158161
assertThat(options.getUser()).isEqualTo("test-user");
159162
assertThat(options.getResponseFormat()).isEqualTo(responseFormat);
160163
assertThat(options.getStreamUsage()).isTrue();
164+
assertThat(options.getReasoningEffort()).isEqualTo("low");
161165
assertThat(options.getSeed()).isEqualTo(12345L);
162166
assertThat(options.isLogprobs()).isTrue();
163167
assertThat(options.getTopLogProbs()).isEqualTo(5);
@@ -182,6 +186,7 @@ void testDefaultValues() {
182186
assertThat(options.getUser()).isNull();
183187
assertThat(options.getResponseFormat()).isNull();
184188
assertThat(options.getStreamUsage()).isNull();
189+
assertThat(options.getReasoningEffort()).isNull();
185190
assertThat(options.getSeed()).isNull();
186191
assertThat(options.isLogprobs()).isNull();
187192
assertThat(options.getTopLogProbs()).isNull();

0 commit comments

Comments
 (0)