Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

package org.springframework.ai.azure.openai;

import com.azure.ai.openai.models.ChatCompletionsJsonSchemaResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsJsonSchemaResponseFormatJsonSchema;
import com.azure.ai.openai.models.*;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please avoid using wildcard imports and instead explicitly import only the required classes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @dev-jonghoonpark! I will avoid using wildcard imports going forward.


import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
Expand All @@ -30,31 +30,6 @@
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.implementation.accesshelpers.ChatCompletionsOptionsAccessHelper;
import com.azure.ai.openai.models.ChatChoice;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall;
import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinition;
import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinitionFunction;
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatCompletionsResponseFormat;
import com.azure.ai.openai.models.ChatCompletionStreamOptions;
import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsToolCall;
import com.azure.ai.openai.models.ChatCompletionsToolDefinition;
import com.azure.ai.openai.models.ChatMessageContentItem;
import com.azure.ai.openai.models.ChatMessageImageContentItem;
import com.azure.ai.openai.models.ChatMessageImageUrl;
import com.azure.ai.openai.models.ChatMessageTextContentItem;
import com.azure.ai.openai.models.ChatRequestAssistantMessage;
import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestSystemMessage;
import com.azure.ai.openai.models.ChatRequestToolMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.CompletionsFinishReason;
import com.azure.ai.openai.models.CompletionsUsage;
import com.azure.ai.openai.models.ContentFilterResultsForPrompt;
import com.azure.ai.openai.models.FunctionCall;
import com.azure.core.util.BinaryData;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
Expand All @@ -63,6 +38,7 @@
import org.slf4j.LoggerFactory;
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.JsonSchema;
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.Type;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;

Expand Down Expand Up @@ -760,6 +736,14 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
mergedAzureOptions.setEnhancements(fromAzureOptions.getEnhancements() != null
? fromAzureOptions.getEnhancements() : toSpringAiOptions.getEnhancements());

ReasoningEffortValue reasoningEffort = (fromAzureOptions.getReasoningEffort() != null)
? fromAzureOptions.getReasoningEffort() : (StringUtils.hasText(toSpringAiOptions.getReasoningEffort())
? ReasoningEffortValue.fromString(toSpringAiOptions.getReasoningEffort()) : null);

if (reasoningEffort != null) {
mergedAzureOptions.setReasoningEffort(reasoningEffort);
}

return mergedAzureOptions;
}

Expand Down Expand Up @@ -849,6 +833,11 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements());
}

if (StringUtils.hasText(fromSpringAiOptions.getReasoningEffort())) {
mergedAzureOptions
.setReasoningEffort(ReasoningEffortValue.fromString(fromSpringAiOptions.getReasoningEffort()));
}

return mergedAzureOptions;
}

Expand Down Expand Up @@ -914,6 +903,10 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
copyOptions.setEnhancements(fromOptions.getEnhancements());
}

if (fromOptions.getReasoningEffort() != null) {
copyOptions.setReasoningEffort(fromOptions.getReasoningEffort());
}

