diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index d72534b4654..3071bc6b716 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -128,10 +128,6 @@ public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap he // @formatter:off Consumer finalHeaders = h -> { - if (!(apiKey instanceof NoopApiKey)) { - h.setBearerAuth(apiKey.getValue()); - } - h.setContentType(MediaType.APPLICATION_JSON); h.addAll(headers); }; @@ -179,12 +175,17 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null."); + // @formatter:off return this.restClient.post() .uri(this.completionsPath) - .headers(headers -> headers.addAll(additionalHttpHeader)) + .headers(headers -> { + headers.addAll(additionalHttpHeader); + addDefaultHeadersIfMissing(headers); + }) .body(chatRequest) .retrieve() .toEntity(ChatCompletion.class); + // @formatter:on } /** @@ -213,9 +214,13 @@ public Flux chatCompletionStream(ChatCompletionRequest chat AtomicBoolean isInsideTool = new AtomicBoolean(false); + // @formatter:off return this.webClient.post() .uri(this.completionsPath) - .headers(headers -> headers.addAll(additionalHttpHeader)) + .headers(headers -> { + headers.addAll(additionalHttpHeader); + addDefaultHeadersIfMissing(headers); + }) // @formatter:on .body(Mono.just(chatRequest), ChatCompletionRequest.class) .retrieve() .bodyToFlux(String.class) @@ -289,6 +294,7 @@ public ResponseEntity> embeddings(EmbeddingRequest< return this.restClient.post() .uri(this.embeddingsPath) + .headers(this::addDefaultHeadersIfMissing) .body(embeddingRequest) .retrieve() .toEntity(new ParameterizedTypeReference<>() { @@ -296,6 +302,12 @@ public ResponseEntity> embeddings(EmbeddingRequest< }); } + private void addDefaultHeadersIfMissing(HttpHeaders headers) { + if (!headers.containsKey(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) { + headers.setBearerAuth(this.apiKey.getValue()); + } + } + // Package-private getters for mutate/copy String getBaseUrl() { return this.baseUrl; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java index 95604f30075..0e355ffb7a6 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java @@ -71,20 +71,30 @@ public OpenAiAudioApi(String baseUrl, ApiKey apiKey, MultiValueMap authHeaders = h -> { - if (!(apiKey instanceof NoopApiKey)) { - h.setBearerAuth(apiKey.getValue()); - } h.addAll(headers); - // h.setContentType(MediaType.APPLICATION_JSON); }; + // @formatter:off this.restClient = restClientBuilder.clone() .baseUrl(baseUrl) .defaultHeaders(authHeaders) .defaultStatusHandler(responseErrorHandler) + .defaultRequest(requestHeadersSpec -> { + if (!(apiKey instanceof NoopApiKey)) { + requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue()); + } + }) .build(); - this.webClient = webClientBuilder.clone().baseUrl(baseUrl).defaultHeaders(authHeaders).build(); + this.webClient = webClientBuilder.clone() + .baseUrl(baseUrl) + .defaultHeaders(authHeaders) + .defaultRequest(requestHeadersSpec -> { + if (!(apiKey instanceof NoopApiKey)) { + requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue()); + } + }) + .build(); // @formatter:on } public static Builder builder() { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java index bd32d42d9f6..e4f8c50fbf2 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java @@ -27,6 +27,7 @@ import org.springframework.ai.model.SimpleApiKey; import org.springframework.ai.openai.api.common.OpenAiApiConstants; import org.springframework.ai.retry.RetryUtils; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; @@ -62,15 +63,18 @@ public OpenAiImageApi(String baseUrl, ApiKey apiKey, MultiValueMap { - if (!(apiKey instanceof NoopApiKey)) { - h.setBearerAuth(apiKey.getValue()); - } h.setContentType(MediaType.APPLICATION_JSON); h.addAll(headers); }) .defaultStatusHandler(responseErrorHandler) + .defaultRequest(requestHeadersSpec -> { + if (!(apiKey instanceof NoopApiKey)) { + requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue()); + } + }) .build(); // @formatter:on diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java index b9a6578b636..4fbea8e1469 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java @@ -27,6 +27,7 @@ import org.springframework.ai.model.SimpleApiKey; import org.springframework.ai.openai.api.common.OpenAiApiConstants; import org.springframework.ai.retry.RetryUtils; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; @@ -64,13 +65,20 @@ public OpenAiModerationApi(String baseUrl, ApiKey apiKey, MultiValueMap { - if (!(apiKey instanceof NoopApiKey)) { - h.setBearerAuth(apiKey.getValue()); - } - h.setContentType(MediaType.APPLICATION_JSON); - h.addAll(headers); - }).defaultStatusHandler(responseErrorHandler).build(); + // @formatter:off + this.restClient = restClientBuilder.clone() + .baseUrl(baseUrl) + .defaultHeaders(h -> { + h.setContentType(MediaType.APPLICATION_JSON); + h.addAll(headers); + }) + .defaultStatusHandler(responseErrorHandler) + .defaultRequest(requestHeadersSpec -> { + if (!(apiKey instanceof NoopApiKey)) { + requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue()); + } + }) + .build(); // @formatter:on } public ResponseEntity createModeration(OpenAiModerationRequest openAiModerationRequest) { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java index 30f28f610c9..b47a4a91bac 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java @@ -16,10 +16,27 @@ package org.springframework.ai.openai.api; +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.Queue; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +import org.opentest4j.AssertionFailedError; import org.springframework.ai.model.ApiKey; import org.springframework.ai.model.SimpleApiKey; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; @@ -142,4 +159,270 @@ void testInvalidResponseErrorHandler() { .hasMessageContaining("responseErrorHandler cannot be null"); } + @Nested + class MockRequests { + + MockWebServer mockWebServer; + + @BeforeEach + void setUp() throws IOException { + mockWebServer = new MockWebServer(); + mockWebServer.start(); + } + + @AfterEach + void tearDown() throws IOException { + mockWebServer.shutdown(); + } + + @Test + void dynamicApiKeyRestClient() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + OpenAiApi api = OpenAiApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + { + "id": "chatcmpl-12345", + "object": "chat.completion", + "created": 1677858242, + "model": "gpt-3.5-turbo", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello world" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15 + } + } + """); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world", + OpenAiApi.ChatCompletionMessage.Role.USER); + OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest( + List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, false); + ResponseEntity response = api.chatCompletionEntity(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); + + response = api.chatCompletionEntity(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); + } + + @Test + void dynamicApiKeyRestClientWithAdditionalAuthorizationHeader() throws InterruptedException { + OpenAiApi api = OpenAiApi.builder().apiKey(() -> { + throw new AssertionFailedError("Should not be called, API key is provided in headers"); + }).baseUrl(mockWebServer.url("/").toString()).build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + { + "id": "chatcmpl-12345", + "object": "chat.completion", + "created": 1677858242, + "model": "gpt-3.5-turbo", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello world" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15 + } + } + """); + mockWebServer.enqueue(mockResponse); + + OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world", + OpenAiApi.ChatCompletionMessage.Role.USER); + OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest( + List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, false); + + MultiValueMap additionalHeaders = new LinkedMultiValueMap<>(); + additionalHeaders.add(HttpHeaders.AUTHORIZATION, "Bearer additional-key"); + ResponseEntity response = api.chatCompletionEntity(request, additionalHeaders); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer additional-key"); + } + + @Test + void dynamicApiKeyWebClient() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + OpenAiApi api = OpenAiApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + { + "id": "chatcmpl-12345", + "object": "chat.completion", + "created": 1677858242, + "model": "gpt-3.5-turbo", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello world" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15 + } + } + """.replace("\n", "")); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world", + OpenAiApi.ChatCompletionMessage.Role.USER); + OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest( + List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, true); + List response = api.chatCompletionStream(request).collectList().block(); + assertThat(response).hasSize(1); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); + + response = api.chatCompletionStream(request).collectList().block(); + assertThat(response).hasSize(1); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); + } + + @Test + void dynamicApiKeyWebClientWithAdditionalAuthorizationHeader() throws InterruptedException { + OpenAiApi api = OpenAiApi.builder().apiKey(() -> { + throw new AssertionFailedError("Should not be called, API key is provided in headers"); + }).baseUrl(mockWebServer.url("/").toString()).build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + { + "id": "chatcmpl-12345", + "object": "chat.completion", + "created": 1677858242, + "model": "gpt-3.5-turbo", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello world" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15 + } + } + """.replace("\n", "")); + mockWebServer.enqueue(mockResponse); + + OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world", + OpenAiApi.ChatCompletionMessage.Role.USER); + OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest( + List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, true); + MultiValueMap additionalHeaders = new LinkedMultiValueMap<>(); + additionalHeaders.add(HttpHeaders.AUTHORIZATION, "Bearer additional-key"); + List response = api.chatCompletionStream(request, additionalHeaders) + .collectList() + .block(); + assertThat(response).hasSize(1); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer additional-key"); + } + + @Test + void dynamicApiKeyRestClientEmbeddings() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + OpenAiApi api = OpenAiApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + -0.005540426, + 0.0047363234, + -0.015009919, + -0.027093535, + -0.015173893, + 0.015173893, + -0.017608276 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 2, + "total_tokens": 2 + } + } + """); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + OpenAiApi.EmbeddingRequest request = new OpenAiApi.EmbeddingRequest<>("Hello world"); + ResponseEntity> response = api.embeddings(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); + + response = api.embeddings(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); + } + + } + } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiBuilderTests.java new file mode 100644 index 00000000000..1b11fa6807a --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiBuilderTests.java @@ -0,0 +1,209 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.audio.api; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.Queue; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.springframework.ai.model.ApiKey; +import org.springframework.ai.model.SimpleApiKey; +import org.springframework.ai.openai.api.OpenAiAudioApi; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; + +/** + * @author Filip Hrisafov + */ +class OpenAiAudioApiBuilderTests { + + private static final ApiKey TEST_API_KEY = new SimpleApiKey("test-api-key"); + + private static final String TEST_BASE_URL = "https://test.openai.com"; + + @Test + void testMinimalBuilder() { + OpenAiAudioApi api = OpenAiAudioApi.builder().apiKey(TEST_API_KEY).build(); + + assertThat(api).isNotNull(); + } + + @Test + void testFullBuilder() { + MultiValueMap headers = new LinkedMultiValueMap<>(); + headers.add("Custom-Header", "test-value"); + RestClient.Builder restClientBuilder = RestClient.builder(); + WebClient.Builder webClientBuilder = WebClient.builder(); + ResponseErrorHandler errorHandler = mock(ResponseErrorHandler.class); + + OpenAiAudioApi api = OpenAiAudioApi.builder() + .baseUrl(TEST_BASE_URL) + .apiKey(TEST_API_KEY) + .headers(headers) + .restClientBuilder(restClientBuilder) + .webClientBuilder(webClientBuilder) + .responseErrorHandler(errorHandler) + .build(); + + assertThat(api).isNotNull(); + } + + @Test + void testMissingApiKey() { + assertThatThrownBy(() -> OpenAiAudioApi.builder().build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("apiKey must be set"); + } + + @Test + void testInvalidBaseUrl() { + assertThatThrownBy(() -> OpenAiAudioApi.builder().baseUrl("").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be null or empty"); + + assertThatThrownBy(() -> OpenAiAudioApi.builder().baseUrl(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be null or empty"); + } + + @Test + void testInvalidHeaders() { + assertThatThrownBy(() -> OpenAiAudioApi.builder().headers(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("headers cannot be null"); + } + + @Test + void testInvalidRestClientBuilder() { + assertThatThrownBy(() -> OpenAiAudioApi.builder().restClientBuilder(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("restClientBuilder cannot be null"); + } + + @Test + void testInvalidWebClientBuilder() { + assertThatThrownBy(() -> OpenAiAudioApi.builder().webClientBuilder(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("webClientBuilder cannot be null"); + } + + @Test + void testInvalidResponseErrorHandler() { + assertThatThrownBy(() -> OpenAiAudioApi.builder().responseErrorHandler(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("responseErrorHandler cannot be null"); + } + + @Nested + class MockRequests { + + MockWebServer mockWebServer; + + @BeforeEach + void setUp() throws IOException { + mockWebServer = new MockWebServer(); + mockWebServer.start(); + } + + @AfterEach + void tearDown() throws IOException { + mockWebServer.shutdown(); + } + + @Test + void dynamicApiKeyRestClient() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + OpenAiAudioApi api = OpenAiAudioApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_OCTET_STREAM_VALUE) + .setBody("Audio bytes as string"); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + OpenAiAudioApi.SpeechRequest request = OpenAiAudioApi.SpeechRequest.builder() + .model(OpenAiAudioApi.TtsModel.TTS_1.value) + .input("Test input") + .build(); + ResponseEntity response = api.createSpeech(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); + + response = api.createSpeech(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); + } + + @Test + void dynamicApiKeyWebClient() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + OpenAiAudioApi api = OpenAiAudioApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_OCTET_STREAM_VALUE) + .setBody("Audio bytes as string"); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + OpenAiAudioApi.SpeechRequest request = OpenAiAudioApi.SpeechRequest.builder() + .model(OpenAiAudioApi.TtsModel.TTS_1.value) + .input("Test input") + .build(); + List> response = api.stream(request).collectList().block(); + assertThat(response).hasSize(1); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); + + response = api.stream(request).collectList().block(); + assertThat(response).hasSize(1); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); + } + + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/api/OpenAiImageApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/api/OpenAiImageApiBuilderTests.java new file mode 100644 index 00000000000..5564a087a71 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/api/OpenAiImageApiBuilderTests.java @@ -0,0 +1,176 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.image.api; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.Queue; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.springframework.ai.model.ApiKey; +import org.springframework.ai.model.SimpleApiKey; +import org.springframework.ai.openai.api.OpenAiImageApi; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; + +/** + * @author Filip Hrisafov + */ +class OpenAiImageApiBuilderTests { + + private static final ApiKey TEST_API_KEY = new SimpleApiKey("test-api-key"); + + private static final String TEST_BASE_URL = "https://test.openai.com"; + + @Test + void testMinimalBuilder() { + OpenAiImageApi api = OpenAiImageApi.builder().apiKey(TEST_API_KEY).build(); + + assertThat(api).isNotNull(); + } + + @Test + void testFullBuilder() { + MultiValueMap headers = new LinkedMultiValueMap<>(); + headers.add("Custom-Header", "test-value"); + RestClient.Builder restClientBuilder = RestClient.builder(); + ResponseErrorHandler errorHandler = mock(ResponseErrorHandler.class); + + OpenAiImageApi api = OpenAiImageApi.builder() + .baseUrl(TEST_BASE_URL) + .apiKey(TEST_API_KEY) + .headers(headers) + .restClientBuilder(restClientBuilder) + .responseErrorHandler(errorHandler) + .build(); + + assertThat(api).isNotNull(); + } + + @Test + void testMissingApiKey() { + assertThatThrownBy(() -> OpenAiImageApi.builder().build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("apiKey must be set"); + } + + @Test + void testInvalidBaseUrl() { + assertThatThrownBy(() -> OpenAiImageApi.builder().baseUrl("").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be null or empty"); + + assertThatThrownBy(() -> OpenAiImageApi.builder().baseUrl(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be null or empty"); + } + + @Test + void testInvalidHeaders() { + assertThatThrownBy(() -> OpenAiImageApi.builder().headers(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("headers cannot be null"); + } + + @Test + void testInvalidRestClientBuilder() { + assertThatThrownBy(() -> OpenAiImageApi.builder().restClientBuilder(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("restClientBuilder cannot be null"); + } + + @Test + void testInvalidResponseErrorHandler() { + assertThatThrownBy(() -> OpenAiImageApi.builder().responseErrorHandler(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("responseErrorHandler cannot be null"); + } + + @Nested + class MockRequests { + + MockWebServer mockWebServer; + + @BeforeEach + void setUp() throws IOException { + mockWebServer = new MockWebServer(); + mockWebServer.start(); + } + + @AfterEach + void tearDown() throws IOException { + mockWebServer.shutdown(); + } + + @Test + void dynamicApiKeyRestClient() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + OpenAiImageApi api = OpenAiImageApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + { + "created": 1589478378, + "data": [ + { + "url": "https://upload.wikimedia.org/wikipedia/commons/4/4e/Mini_Golden_Doodle.jpg" + } + ] + } + """); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + OpenAiImageApi.OpenAiImageRequest request = new OpenAiImageApi.OpenAiImageRequest("Test", + OpenAiImageApi.ImageModel.DALL_E_3.getValue()); + ResponseEntity response = api.createImage(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); + + response = api.createImage(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); + } + + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/api/OpenAiModerationApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/api/OpenAiModerationApiBuilderTests.java new file mode 100644 index 00000000000..1c757789b27 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/api/OpenAiModerationApiBuilderTests.java @@ -0,0 +1,176 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.moderation.api; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.Queue; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.springframework.ai.model.ApiKey; +import org.springframework.ai.model.SimpleApiKey; +import org.springframework.ai.openai.api.OpenAiModerationApi; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; + +/** + * @author Filip Hrisafov + */ +class OpenAiModerationApiBuilderTests { + + private static final ApiKey TEST_API_KEY = new SimpleApiKey("test-api-key"); + + private static final String TEST_BASE_URL = "https://test.openai.com"; + + @Test + void testMinimalBuilder() { + OpenAiModerationApi api = OpenAiModerationApi.builder().apiKey(TEST_API_KEY).build(); + + assertThat(api).isNotNull(); + } + + @Test + void testFullBuilder() { + MultiValueMap headers = new LinkedMultiValueMap<>(); + headers.add("Custom-Header", "test-value"); + RestClient.Builder restClientBuilder = RestClient.builder(); + ResponseErrorHandler errorHandler = mock(ResponseErrorHandler.class); + + OpenAiModerationApi api = OpenAiModerationApi.builder() + .baseUrl(TEST_BASE_URL) + .apiKey(TEST_API_KEY) + .headers(headers) + .restClientBuilder(restClientBuilder) + .responseErrorHandler(errorHandler) + .build(); + + assertThat(api).isNotNull(); + } + + @Test + void testMissingApiKey() { + assertThatThrownBy(() -> OpenAiModerationApi.builder().build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("apiKey must be set"); + } + + @Test + void testInvalidBaseUrl() { + assertThatThrownBy(() -> OpenAiModerationApi.builder().baseUrl("").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be null or empty"); + + assertThatThrownBy(() -> OpenAiModerationApi.builder().baseUrl(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be null or empty"); + } + + @Test + void testInvalidHeaders() { + assertThatThrownBy(() -> OpenAiModerationApi.builder().headers(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("headers cannot be null"); + } + + @Test + void testInvalidRestClientBuilder() { + assertThatThrownBy(() -> OpenAiModerationApi.builder().restClientBuilder(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("restClientBuilder cannot be null"); + } + + @Test + void testInvalidResponseErrorHandler() { + assertThatThrownBy(() -> OpenAiModerationApi.builder().responseErrorHandler(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("responseErrorHandler cannot be null"); + } + + @Nested + class MockRequests { + + MockWebServer mockWebServer; + + @BeforeEach + void setUp() throws IOException { + mockWebServer = new MockWebServer(); + mockWebServer.start(); + } + + @AfterEach + void tearDown() throws IOException { + mockWebServer.shutdown(); + } + + @Test + void dynamicApiKeyRestClient() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + OpenAiModerationApi api = OpenAiModerationApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + { + "created": 1589478378, + "data": [ + { + "url": "https://upload.wikimedia.org/wikipedia/commons/4/4e/Mini_Golden_Doodle.jpg" + } + ] + } + """); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + OpenAiModerationApi.OpenAiModerationRequest request = new OpenAiModerationApi.OpenAiModerationRequest( + "Test"); + ResponseEntity response = api.createModeration(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); + + response = api.createModeration(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); + } + + } + +}