Skip to content

Commit 4e486c1

Browse files
authored
test: Add comprehensive test coverage for AzureOpenAiChatOptions and PostgresMlEmbeddingOptions (#4042)
Signed-off-by: Alex Klimenko <[email protected]>
1 parent 5be2509 commit 4e486c1

File tree

2 files changed

+220
-0
lines changed

2 files changed

+220
-0
lines changed

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,116 @@ public void createChatOptionsWithPresencePenaltyAndFrequencyPenalty(Double prese
166166
}
167167
}
168168

169+
@Test
170+
public void createRequestWithMinimalOptions() {
171+
OpenAIClientBuilder mockClient = Mockito.mock(OpenAIClientBuilder.class);
172+
173+
var minimalOptions = AzureOpenAiChatOptions.builder().deploymentName("MINIMAL_MODEL").build();
174+
175+
var client = AzureOpenAiChatModel.builder()
176+
.openAIClientBuilder(mockClient)
177+
.defaultOptions(minimalOptions)
178+
.build();
179+
180+
var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message"));
181+
182+
assertThat(requestOptions.getModel()).isEqualTo("MINIMAL_MODEL");
183+
assertThat(requestOptions.getTemperature()).isNull();
184+
assertThat(requestOptions.getMaxTokens()).isNull();
185+
assertThat(requestOptions.getTopP()).isNull();
186+
}
187+
188+
@Test
189+
public void createRequestWithEmptyStopList() {
190+
OpenAIClientBuilder mockClient = Mockito.mock(OpenAIClientBuilder.class);
191+
192+
var options = AzureOpenAiChatOptions.builder().deploymentName("TEST_MODEL").stop(List.of()).build();
193+
194+
var client = AzureOpenAiChatModel.builder().openAIClientBuilder(mockClient).defaultOptions(options).build();
195+
196+
var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message"));
197+
198+
assertThat(requestOptions.getStop()).isEmpty();
199+
}
200+
201+
@Test
202+
public void createRequestWithEmptyLogitBias() {
203+
OpenAIClientBuilder mockClient = Mockito.mock(OpenAIClientBuilder.class);
204+
205+
var options = AzureOpenAiChatOptions.builder().deploymentName("TEST_MODEL").logitBias(Map.of()).build();
206+
207+
var client = AzureOpenAiChatModel.builder().openAIClientBuilder(mockClient).defaultOptions(options).build();
208+
209+
var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message"));
210+
211+
assertThat(requestOptions.getLogitBias()).isEmpty();
212+
}
213+
214+
@Test
215+
public void createRequestWithLogprobsDisabled() {
216+
OpenAIClientBuilder mockClient = Mockito.mock(OpenAIClientBuilder.class);
217+
218+
var options = AzureOpenAiChatOptions.builder()
219+
.deploymentName("TEST_MODEL")
220+
.logprobs(false)
221+
.topLogprobs(0)
222+
.build();
223+
224+
var client = AzureOpenAiChatModel.builder().openAIClientBuilder(mockClient).defaultOptions(options).build();
225+
226+
var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message"));
227+
228+
assertThat(requestOptions.isLogprobs()).isFalse();
229+
assertThat(requestOptions.getTopLogprobs()).isEqualTo(0);
230+
}
231+
232+
@Test
233+
public void createRequestWithSingleStopSequence() {
234+
OpenAIClientBuilder mockClient = Mockito.mock(OpenAIClientBuilder.class);
235+
236+
var options = AzureOpenAiChatOptions.builder().deploymentName("SINGLE_STOP_MODEL").stop(List.of("END")).build();
237+
238+
var client = AzureOpenAiChatModel.builder().openAIClientBuilder(mockClient).defaultOptions(options).build();
239+
240+
var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message"));
241+
242+
assertThat(requestOptions.getStop()).hasSize(1);
243+
assertThat(requestOptions.getStop()).containsExactly("END");
244+
}
245+
246+
@Test
247+
public void builderPatternTest() {
248+
var options = AzureOpenAiChatOptions.builder()
249+
.deploymentName("BUILDER_TEST_MODEL")
250+
.temperature(0.7)
251+
.maxTokens(1500)
252+
.build();
253+
254+
assertThat(options.getDeploymentName()).isEqualTo("BUILDER_TEST_MODEL");
255+
assertThat(options.getTemperature()).isEqualTo(0.7);
256+
assertThat(options.getMaxTokens()).isEqualTo(1500);
257+
}
258+
259+
@ParameterizedTest
260+
@MethodSource("provideResponseFormatTypes")
261+
public void createRequestWithDifferentResponseFormats(Type responseFormatType, Class<?> expectedFormatClass) {
262+
OpenAIClientBuilder mockClient = Mockito.mock(OpenAIClientBuilder.class);
263+
264+
var options = AzureOpenAiChatOptions.builder()
265+
.deploymentName("FORMAT_TEST_MODEL")
266+
.responseFormat(AzureOpenAiResponseFormat.builder().type(responseFormatType).build())
267+
.build();
268+
269+
var client = AzureOpenAiChatModel.builder().openAIClientBuilder(mockClient).defaultOptions(options).build();
270+
271+
var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message"));
272+
273+
assertThat(requestOptions.getResponseFormat()).isInstanceOf(expectedFormatClass);
274+
}
275+
276+
private static Stream<Arguments> provideResponseFormatTypes() {
277+
return Stream.of(Arguments.of(Type.TEXT, ChatCompletionsTextResponseFormat.class),
278+
Arguments.of(Type.JSON_OBJECT, ChatCompletionsJsonResponseFormat.class));
279+
}
280+
169281
}

