|  | 
|  | 1 | +/* | 
|  | 2 | + * Copyright 2025-2025 the original author or authors. | 
|  | 3 | + * | 
|  | 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 5 | + * you may not use this file except in compliance with the License. | 
|  | 6 | + * You may obtain a copy of the License at | 
|  | 7 | + * | 
|  | 8 | + *      https://www.apache.org/licenses/LICENSE-2.0 | 
|  | 9 | + * | 
|  | 10 | + * Unless required by applicable law or agreed to in writing, software | 
|  | 11 | + * distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 13 | + * See the License for the specific language governing permissions and | 
|  | 14 | + * limitations under the License. | 
|  | 15 | + */ | 
|  | 16 | + | 
|  | 17 | +package org.springframework.ai.openai; | 
|  | 18 | + | 
|  | 19 | +import java.util.ArrayList; | 
|  | 20 | +import java.util.HashMap; | 
|  | 21 | +import java.util.List; | 
|  | 22 | +import java.util.Map; | 
|  | 23 | + | 
|  | 24 | +import static org.assertj.core.api.Assertions.assertThat; | 
|  | 25 | +import org.junit.jupiter.api.Test; | 
|  | 26 | +import static org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters.Voice.ALLOY; | 
|  | 27 | + | 
|  | 28 | +import org.springframework.ai.openai.api.OpenAiApi; | 
|  | 29 | +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters; | 
|  | 30 | +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.StreamOptions; | 
|  | 31 | +import org.springframework.ai.openai.api.ResponseFormat; | 
|  | 32 | + | 
|  | 33 | +/** | 
|  | 34 | + * Tests for {@link OpenAiChatOptions}. | 
|  | 35 | + * | 
|  | 36 | + * @author Alexandros Pappas | 
|  | 37 | + */ | 
|  | 38 | +class OpenAiChatOptionsTests { | 
|  | 39 | + | 
|  | 40 | +	@Test | 
|  | 41 | +	void testBuilderWithAllFields() { | 
|  | 42 | +		Map<String, Integer> logitBias = new HashMap<>(); | 
|  | 43 | +		logitBias.put("token1", 1); | 
|  | 44 | +		logitBias.put("token2", -1); | 
|  | 45 | + | 
|  | 46 | +		List<String> outputModalities = List.of("text", "audio"); | 
|  | 47 | +		AudioParameters outputAudio = new AudioParameters(ALLOY, AudioParameters.AudioResponseFormat.MP3); | 
|  | 48 | +		ResponseFormat responseFormat = new ResponseFormat(); | 
|  | 49 | +		StreamOptions streamOptions = StreamOptions.INCLUDE_USAGE; | 
|  | 50 | +		List<String> stopSequences = List.of("stop1", "stop2"); | 
|  | 51 | +		List<OpenAiApi.FunctionTool> tools = new ArrayList<>(); | 
|  | 52 | +		Object toolChoice = "auto"; | 
|  | 53 | +		Map<String, String> metadata = Map.of("key1", "value1"); | 
|  | 54 | +		Map<String, Object> toolContext = Map.of("keyA", "valueA"); | 
|  | 55 | + | 
|  | 56 | +		OpenAiChatOptions options = OpenAiChatOptions.builder() | 
|  | 57 | +			.model("test-model") | 
|  | 58 | +			.frequencyPenalty(0.5) | 
|  | 59 | +			.logitBias(logitBias) | 
|  | 60 | +			.logprobs(true) | 
|  | 61 | +			.topLogprobs(5) | 
|  | 62 | +			.maxTokens(100) | 
|  | 63 | +			.maxCompletionTokens(50) | 
|  | 64 | +			.N(2) | 
|  | 65 | +			.outputModalities(outputModalities) | 
|  | 66 | +			.outputAudio(outputAudio) | 
|  | 67 | +			.presencePenalty(0.8) | 
|  | 68 | +			.responseFormat(responseFormat) | 
|  | 69 | +			.streamUsage(true) | 
|  | 70 | +			.seed(12345) | 
|  | 71 | +			.stop(stopSequences) | 
|  | 72 | +			.temperature(0.7) | 
|  | 73 | +			.topP(0.9) | 
|  | 74 | +			.tools(tools) | 
|  | 75 | +			.toolChoice(toolChoice) | 
|  | 76 | +			.user("test-user") | 
|  | 77 | +			.parallelToolCalls(true) | 
|  | 78 | +			.store(false) | 
|  | 79 | +			.metadata(metadata) | 
|  | 80 | +			.reasoningEffort("medium") | 
|  | 81 | +			.proxyToolCalls(false) | 
|  | 82 | +			.httpHeaders(Map.of("header1", "value1")) | 
|  | 83 | +			.toolContext(toolContext) | 
|  | 84 | +			.build(); | 
|  | 85 | + | 
|  | 86 | +		assertThat(options) | 
|  | 87 | +			.extracting("model", "frequencyPenalty", "logitBias", "logprobs", "topLogprobs", "maxTokens", | 
|  | 88 | +					"maxCompletionTokens", "n", "outputModalities", "outputAudio", "presencePenalty", "responseFormat", | 
|  | 89 | +					"streamOptions", "seed", "stop", "temperature", "topP", "tools", "toolChoice", "user", | 
|  | 90 | +					"parallelToolCalls", "store", "metadata", "reasoningEffort", "proxyToolCalls", "httpHeaders", | 
|  | 91 | +					"toolContext") | 
|  | 92 | +			.containsExactly("test-model", 0.5, logitBias, true, 5, 100, 50, 2, outputModalities, outputAudio, 0.8, | 
|  | 93 | +					responseFormat, streamOptions, 12345, stopSequences, 0.7, 0.9, tools, toolChoice, "test-user", true, | 
|  | 94 | +					false, metadata, "medium", false, Map.of("header1", "value1"), toolContext); | 
|  | 95 | + | 
|  | 96 | +		assertThat(options.getStreamUsage()).isTrue(); | 
|  | 97 | +		assertThat(options.getStreamOptions()).isEqualTo(StreamOptions.INCLUDE_USAGE); | 
|  | 98 | + | 
|  | 99 | +	} | 
|  | 100 | + | 
|  | 101 | +	@Test | 
|  | 102 | +	void testCopy() { | 
|  | 103 | +		Map<String, Integer> logitBias = new HashMap<>(); | 
|  | 104 | +		logitBias.put("token1", 1); | 
|  | 105 | + | 
|  | 106 | +		List<String> outputModalities = List.of("text"); | 
|  | 107 | +		AudioParameters outputAudio = new AudioParameters(ALLOY, AudioParameters.AudioResponseFormat.MP3); | 
|  | 108 | +		ResponseFormat responseFormat = new ResponseFormat(); | 
|  | 109 | + | 
|  | 110 | +		List<String> stopSequences = List.of("stop1"); | 
|  | 111 | +		List<OpenAiApi.FunctionTool> tools = new ArrayList<>(); | 
|  | 112 | +		Object toolChoice = "none"; | 
|  | 113 | +		Map<String, String> metadata = Map.of("key1", "value1"); | 
|  | 114 | + | 
|  | 115 | +		OpenAiChatOptions originalOptions = OpenAiChatOptions.builder() | 
|  | 116 | +			.model("test-model") | 
|  | 117 | +			.frequencyPenalty(0.5) | 
|  | 118 | +			.logitBias(logitBias) | 
|  | 119 | +			.logprobs(true) | 
|  | 120 | +			.topLogprobs(5) | 
|  | 121 | +			.maxTokens(100) | 
|  | 122 | +			.maxCompletionTokens(50) | 
|  | 123 | +			.N(2) | 
|  | 124 | +			.outputModalities(outputModalities) | 
|  | 125 | +			.outputAudio(outputAudio) | 
|  | 126 | +			.presencePenalty(0.8) | 
|  | 127 | +			.responseFormat(responseFormat) | 
|  | 128 | +			.streamUsage(false) | 
|  | 129 | +			.seed(12345) | 
|  | 130 | +			.stop(stopSequences) | 
|  | 131 | +			.temperature(0.7) | 
|  | 132 | +			.topP(0.9) | 
|  | 133 | +			.tools(tools) | 
|  | 134 | +			.toolChoice(toolChoice) | 
|  | 135 | +			.user("test-user") | 
|  | 136 | +			.parallelToolCalls(false) | 
|  | 137 | +			.store(true) | 
|  | 138 | +			.metadata(metadata) | 
|  | 139 | +			.reasoningEffort("low") | 
|  | 140 | +			.proxyToolCalls(true) | 
|  | 141 | +			.httpHeaders(Map.of("header1", "value1")) | 
|  | 142 | +			.build(); | 
|  | 143 | + | 
|  | 144 | +		OpenAiChatOptions copiedOptions = originalOptions.copy(); | 
|  | 145 | +		assertThat(copiedOptions).isNotSameAs(originalOptions).isEqualTo(originalOptions); | 
|  | 146 | +	} | 
|  | 147 | + | 
|  | 148 | +	@Test | 
|  | 149 | +	void testSetters() { | 
|  | 150 | +		Map<String, Integer> logitBias = new HashMap<>(); | 
|  | 151 | +		logitBias.put("token1", 1); | 
|  | 152 | + | 
|  | 153 | +		List<String> outputModalities = List.of("audio"); | 
|  | 154 | +		AudioParameters outputAudio = new AudioParameters(ALLOY, AudioParameters.AudioResponseFormat.MP3); | 
|  | 155 | +		ResponseFormat responseFormat = new ResponseFormat(); | 
|  | 156 | + | 
|  | 157 | +		StreamOptions streamOptions = StreamOptions.INCLUDE_USAGE; | 
|  | 158 | +		List<String> stopSequences = List.of("stop1", "stop2"); | 
|  | 159 | +		List<OpenAiApi.FunctionTool> tools = new ArrayList<>(); | 
|  | 160 | +		Object toolChoice = "auto"; | 
|  | 161 | +		Map<String, String> metadata = Map.of("key2", "value2"); | 
|  | 162 | + | 
|  | 163 | +		OpenAiChatOptions options = new OpenAiChatOptions(); | 
|  | 164 | +		options.setModel("test-model"); | 
|  | 165 | +		options.setFrequencyPenalty(0.5); | 
|  | 166 | +		options.setLogitBias(logitBias); | 
|  | 167 | +		options.setLogprobs(true); | 
|  | 168 | +		options.setTopLogprobs(5); | 
|  | 169 | +		options.setMaxTokens(100); | 
|  | 170 | +		options.setMaxCompletionTokens(50); | 
|  | 171 | +		options.setN(2); | 
|  | 172 | +		options.setOutputModalities(outputModalities); | 
|  | 173 | +		options.setOutputAudio(outputAudio); | 
|  | 174 | +		options.setPresencePenalty(0.8); | 
|  | 175 | +		options.setResponseFormat(responseFormat); | 
|  | 176 | +		options.setStreamOptions(streamOptions); | 
|  | 177 | +		options.setSeed(12345); | 
|  | 178 | +		options.setStop(stopSequences); | 
|  | 179 | +		options.setTemperature(0.7); | 
|  | 180 | +		options.setTopP(0.9); | 
|  | 181 | +		options.setTools(tools); | 
|  | 182 | +		options.setToolChoice(toolChoice); | 
|  | 183 | +		options.setUser("test-user"); | 
|  | 184 | +		options.setParallelToolCalls(true); | 
|  | 185 | +		options.setStore(false); | 
|  | 186 | +		options.setMetadata(metadata); | 
|  | 187 | +		options.setReasoningEffort("high"); | 
|  | 188 | +		options.setProxyToolCalls(false); | 
|  | 189 | +		options.setHttpHeaders(Map.of("header2", "value2")); | 
|  | 190 | + | 
|  | 191 | +		assertThat(options.getModel()).isEqualTo("test-model"); | 
|  | 192 | +		assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); | 
|  | 193 | +		assertThat(options.getLogitBias()).isEqualTo(logitBias); | 
|  | 194 | +		assertThat(options.getLogprobs()).isTrue(); | 
|  | 195 | +		assertThat(options.getTopLogprobs()).isEqualTo(5); | 
|  | 196 | +		assertThat(options.getMaxTokens()).isEqualTo(100); | 
|  | 197 | +		assertThat(options.getMaxCompletionTokens()).isEqualTo(50); | 
|  | 198 | +		assertThat(options.getN()).isEqualTo(2); | 
|  | 199 | +		assertThat(options.getOutputModalities()).isEqualTo(outputModalities); | 
|  | 200 | +		assertThat(options.getOutputAudio()).isEqualTo(outputAudio); | 
|  | 201 | +		assertThat(options.getPresencePenalty()).isEqualTo(0.8); | 
|  | 202 | +		assertThat(options.getResponseFormat()).isEqualTo(responseFormat); | 
|  | 203 | +		assertThat(options.getStreamOptions()).isEqualTo(streamOptions); | 
|  | 204 | +		assertThat(options.getSeed()).isEqualTo(12345); | 
|  | 205 | +		assertThat(options.getStop()).isEqualTo(stopSequences); | 
|  | 206 | +		assertThat(options.getTemperature()).isEqualTo(0.7); | 
|  | 207 | +		assertThat(options.getTopP()).isEqualTo(0.9); | 
|  | 208 | +		assertThat(options.getTools()).isEqualTo(tools); | 
|  | 209 | +		assertThat(options.getToolChoice()).isEqualTo(toolChoice); | 
|  | 210 | +		assertThat(options.getUser()).isEqualTo("test-user"); | 
|  | 211 | +		assertThat(options.getParallelToolCalls()).isTrue(); | 
|  | 212 | +		assertThat(options.getStore()).isFalse(); | 
|  | 213 | +		assertThat(options.getMetadata()).isEqualTo(metadata); | 
|  | 214 | +		assertThat(options.getReasoningEffort()).isEqualTo("high"); | 
|  | 215 | +		assertThat(options.getProxyToolCalls()).isFalse(); | 
|  | 216 | +		assertThat(options.getHttpHeaders()).isEqualTo(Map.of("header2", "value2")); | 
|  | 217 | +		assertThat(options.getStreamUsage()).isTrue(); | 
|  | 218 | +		options.setStreamUsage(false); | 
|  | 219 | +		assertThat(options.getStreamUsage()).isFalse(); | 
|  | 220 | +		assertThat(options.getStreamOptions()).isNull(); | 
|  | 221 | +		options.setStopSequences(List.of("s1", "s2")); | 
|  | 222 | +		assertThat(options.getStopSequences()).isEqualTo(List.of("s1", "s2")); | 
|  | 223 | +		assertThat(options.getStop()).isEqualTo(List.of("s1", "s2")); | 
|  | 224 | +	} | 
|  | 225 | + | 
|  | 226 | +	@Test | 
|  | 227 | +	void testDefaultValues() { | 
|  | 228 | +		OpenAiChatOptions options = new OpenAiChatOptions(); | 
|  | 229 | +		assertThat(options.getModel()).isNull(); | 
|  | 230 | +		assertThat(options.getFrequencyPenalty()).isNull(); | 
|  | 231 | +		assertThat(options.getLogitBias()).isNull(); | 
|  | 232 | +		assertThat(options.getLogprobs()).isNull(); | 
|  | 233 | +		assertThat(options.getTopLogprobs()).isNull(); | 
|  | 234 | +		assertThat(options.getMaxTokens()).isNull(); | 
|  | 235 | +		assertThat(options.getMaxCompletionTokens()).isNull(); | 
|  | 236 | +		assertThat(options.getN()).isNull(); | 
|  | 237 | +		assertThat(options.getOutputModalities()).isNull(); | 
|  | 238 | +		assertThat(options.getOutputAudio()).isNull(); | 
|  | 239 | +		assertThat(options.getPresencePenalty()).isNull(); | 
|  | 240 | +		assertThat(options.getResponseFormat()).isNull(); | 
|  | 241 | +		assertThat(options.getStreamOptions()).isNull(); | 
|  | 242 | +		assertThat(options.getSeed()).isNull(); | 
|  | 243 | +		assertThat(options.getStop()).isNull(); | 
|  | 244 | +		assertThat(options.getTemperature()).isNull(); | 
|  | 245 | +		assertThat(options.getTopP()).isNull(); | 
|  | 246 | +		assertThat(options.getTools()).isNull(); | 
|  | 247 | +		assertThat(options.getToolChoice()).isNull(); | 
|  | 248 | +		assertThat(options.getUser()).isNull(); | 
|  | 249 | +		assertThat(options.getParallelToolCalls()).isNull(); | 
|  | 250 | +		assertThat(options.getStore()).isNull(); | 
|  | 251 | +		assertThat(options.getMetadata()).isNull(); | 
|  | 252 | +		assertThat(options.getReasoningEffort()).isNull(); | 
|  | 253 | +		assertThat(options.getFunctionCallbacks()).isNotNull().isEmpty(); | 
|  | 254 | +		assertThat(options.getFunctions()).isNotNull().isEmpty(); | 
|  | 255 | +		assertThat(options.getProxyToolCalls()).isNull(); | 
|  | 256 | +		assertThat(options.getHttpHeaders()).isNotNull().isEmpty(); | 
|  | 257 | +		assertThat(options.getToolContext()).isEqualTo(new HashMap<>()); | 
|  | 258 | +		assertThat(options.getStreamUsage()).isFalse(); | 
|  | 259 | +		assertThat(options.getStopSequences()).isNull(); | 
|  | 260 | +	} | 
|  | 261 | + | 
|  | 262 | +} | 
0 commit comments