diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java index 72329d3aa88..e2858c9e46d 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java @@ -159,6 +159,122 @@ void testInvalidResponseErrorHandler() { .hasMessageContaining("responseErrorHandler cannot be null"); } + @Test + void testBuilderWithAllCustomPaths() { + OpenAiApi api = OpenAiApi.builder() + .apiKey(TEST_API_KEY) + .baseUrl(TEST_BASE_URL) + .completionsPath("/custom/completions") + .embeddingsPath("/custom/embeddings") + .build(); + + assertThat(api).isNotNull(); + } + + @Test + void testBuilderImmutability() { + OpenAiApi.Builder builder = OpenAiApi.builder().apiKey(TEST_API_KEY).baseUrl(TEST_BASE_URL); + + OpenAiApi api1 = builder.build(); + OpenAiApi api2 = builder.build(); + + assertThat(api1).isNotNull(); + assertThat(api2).isNotNull(); + assertThat(api1).isNotSameAs(api2); + } + + @Test + void testNullApiKeyValue() { + assertThatThrownBy(() -> OpenAiApi.builder().apiKey((ApiKey) null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("apiKey cannot be null"); + } + + @Test + void testBuilderMethodChaining() { + MultiValueMap headers = new LinkedMultiValueMap<>(); + headers.add("Test-Header", "test-value"); + + OpenAiApi api = OpenAiApi.builder() + .apiKey(TEST_API_KEY) + .baseUrl(TEST_BASE_URL) + .completionsPath(TEST_COMPLETIONS_PATH) + .embeddingsPath(TEST_EMBEDDINGS_PATH) + .headers(headers) + .restClientBuilder(RestClient.builder()) + .webClientBuilder(WebClient.builder()) + .responseErrorHandler(mock(ResponseErrorHandler.class)) + .build(); + + assertThat(api).isNotNull(); + } + + @Test + void testCustomHeadersPreservation() { + MultiValueMap customHeaders = new LinkedMultiValueMap<>(); + customHeaders.add("X-Custom-Header", "custom-value"); + customHeaders.add("X-Organization", "org-123"); + customHeaders.add("User-Agent", "Custom-Client/1.0"); + + OpenAiApi api = OpenAiApi.builder().apiKey(TEST_API_KEY).headers(customHeaders).build(); + + assertThat(api).isNotNull(); + } + + @Test + void testComplexMultiValueHeaders() { + MultiValueMap multiHeaders = new LinkedMultiValueMap<>(); + multiHeaders.add("Accept", "application/json"); + multiHeaders.add("Accept", "text/plain"); + multiHeaders.add("Cache-Control", "no-cache"); + multiHeaders.add("Cache-Control", "no-store"); + + OpenAiApi api = OpenAiApi.builder().apiKey(TEST_API_KEY).headers(multiHeaders).build(); + + assertThat(api).isNotNull(); + } + + @Test + void testPathValidationWithSlashes() { + OpenAiApi api1 = OpenAiApi.builder() + .apiKey(TEST_API_KEY) + .completionsPath("/v1/completions") + .embeddingsPath("/v1/embeddings") + .build(); + + OpenAiApi api2 = OpenAiApi.builder() + .apiKey(TEST_API_KEY) + .completionsPath("v1/completions") + .embeddingsPath("v1/embeddings") + .build(); + + assertThat(api1).isNotNull(); + assertThat(api2).isNotNull(); + } + + @Test + void testInvalidPathsWithSpecialCharacters() { + assertThatThrownBy(() -> OpenAiApi.builder().apiKey(TEST_API_KEY).completionsPath(" ").build()) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void testBuilderWithOnlyRequiredFields() { + OpenAiApi api = OpenAiApi.builder().apiKey(TEST_API_KEY).build(); + + assertThat(api).isNotNull(); + } + + @Test + void testDifferentApiKeyTypes() { + SimpleApiKey simpleKey = new SimpleApiKey("simple-key"); + OpenAiApi api1 = OpenAiApi.builder().apiKey(simpleKey).build(); + assertThat(api1).isNotNull(); + + OpenAiApi api2 = OpenAiApi.builder().apiKey(() -> "supplier-key").build(); + assertThat(api2).isNotNull(); + } + @Nested class MockRequests { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiChatModelMutateTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiChatModelMutateTests.java index dcacca47613..706c0ab7f9a 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiChatModelMutateTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiChatModelMutateTests.java @@ -24,6 +24,7 @@ import org.springframework.util.LinkedMultiValueMap; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; /* * Integration test for mutate/clone functionality on OpenAiApi and OpenAiChatModel. @@ -190,4 +191,31 @@ void testMutateBuilderValidation() { assertThat(unchanged).isNotSameAs(this.baseModel); } + @Test + void testMutateWithInvalidBaseUrl() { + assertThatThrownBy(() -> this.baseApi.mutate().baseUrl("").build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl"); + + assertThatThrownBy(() -> this.baseApi.mutate().baseUrl(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl"); + } + + @Test + void testMutateWithNullOpenAiApi() { + assertThatThrownBy(() -> this.baseModel.mutate().openAiApi(null).build()) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void testMutatePreservesUnchangedFields() { + String originalBaseUrl = this.baseApi.getBaseUrl(); + String newApiKey = "new-test-key"; + + OpenAiApi mutated = this.baseApi.mutate().apiKey(newApiKey).build(); + + assertThat(mutated.getBaseUrl()).isEqualTo(originalBaseUrl); + assertThat(mutated.getApiKey().getValue()).isEqualTo(newApiKey); + } + } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelperTest.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelperTest.java index 23fcf704fdb..d67245dc49e 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelperTest.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelperTest.java @@ -25,6 +25,7 @@ import org.mockito.Mockito; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; /** * Unit tests for {@link OpenAiStreamFunctionCallingHelper} @@ -206,4 +207,131 @@ public void isStreamingToolFunctionCallFinishDetectsToolCallsFinishReason() { assertThat(this.helper.isStreamingToolFunctionCallFinish(chunk)).isTrue(); } + @Test + public void merge_whenBothChunksAreNull() { + var result = this.helper.merge(null, null); + assertThat(result).isNull(); + } + + @Test + public void merge_whenPreviousIsNull() { + var current = new OpenAiApi.ChatCompletionChunk("id", Collections.emptyList(), System.currentTimeMillis(), + "model", "default", "fingerprint", "object", null); + + var result = this.helper.merge(null, current); + assertThat(result).isEqualTo(current); + } + + @Test + public void merge_whenCurrentIsNull() { + var previous = new OpenAiApi.ChatCompletionChunk("id", Collections.emptyList(), System.currentTimeMillis(), + "model", "default", "fingerprint", "object", null); + + var result = this.helper.merge(previous, null); + assertThat(result).isEqualTo(previous); + } + + @Test + public void merge_partialFieldsFromEachChunk() { + var choices = List.of(Mockito.mock(OpenAiApi.ChatCompletionChunk.ChunkChoice.class)); + var usage = Mockito.mock(OpenAiApi.Usage.class); + + var previous = new OpenAiApi.ChatCompletionChunk(null, choices, 1L, "model1", null, "fp1", null, null); + var current = new OpenAiApi.ChatCompletionChunk("id2", null, null, null, "tier2", null, "object2", usage); + + var result = this.helper.merge(previous, current); + + assertThat(result.id()).isEqualTo("id2"); + assertThat(result.choices()).isEqualTo(choices); + assertThat(result.created()).isEqualTo(1L); + assertThat(result.model()).isEqualTo("model1"); + assertThat(result.serviceTier()).isEqualTo("tier2"); + assertThat(result.systemFingerprint()).isEqualTo("fp1"); + assertThat(result.object()).isEqualTo("object2"); + assertThat(result.usage()).isEqualTo(usage); + } + + @Test + public void isStreamingToolFunctionCall_withMultipleChoicesAndOnlyFirstHasToolCalls() { + var toolCall = Mockito.mock(OpenAiApi.ChatCompletionMessage.ToolCall.class); + var deltaWithToolCalls = new OpenAiApi.ChatCompletionMessage(null, null, null, null, List.of(toolCall), null, + null, null); + var deltaWithoutToolCalls = new OpenAiApi.ChatCompletionMessage(null, null); + + var choice1 = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, deltaWithToolCalls, null); + var choice2 = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, deltaWithoutToolCalls, null); + + var chunk = new OpenAiApi.ChatCompletionChunk(null, List.of(choice1, choice2), null, null, null, null, null, + null); + + assertThat(this.helper.isStreamingToolFunctionCall(chunk)).isTrue(); + } + + @Test + public void isStreamingToolFunctionCall_withMultipleChoicesAndNoneHaveToolCalls() { + var deltaWithoutToolCalls = new OpenAiApi.ChatCompletionMessage(null, null); + + var choice1 = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, deltaWithoutToolCalls, null); + var choice2 = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, deltaWithoutToolCalls, null); + + var chunk = new OpenAiApi.ChatCompletionChunk(null, List.of(choice1, choice2), null, null, null, null, null, + null); + + assertThat(this.helper.isStreamingToolFunctionCall(chunk)).isFalse(); + } + + @Test + public void isStreamingToolFunctionCallFinish_withMultipleChoicesAndOnlyFirstIsToolCallsFinish() { + var choice1 = new OpenAiApi.ChatCompletionChunk.ChunkChoice(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS, + null, new OpenAiApi.ChatCompletionMessage(null, null), null); + var choice2 = new OpenAiApi.ChatCompletionChunk.ChunkChoice(OpenAiApi.ChatCompletionFinishReason.STOP, null, + new OpenAiApi.ChatCompletionMessage(null, null), null); + + var chunk = new OpenAiApi.ChatCompletionChunk(null, List.of(choice1, choice2), null, null, null, null, null, + null); + + assertThat(this.helper.isStreamingToolFunctionCallFinish(chunk)).isTrue(); + } + + @Test + public void chunkToChatCompletion_whenChunkIsNull() { + assertThatThrownBy(() -> this.helper.chunkToChatCompletion(null)).isInstanceOf(NullPointerException.class); + } + + @Test + public void chunkToChatCompletion_withEmptyChoices() { + var chunk = new OpenAiApi.ChatCompletionChunk("id", Collections.emptyList(), 1L, "model", "tier", "fp", + "object", null); + + var result = this.helper.chunkToChatCompletion(chunk); + + assertThat(result.object()).isEqualTo("chat.completion"); + assertThat(result.choices()).isEmpty(); + assertThat(result.id()).isEqualTo("id"); + assertThat(result.created()).isEqualTo(1L); + assertThat(result.model()).isEqualTo("model"); + } + + @Test + public void edgeCases_emptyStringFields() { + var chunk = new OpenAiApi.ChatCompletionChunk("", Collections.emptyList(), 0L, "", "", "", "", null); + + var result = this.helper.chunkToChatCompletion(chunk); + + assertThat(result.id()).isEmpty(); + assertThat(result.model()).isEmpty(); + assertThat(result.serviceTier()).isEmpty(); + assertThat(result.systemFingerprint()).isEmpty(); + assertThat(result.created()).isEqualTo(0L); + } + + @Test + public void isStreamingToolFunctionCall_withNullToolCallsList() { + var delta = new OpenAiApi.ChatCompletionMessage(null, null, null, null, null, null, null, null); + var choice = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, delta, null); + var chunk = new OpenAiApi.ChatCompletionChunk(null, List.of(choice), null, null, null, null, null, null); + + assertThat(this.helper.isStreamingToolFunctionCall(chunk)).isFalse(); + } + }