From fb48537df17905d183eaa47b0a05b266b706bd2e Mon Sep 17 00:00:00 2001 From: Alexandros Pappas Date: Sun, 2 Mar 2025 14:05:46 +0100 Subject: [PATCH] This commit enhances `AzureOpenAiChatOptions` by: - Adding `equals` and `hashCode` methods for proper object comparison. - Implementing a deep `copy()` method, creating new instances of mutable collections (List, Set, Map, Metadata) to prevent shared state. - Adding `AzureOpenAiChatOptionsTests` to verify `copy()`, builders, setters, and default values. Signed-off-by: Alexandros Pappas --- .../azure/openai/AzureOpenAiChatOptions.java | 50 ++++- .../openai/AzureOpenAiChatOptionsTests.java | 182 ++++++++++++++++++ 2 files changed, 226 insertions(+), 6 deletions(-) create mode 100644 models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index 887c0ad6e74..f90f99f64d6 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -22,6 +22,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; @@ -46,6 +47,7 @@ * @author Thomas Vitale * @author Soby Chacko * @author Ilayaperumal Gopinathan + * @author Alexandros Pappas */ @JsonInclude(Include.NON_NULL) public class AzureOpenAiChatOptions implements ToolCallingChatOptions { @@ -250,22 +252,24 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti .maxTokens(fromOptions.getMaxTokens()) .N(fromOptions.getN()) .presencePenalty(fromOptions.getPresencePenalty() != null ? fromOptions.getPresencePenalty() : null) - .stop(fromOptions.getStop()) + .stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null) .temperature(fromOptions.getTemperature()) .topP(fromOptions.getTopP()) .user(fromOptions.getUser()) - .functionCallbacks(fromOptions.getFunctionCallbacks()) - .functions(fromOptions.getFunctions()) + .functionCallbacks(fromOptions.getFunctionCallbacks() != null + ? new ArrayList<>(fromOptions.getFunctionCallbacks()) : null) + .functions(fromOptions.getFunctions() != null ? new HashSet<>(fromOptions.getFunctions()) : null) .responseFormat(fromOptions.getResponseFormat()) .seed(fromOptions.getSeed()) .logprobs(fromOptions.isLogprobs()) .topLogprobs(fromOptions.getTopLogProbs()) .enhancements(fromOptions.getEnhancements()) - .toolContext(fromOptions.getToolContext()) + .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled()) .streamOptions(fromOptions.getStreamOptions()) - .toolCallbacks(fromOptions.getToolCallbacks()) - .toolNames(fromOptions.getToolNames()) + .toolCallbacks( + fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null) + .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) .build(); } @@ -479,10 +483,44 @@ public void setStreamOptions(ChatCompletionStreamOptions streamOptions) { } @Override + @SuppressWarnings("unchecked") public AzureOpenAiChatOptions copy() { return fromOptions(this); } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof AzureOpenAiChatOptions that)) { + return false; + } + return Objects.equals(this.logitBias, that.logitBias) && Objects.equals(this.user, that.user) + && Objects.equals(this.n, that.n) && Objects.equals(this.stop, that.stop) + && Objects.equals(this.deploymentName, that.deploymentName) + && Objects.equals(this.responseFormat, that.responseFormat) + + && Objects.equals(this.toolCallbacks, that.toolCallbacks) + && Objects.equals(this.toolNames, that.toolNames) + && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) + && Objects.equals(this.logprobs, that.logprobs) && Objects.equals(this.topLogProbs, that.topLogProbs) + && Objects.equals(this.enhancements, that.enhancements) + && Objects.equals(this.streamOptions, that.streamOptions) + && Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.maxTokens, that.maxTokens) + && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) + && Objects.equals(this.presencePenalty, that.presencePenalty) + && Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP); + } + + @Override + public int hashCode() { + return Objects.hash(this.logitBias, this.user, this.n, this.stop, this.deploymentName, this.responseFormat, + this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.seed, this.logprobs, + this.topLogProbs, this.enhancements, this.streamOptions, this.toolContext, this.maxTokens, + this.frequencyPenalty, this.presencePenalty, this.temperature, this.topP); + } + public static class Builder { protected AzureOpenAiChatOptions options; diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java new file mode 100644 index 00000000000..b3a8bfd6d74 --- /dev/null +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java @@ -0,0 +1,182 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.azure.openai; + +import java.util.List; +import java.util.Map; + +import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; +import com.azure.ai.openai.models.AzureChatGroundingEnhancementConfiguration; +import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration; +import com.azure.ai.openai.models.ChatCompletionStreamOptions; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link AzureOpenAiChatOptions}. + * + * @author Alexandros Pappas + */ +class AzureOpenAiChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.TEXT; + ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions(); + streamOptions.setIncludeUsage(true); + + AzureChatEnhancementConfiguration enhancements = new AzureChatEnhancementConfiguration(); + enhancements.setOcr(new AzureChatOCREnhancementConfiguration(true)); + enhancements.setGrounding(new AzureChatGroundingEnhancementConfiguration(true)); + + AzureOpenAiChatOptions options = AzureOpenAiChatOptions.builder() + .deploymentName("test-deployment") + .frequencyPenalty(0.5) + .logitBias(Map.of("token1", 1, "token2", -1)) + .maxTokens(200) + .N(2) + .presencePenalty(0.8) + .stop(List.of("stop1", "stop2")) + .temperature(0.7) + .topP(0.9) + .user("test-user") + .responseFormat(responseFormat) + .seed(12345L) + .logprobs(true) + .topLogprobs(5) + .enhancements(enhancements) + .streamOptions(streamOptions) + .build(); + + assertThat(options) + .extracting("deploymentName", "frequencyPenalty", "logitBias", "maxTokens", "n", "presencePenalty", "stop", + "temperature", "topP", "user", "responseFormat", "seed", "logprobs", "topLogProbs", "enhancements", + "streamOptions") + .containsExactly("test-deployment", 0.5, Map.of("token1", 1, "token2", -1), 200, 2, 0.8, + List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, 12345L, true, 5, enhancements, + streamOptions); + } + + @Test + void testCopy() { + AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.TEXT; + ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions(); + streamOptions.setIncludeUsage(true); + + AzureChatEnhancementConfiguration enhancements = new AzureChatEnhancementConfiguration(); + enhancements.setOcr(new AzureChatOCREnhancementConfiguration(true)); + enhancements.setGrounding(new AzureChatGroundingEnhancementConfiguration(true)); + + AzureOpenAiChatOptions originalOptions = AzureOpenAiChatOptions.builder() + .deploymentName("test-deployment") + .frequencyPenalty(0.5) + .logitBias(Map.of("token1", 1, "token2", -1)) + .maxTokens(200) + .N(2) + .presencePenalty(0.8) + .stop(List.of("stop1", "stop2")) + .temperature(0.7) + .topP(0.9) + .user("test-user") + .responseFormat(responseFormat) + .seed(12345L) + .logprobs(true) + .topLogprobs(5) + .enhancements(enhancements) + .streamOptions(streamOptions) + .build(); + + AzureOpenAiChatOptions copiedOptions = originalOptions.copy(); + + assertThat(copiedOptions).isNotSameAs(originalOptions).isEqualTo(originalOptions); + // Ensure deep copy + assertThat(copiedOptions.getStop()).isNotSameAs(originalOptions.getStop()); + assertThat(copiedOptions.getToolContext()).isNotSameAs(originalOptions.getToolContext()); + } + + @Test + void testSetters() { + AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.TEXT; + ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions(); + streamOptions.setIncludeUsage(true); + AzureChatEnhancementConfiguration enhancements = new AzureChatEnhancementConfiguration(); + + AzureOpenAiChatOptions options = new AzureOpenAiChatOptions(); + options.setDeploymentName("test-deployment"); + options.setFrequencyPenalty(0.5); + options.setLogitBias(Map.of("token1", 1, "token2", -1)); + options.setMaxTokens(200); + options.setN(2); + options.setPresencePenalty(0.8); + options.setStop(List.of("stop1", "stop2")); + options.setTemperature(0.7); + options.setTopP(0.9); + options.setUser("test-user"); + options.setResponseFormat(responseFormat); + options.setSeed(12345L); + options.setLogprobs(true); + options.setTopLogProbs(5); + options.setEnhancements(enhancements); + options.setStreamOptions(streamOptions); + + assertThat(options.getDeploymentName()).isEqualTo("test-deployment"); + options.setModel("test-model"); + assertThat(options.getDeploymentName()).isEqualTo("test-model"); + + assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); + assertThat(options.getLogitBias()).isEqualTo(Map.of("token1", 1, "token2", -1)); + assertThat(options.getMaxTokens()).isEqualTo(200); + assertThat(options.getN()).isEqualTo(2); + assertThat(options.getPresencePenalty()).isEqualTo(0.8); + assertThat(options.getStop()).isEqualTo(List.of("stop1", "stop2")); + assertThat(options.getTemperature()).isEqualTo(0.7); + assertThat(options.getTopP()).isEqualTo(0.9); + assertThat(options.getUser()).isEqualTo("test-user"); + assertThat(options.getResponseFormat()).isEqualTo(responseFormat); + assertThat(options.getSeed()).isEqualTo(12345L); + assertThat(options.isLogprobs()).isTrue(); + assertThat(options.getTopLogProbs()).isEqualTo(5); + assertThat(options.getEnhancements()).isEqualTo(enhancements); + assertThat(options.getStreamOptions()).isEqualTo(streamOptions); + assertThat(options.getModel()).isEqualTo("test-model"); + } + + @Test + void testDefaultValues() { + AzureOpenAiChatOptions options = new AzureOpenAiChatOptions(); + + assertThat(options.getDeploymentName()).isNull(); + assertThat(options.getFrequencyPenalty()).isNull(); + assertThat(options.getLogitBias()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getN()).isNull(); + assertThat(options.getPresencePenalty()).isNull(); + assertThat(options.getStop()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getUser()).isNull(); + assertThat(options.getResponseFormat()).isNull(); + assertThat(options.getSeed()).isNull(); + assertThat(options.isLogprobs()).isNull(); + assertThat(options.getTopLogProbs()).isNull(); + assertThat(options.getEnhancements()).isNull(); + assertThat(options.getStreamOptions()).isNull(); + assertThat(options.getModel()).isNull(); + } + +}