From 922871a79223131e18c2239b19fde2dd3e65a882 Mon Sep 17 00:00:00 2001 From: Alex Klimenko Date: Wed, 6 Aug 2025 22:59:57 +0200 Subject: [PATCH] test: Add comprehensive test coverage for AzureOpenAiChatOptions and PostgresMlEmbeddingOptions Signed-off-by: Alex Klimenko --- .../AzureChatCompletionsOptionsTests.java | 112 ++++++++++++++++++ .../PostgresMlEmbeddingOptionsTests.java | 108 +++++++++++++++++ 2 files changed, 220 insertions(+) diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java index 2c13ced5636..9b1a1fd0cab 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java @@ -166,4 +166,116 @@ public void createChatOptionsWithPresencePenaltyAndFrequencyPenalty(Double prese } } + @Test + public void createRequestWithMinimalOptions() { + OpenAIClientBuilder mockClient = Mockito.mock(OpenAIClientBuilder.class); + + var minimalOptions = AzureOpenAiChatOptions.builder().deploymentName("MINIMAL_MODEL").build(); + + var client = AzureOpenAiChatModel.builder() + .openAIClientBuilder(mockClient) + .defaultOptions(minimalOptions) + .build(); + + var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message")); + + assertThat(requestOptions.getModel()).isEqualTo("MINIMAL_MODEL"); + assertThat(requestOptions.getTemperature()).isNull(); + assertThat(requestOptions.getMaxTokens()).isNull(); + assertThat(requestOptions.getTopP()).isNull(); + } + + @Test + public void createRequestWithEmptyStopList() { + OpenAIClientBuilder mockClient = Mockito.mock(OpenAIClientBuilder.class); + + var options = AzureOpenAiChatOptions.builder().deploymentName("TEST_MODEL").stop(List.of()).build(); + + var client = AzureOpenAiChatModel.builder().openAIClientBuilder(mockClient).defaultOptions(options).build(); + + var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message")); + + assertThat(requestOptions.getStop()).isEmpty(); + } + + @Test + public void createRequestWithEmptyLogitBias() { + OpenAIClientBuilder mockClient = Mockito.mock(OpenAIClientBuilder.class); + + var options = AzureOpenAiChatOptions.builder().deploymentName("TEST_MODEL").logitBias(Map.of()).build(); + + var client = AzureOpenAiChatModel.builder().openAIClientBuilder(mockClient).defaultOptions(options).build(); + + var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message")); + + assertThat(requestOptions.getLogitBias()).isEmpty(); + } + + @Test + public void createRequestWithLogprobsDisabled() { + OpenAIClientBuilder mockClient = Mockito.mock(OpenAIClientBuilder.class); + + var options = AzureOpenAiChatOptions.builder() + .deploymentName("TEST_MODEL") + .logprobs(false) + .topLogprobs(0) + .build(); + + var client = AzureOpenAiChatModel.builder().openAIClientBuilder(mockClient).defaultOptions(options).build(); + + var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message")); + + assertThat(requestOptions.isLogprobs()).isFalse(); + assertThat(requestOptions.getTopLogprobs()).isEqualTo(0); + } + + @Test + public void createRequestWithSingleStopSequence() { + OpenAIClientBuilder mockClient = Mockito.mock(OpenAIClientBuilder.class); + + var options = AzureOpenAiChatOptions.builder().deploymentName("SINGLE_STOP_MODEL").stop(List.of("END")).build(); + + var client = AzureOpenAiChatModel.builder().openAIClientBuilder(mockClient).defaultOptions(options).build(); + + var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message")); + + assertThat(requestOptions.getStop()).hasSize(1); + assertThat(requestOptions.getStop()).containsExactly("END"); + } + + @Test + public void builderPatternTest() { + var options = AzureOpenAiChatOptions.builder() + .deploymentName("BUILDER_TEST_MODEL") + .temperature(0.7) + .maxTokens(1500) + .build(); + + assertThat(options.getDeploymentName()).isEqualTo("BUILDER_TEST_MODEL"); + assertThat(options.getTemperature()).isEqualTo(0.7); + assertThat(options.getMaxTokens()).isEqualTo(1500); + } + + @ParameterizedTest + @MethodSource("provideResponseFormatTypes") + public void createRequestWithDifferentResponseFormats(Type responseFormatType, Class expectedFormatClass) { + OpenAIClientBuilder mockClient = Mockito.mock(OpenAIClientBuilder.class); + + var options = AzureOpenAiChatOptions.builder() + .deploymentName("FORMAT_TEST_MODEL") + .responseFormat(AzureOpenAiResponseFormat.builder().type(responseFormatType).build()) + .build(); + + var client = AzureOpenAiChatModel.builder().openAIClientBuilder(mockClient).defaultOptions(options).build(); + + var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message")); + + assertThat(requestOptions.getResponseFormat()).isInstanceOf(expectedFormatClass); + } + + private static Stream provideResponseFormatTypes() { + return Stream.of(Arguments.of(Type.TEXT, ChatCompletionsTextResponseFormat.class), + Arguments.of(Type.JSON_OBJECT, ChatCompletionsJsonResponseFormat.class)); + } + } diff --git a/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptionsTests.java b/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptionsTests.java index f39630761bf..bc5cc218c11 100644 --- a/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptionsTests.java +++ b/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingOptionsTests.java @@ -95,4 +95,112 @@ public void mergeOptions() { assertThat(options.getMetadataMode()).isEqualTo(org.springframework.ai.document.MetadataMode.ALL); } + @Test + public void builderWithEmptyKwargs() { + PostgresMlEmbeddingOptions options = PostgresMlEmbeddingOptions.builder().kwargs(Map.of()).build(); + + assertThat(options.getKwargs()).isEmpty(); + assertThat(options.getKwargs()).isNotNull(); + } + + @Test + public void builderWithMultipleKwargs() { + Map kwargs = Map.of("device", "gpu", "batch_size", 32, "max_length", 512, "normalize", true); + + PostgresMlEmbeddingOptions options = PostgresMlEmbeddingOptions.builder().kwargs(kwargs).build(); + + assertThat(options.getKwargs()).hasSize(4); + assertThat(options.getKwargs().get("device")).isEqualTo("gpu"); + assertThat(options.getKwargs().get("batch_size")).isEqualTo(32); + assertThat(options.getKwargs().get("max_length")).isEqualTo(512); + assertThat(options.getKwargs().get("normalize")).isEqualTo(true); + } + + @Test + public void allVectorTypes() { + for (PostgresMlEmbeddingModel.VectorType vectorType : PostgresMlEmbeddingModel.VectorType.values()) { + PostgresMlEmbeddingOptions options = PostgresMlEmbeddingOptions.builder().vectorType(vectorType).build(); + + assertThat(options.getVectorType()).isEqualTo(vectorType); + } + } + + @Test + public void allMetadataModes() { + for (org.springframework.ai.document.MetadataMode mode : org.springframework.ai.document.MetadataMode + .values()) { + PostgresMlEmbeddingOptions options = PostgresMlEmbeddingOptions.builder().metadataMode(mode).build(); + + assertThat(options.getMetadataMode()).isEqualTo(mode); + } + } + + @Test + public void mergeOptionsWithNullInput() { + var jdbcTemplate = Mockito.mock(JdbcTemplate.class); + PostgresMlEmbeddingModel embeddingModel = new PostgresMlEmbeddingModel(jdbcTemplate); + + PostgresMlEmbeddingOptions options = embeddingModel.mergeOptions(null); + + // Should return default options when input is null + assertThat(options.getTransformer()).isEqualTo(PostgresMlEmbeddingModel.DEFAULT_TRANSFORMER_MODEL); + assertThat(options.getVectorType()).isEqualTo(PostgresMlEmbeddingModel.VectorType.PG_ARRAY); + assertThat(options.getKwargs()).isEqualTo(Map.of()); + assertThat(options.getMetadataMode()).isEqualTo(org.springframework.ai.document.MetadataMode.EMBED); + } + + @Test + public void mergeOptionsPreservesOriginal() { + var jdbcTemplate = Mockito.mock(JdbcTemplate.class); + PostgresMlEmbeddingModel embeddingModel = new PostgresMlEmbeddingModel(jdbcTemplate); + + PostgresMlEmbeddingOptions original = PostgresMlEmbeddingOptions.builder() + .transformer("original-model") + .kwargs(Map.of("original", "value")) + .build(); + + PostgresMlEmbeddingOptions merged = embeddingModel.mergeOptions(original); + + // Verify original options are not modified + assertThat(original.getTransformer()).isEqualTo("original-model"); + assertThat(original.getKwargs()).containsEntry("original", "value"); + + // Verify merged options have expected values + assertThat(merged.getTransformer()).isEqualTo("original-model"); + } + + @Test + public void mergeOptionsWithComplexKwargs() { + var jdbcTemplate = Mockito.mock(JdbcTemplate.class); + PostgresMlEmbeddingModel embeddingModel = new PostgresMlEmbeddingModel(jdbcTemplate); + + Map complexKwargs = Map.of("device", "cuda:0", "model_kwargs", + Map.of("trust_remote_code", true), "encode_kwargs", + Map.of("normalize_embeddings", true, "batch_size", 64)); + + PostgresMlEmbeddingOptions options = embeddingModel + .mergeOptions(PostgresMlEmbeddingOptions.builder().kwargs(complexKwargs).build()); + + assertThat(options.getKwargs()).hasSize(3); + assertThat(options.getKwargs().get("device")).isEqualTo("cuda:0"); + assertThat(options.getKwargs().get("model_kwargs")).isInstanceOf(Map.class); + assertThat(options.getKwargs().get("encode_kwargs")).isInstanceOf(Map.class); + } + + @Test + public void builderChaining() { + PostgresMlEmbeddingOptions options = PostgresMlEmbeddingOptions.builder() + .transformer("model-1") + .transformer("model-2") // Should override previous value + .vectorType(PostgresMlEmbeddingModel.VectorType.PG_VECTOR) + .metadataMode(org.springframework.ai.document.MetadataMode.ALL) + .kwargs(Map.of("key1", "value1")) + .kwargs(Map.of("key2", "value2")) // Should override previous kwargs + .build(); + + assertThat(options.getTransformer()).isEqualTo("model-2"); + assertThat(options.getKwargs()).containsEntry("key2", "value2"); + assertThat(options.getKwargs()).doesNotContainKey("key1"); + } + }