return copyOptions;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,15 @@ public class AzureOpenAiChatOptions implements ToolCallingChatOptions {
@JsonIgnore
private Boolean enableStreamUsage;

/**
* Constrains effort on reasoning for reasoning models. Currently supported values are
* low, medium, and high. Reducing reasoning effort can result in faster responses and
* fewer tokens used on reasoning in a response. Optional. Defaults to medium. Only
* for reasoning models.
*/
@JsonProperty("reasoning_effort")
private String reasoningEffort;

@Override
@JsonIgnore
public List<ToolCallback> getToolCallbacks() {
Expand Down Expand Up @@ -268,6 +277,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
.responseFormat(fromOptions.getResponseFormat())
.streamUsage(fromOptions.getStreamUsage())
.reasoningEffort(fromOptions.getReasoningEffort())
.seed(fromOptions.getSeed())
.logprobs(fromOptions.isLogprobs())
.topLogprobs(fromOptions.getTopLogProbs())
Expand Down Expand Up @@ -408,6 +418,14 @@ public void setStreamUsage(Boolean enableStreamUsage) {
this.enableStreamUsage = enableStreamUsage;
}

public String getReasoningEffort() {
return this.reasoningEffort;
}

public void setReasoningEffort(String reasoningEffort) {
this.reasoningEffort = reasoningEffort;
}

@Override
@JsonIgnore
public Integer getTopK() {
Expand Down Expand Up @@ -490,6 +508,7 @@ public boolean equals(Object o) {
&& Objects.equals(this.enhancements, that.enhancements)
&& Objects.equals(this.streamOptions, that.streamOptions)
&& Objects.equals(this.enableStreamUsage, that.enableStreamUsage)
&& Objects.equals(this.reasoningEffort, that.reasoningEffort)
&& Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.maxTokens, that.maxTokens)
&& Objects.equals(this.frequencyPenalty, that.frequencyPenalty)
&& Objects.equals(this.presencePenalty, that.presencePenalty)
Expand All @@ -500,8 +519,9 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(this.logitBias, this.user, this.n, this.stop, this.deploymentName, this.responseFormat,
this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.seed, this.logprobs,
this.topLogProbs, this.enhancements, this.streamOptions, this.enableStreamUsage, this.toolContext,
this.maxTokens, this.frequencyPenalty, this.presencePenalty, this.temperature, this.topP);
this.topLogProbs, this.enhancements, this.streamOptions, this.reasoningEffort, this.enableStreamUsage,
this.toolContext, this.maxTokens, this.frequencyPenalty, this.presencePenalty, this.temperature,
this.topP);
}

public static class Builder {
Expand Down Expand Up @@ -576,6 +596,11 @@ public Builder streamUsage(Boolean enableStreamUsage) {
return this;
}

public Builder reasoningEffort(String reasoningEffort) {
this.options.reasoningEffort = reasoningEffort;
return this;
}

public Builder seed(Long seed) {
this.options.seed = seed;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ void testBuilderWithAllFields() {
.user("test-user")
.responseFormat(responseFormat)
.streamUsage(true)
.reasoningEffort("low")
.seed(12345L)
.logprobs(true)
.topLogprobs(5)
Expand All @@ -68,10 +69,10 @@ void testBuilderWithAllFields() {

assertThat(options)
.extracting("deploymentName", "frequencyPenalty", "logitBias", "maxTokens", "n", "presencePenalty", "stop",
"temperature", "topP", "user", "responseFormat", "streamUsage", "seed", "logprobs", "topLogProbs",
"enhancements", "streamOptions")
"temperature", "topP", "user", "responseFormat", "streamUsage", "reasoningEffort", "seed",
"logprobs", "topLogProbs", "enhancements", "streamOptions")
.containsExactly("test-deployment", 0.5, Map.of("token1", 1, "token2", -1), 200, 2, 0.8,
List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, true, 12345L, true, 5,
List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, true, "low", 12345L, true, 5,
enhancements, streamOptions);
}

Expand Down Expand Up @@ -100,6 +101,7 @@ void testCopy() {
.user("test-user")
.responseFormat(responseFormat)
.streamUsage(true)
.reasoningEffort("low")
.seed(12345L)
.logprobs(true)
.topLogprobs(5)
Expand Down Expand Up @@ -137,6 +139,7 @@ void testSetters() {
options.setUser("test-user");
options.setResponseFormat(responseFormat);
options.setStreamUsage(true);
options.setReasoningEffort("low");
options.setSeed(12345L);
options.setLogprobs(true);
options.setTopLogProbs(5);
Expand All @@ -158,6 +161,7 @@ void testSetters() {
assertThat(options.getUser()).isEqualTo("test-user");
assertThat(options.getResponseFormat()).isEqualTo(responseFormat);
assertThat(options.getStreamUsage()).isTrue();
assertThat(options.getReasoningEffort()).isEqualTo("low");
assertThat(options.getSeed()).isEqualTo(12345L);
assertThat(options.isLogprobs()).isTrue();
assertThat(options.getTopLogProbs()).isEqualTo(5);
Expand All @@ -182,6 +186,7 @@ void testDefaultValues() {
assertThat(options.getUser()).isNull();
assertThat(options.getResponseFormat()).isNull();
assertThat(options.getStreamUsage()).isNull();
assertThat(options.getReasoningEffort()).isNull();
assertThat(options.getSeed()).isNull();
assertThat(options.isLogprobs()).isNull();
assertThat(options.getTopLogProbs()).isNull();
Expand Down