models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptionsTests.java

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,112 @@ public void mergeOptions() {
9595
assertThat(options.getMetadataMode()).isEqualTo(org.springframework.ai.document.MetadataMode.ALL);
9696
}
9797

98+
@Test
99+
public void builderWithEmptyKwargs() {
100+
PostgresMlEmbeddingOptions options = PostgresMlEmbeddingOptions.builder().kwargs(Map.of()).build();
101+
102+
assertThat(options.getKwargs()).isEmpty();
103+
assertThat(options.getKwargs()).isNotNull();
104+
}
105+
106+
@Test
107+
public void builderWithMultipleKwargs() {
108+
Map<String, Object> kwargs = Map.of("device", "gpu", "batch_size", 32, "max_length", 512, "normalize", true);
109+
110+
PostgresMlEmbeddingOptions options = PostgresMlEmbeddingOptions.builder().kwargs(kwargs).build();
111+
112+
assertThat(options.getKwargs()).hasSize(4);
113+
assertThat(options.getKwargs().get("device")).isEqualTo("gpu");
114+
assertThat(options.getKwargs().get("batch_size")).isEqualTo(32);
115+
assertThat(options.getKwargs().get("max_length")).isEqualTo(512);
116+
assertThat(options.getKwargs().get("normalize")).isEqualTo(true);
117+
}
118+
119+
@Test
120+
public void allVectorTypes() {
121+
for (PostgresMlEmbeddingModel.VectorType vectorType : PostgresMlEmbeddingModel.VectorType.values()) {
122+
PostgresMlEmbeddingOptions options = PostgresMlEmbeddingOptions.builder().vectorType(vectorType).build();
123+
124+
assertThat(options.getVectorType()).isEqualTo(vectorType);
125+
}
126+
}
127+
128+
@Test
129+
public void allMetadataModes() {
130+
for (org.springframework.ai.document.MetadataMode mode : org.springframework.ai.document.MetadataMode
131+
.values()) {
132+
PostgresMlEmbeddingOptions options = PostgresMlEmbeddingOptions.builder().metadataMode(mode).build();
133+
134+
assertThat(options.getMetadataMode()).isEqualTo(mode);
135+
}
136+
}
137+
138+
@Test
139+
public void mergeOptionsWithNullInput() {
140+
var jdbcTemplate = Mockito.mock(JdbcTemplate.class);
141+
PostgresMlEmbeddingModel embeddingModel = new PostgresMlEmbeddingModel(jdbcTemplate);
142+
143+
PostgresMlEmbeddingOptions options = embeddingModel.mergeOptions(null);
144+
145+
// Should return default options when input is null
146+
assertThat(options.getTransformer()).isEqualTo(PostgresMlEmbeddingModel.DEFAULT_TRANSFORMER_MODEL);
147+
assertThat(options.getVectorType()).isEqualTo(PostgresMlEmbeddingModel.VectorType.PG_ARRAY);
148+
assertThat(options.getKwargs()).isEqualTo(Map.of());
149+
assertThat(options.getMetadataMode()).isEqualTo(org.springframework.ai.document.MetadataMode.EMBED);
150+
}
151+
152+
@Test
153+
public void mergeOptionsPreservesOriginal() {
154+
var jdbcTemplate = Mockito.mock(JdbcTemplate.class);
155+
PostgresMlEmbeddingModel embeddingModel = new PostgresMlEmbeddingModel(jdbcTemplate);
156+
157+
PostgresMlEmbeddingOptions original = PostgresMlEmbeddingOptions.builder()
158+
.transformer("original-model")
159+
.kwargs(Map.of("original", "value"))
160+
.build();
161+
162+
PostgresMlEmbeddingOptions merged = embeddingModel.mergeOptions(original);
163+
164+
// Verify original options are not modified
165+
assertThat(original.getTransformer()).isEqualTo("original-model");
166+
assertThat(original.getKwargs()).containsEntry("original", "value");
167+
168+
// Verify merged options have expected values
169+
assertThat(merged.getTransformer()).isEqualTo("original-model");
170+
}
171+
172+
@Test
173+
public void mergeOptionsWithComplexKwargs() {
174+
var jdbcTemplate = Mockito.mock(JdbcTemplate.class);
175+
PostgresMlEmbeddingModel embeddingModel = new PostgresMlEmbeddingModel(jdbcTemplate);
176+
177+
Map<String, Object> complexKwargs = Map.of("device", "cuda:0", "model_kwargs",
178+
Map.of("trust_remote_code", true), "encode_kwargs",
179+
Map.of("normalize_embeddings", true, "batch_size", 64));
180+
181+
PostgresMlEmbeddingOptions options = embeddingModel
182+
.mergeOptions(PostgresMlEmbeddingOptions.builder().kwargs(complexKwargs).build());
183+
184+
assertThat(options.getKwargs()).hasSize(3);
185+
assertThat(options.getKwargs().get("device")).isEqualTo("cuda:0");
186+
assertThat(options.getKwargs().get("model_kwargs")).isInstanceOf(Map.class);
187+
assertThat(options.getKwargs().get("encode_kwargs")).isInstanceOf(Map.class);
188+
}
189+
190+
@Test
191+
public void builderChaining() {
192+
PostgresMlEmbeddingOptions options = PostgresMlEmbeddingOptions.builder()
193+
.transformer("model-1")
194+
.transformer("model-2") // Should override previous value
195+
.vectorType(PostgresMlEmbeddingModel.VectorType.PG_VECTOR)
196+
.metadataMode(org.springframework.ai.document.MetadataMode.ALL)
197+
.kwargs(Map.of("key1", "value1"))
198+
.kwargs(Map.of("key2", "value2")) // Should override previous kwargs
199+
.build();
200+
201+
assertThat(options.getTransformer()).isEqualTo("model-2");
202+
assertThat(options.getKwargs()).containsEntry("key2", "value2");
203+
assertThat(options.getKwargs()).doesNotContainKey("key1");
204+
}
205+
98206
}

0 commit comments

Comments
 (0)