From 4af3439c438e36cc1541b647e547279840782737 Mon Sep 17 00:00:00 2001 From: Filip Hrisafov Date: Wed, 28 May 2025 20:01:51 +0200 Subject: [PATCH 1/3] Resolve OpenAI ApiKey for every request Signed-off-by: Filip Hrisafov --- .../ai/openai/api/OpenAiApi.java | 14 +- .../ai/openai/api/OpenAiAudioApi.java | 20 +- .../ai/openai/api/OpenAiImageApi.java | 12 +- .../ai/openai/api/OpenAiModerationApi.java | 22 +- .../ai/openai/api/OpenAiApiBuilderTests.java | 139 ++++++++++++ .../audio/api/OpenAiAudioApiBuilderTests.java | 209 ++++++++++++++++++ .../image/api/OpenAiImageApiBuilderTests.java | 176 +++++++++++++++ .../api/OpenAiModerationApiBuilderTests.java | 176 +++++++++++++++ 8 files changed, 748 insertions(+), 20 deletions(-) create mode 100644 models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiBuilderTests.java create mode 100644 models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/api/OpenAiImageApiBuilderTests.java create mode 100644 models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/api/OpenAiModerationApiBuilderTests.java diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index d72534b4654..b92ec3d75ff 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -128,10 +128,6 @@ public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap he // @formatter:off Consumer finalHeaders = h -> { - if (!(apiKey instanceof NoopApiKey)) { - h.setBearerAuth(apiKey.getValue()); - } - h.setContentType(MediaType.APPLICATION_JSON); h.addAll(headers); }; @@ -139,11 +135,21 @@ public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap he .baseUrl(baseUrl) .defaultHeaders(finalHeaders) .defaultStatusHandler(responseErrorHandler) + .defaultRequest(requestHeadersSpec -> { + if (!(apiKey instanceof NoopApiKey)) { + requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue()); + } + }) .build(); this.webClient = webClientBuilder.clone() .baseUrl(baseUrl) .defaultHeaders(finalHeaders) + .defaultRequest(requestHeadersSpec -> { + if (!(apiKey instanceof NoopApiKey)) { + requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue()); + } + }) .build(); // @formatter:on } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java index 95604f30075..0e355ffb7a6 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java @@ -71,20 +71,30 @@ public OpenAiAudioApi(String baseUrl, ApiKey apiKey, MultiValueMap authHeaders = h -> { - if (!(apiKey instanceof NoopApiKey)) { - h.setBearerAuth(apiKey.getValue()); - } h.addAll(headers); - // h.setContentType(MediaType.APPLICATION_JSON); }; + // @formatter:off this.restClient = restClientBuilder.clone() .baseUrl(baseUrl) .defaultHeaders(authHeaders) .defaultStatusHandler(responseErrorHandler) + .defaultRequest(requestHeadersSpec -> { + if (!(apiKey instanceof NoopApiKey)) { + requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue()); + } + }) .build(); - this.webClient = webClientBuilder.clone().baseUrl(baseUrl).defaultHeaders(authHeaders).build(); + this.webClient = webClientBuilder.clone() + .baseUrl(baseUrl) + .defaultHeaders(authHeaders) + .defaultRequest(requestHeadersSpec -> { + if (!(apiKey instanceof NoopApiKey)) { + requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue()); + } + }) + .build(); // @formatter:on } public static Builder builder() { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java index bd32d42d9f6..e4f8c50fbf2 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java @@ -27,6 +27,7 @@ import org.springframework.ai.model.SimpleApiKey; import org.springframework.ai.openai.api.common.OpenAiApiConstants; import org.springframework.ai.retry.RetryUtils; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; @@ -62,15 +63,18 @@ public OpenAiImageApi(String baseUrl, ApiKey apiKey, MultiValueMap { - if (!(apiKey instanceof NoopApiKey)) { - h.setBearerAuth(apiKey.getValue()); - } h.setContentType(MediaType.APPLICATION_JSON); h.addAll(headers); }) .defaultStatusHandler(responseErrorHandler) + .defaultRequest(requestHeadersSpec -> { + if (!(apiKey instanceof NoopApiKey)) { + requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue()); + } + }) .build(); // @formatter:on diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java index b9a6578b636..4fbea8e1469 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java @@ -27,6 +27,7 @@ import org.springframework.ai.model.SimpleApiKey; import org.springframework.ai.openai.api.common.OpenAiApiConstants; import org.springframework.ai.retry.RetryUtils; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; @@ -64,13 +65,20 @@ public OpenAiModerationApi(String baseUrl, ApiKey apiKey, MultiValueMap { - if (!(apiKey instanceof NoopApiKey)) { - h.setBearerAuth(apiKey.getValue()); - } - h.setContentType(MediaType.APPLICATION_JSON); - h.addAll(headers); - }).defaultStatusHandler(responseErrorHandler).build(); + // @formatter:off + this.restClient = restClientBuilder.clone() + .baseUrl(baseUrl) + .defaultHeaders(h -> { + h.setContentType(MediaType.APPLICATION_JSON); + h.addAll(headers); + }) + .defaultStatusHandler(responseErrorHandler) + .defaultRequest(requestHeadersSpec -> { + if (!(apiKey instanceof NoopApiKey)) { + requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue()); + } + }) + .build(); // @formatter:on } public ResponseEntity createModeration(OpenAiModerationRequest openAiModerationRequest) { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java index 30f28f610c9..e9d188052fa 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java @@ -16,10 +16,27 @@ package org.springframework.ai.openai.api; +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.Queue; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.springframework.ai.model.ApiKey; import org.springframework.ai.model.SimpleApiKey; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; @@ -142,4 +159,126 @@ void testInvalidResponseErrorHandler() { .hasMessageContaining("responseErrorHandler cannot be null"); } + @Nested + class MockRequests { + + MockWebServer mockWebServer; + + @BeforeEach + void setUp() throws IOException { + mockWebServer = new MockWebServer(); + mockWebServer.start(); + } + + @AfterEach + void tearDown() throws IOException { + mockWebServer.shutdown(); + } + + @Test + void dynamicApiKeyRestClient() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + OpenAiApi api = OpenAiApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + { + "id": "chatcmpl-12345", + "object": "chat.completion", + "created": 1677858242, + "model": "gpt-3.5-turbo", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello world" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15 + } + } + """); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world", + OpenAiApi.ChatCompletionMessage.Role.USER); + OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest( + List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, false); + ResponseEntity response = api.chatCompletionEntity(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); + + response = api.chatCompletionEntity(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); + } + + @Test + void dynamicApiKeyWebClient() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + OpenAiApi api = OpenAiApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + { + "id": "chatcmpl-12345", + "object": "chat.completion", + "created": 1677858242, + "model": "gpt-3.5-turbo", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello world" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15 + } + } + """.replace("\n", "")); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world", + OpenAiApi.ChatCompletionMessage.Role.USER); + OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest( + List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, true); + List response = api.chatCompletionStream(request).collectList().block(); + assertThat(response).hasSize(1); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); + + response = api.chatCompletionStream(request).collectList().block(); + assertThat(response).hasSize(1); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); + } + + } + } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiBuilderTests.java new file mode 100644 index 00000000000..1b11fa6807a --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/api/OpenAiAudioApiBuilderTests.java @@ -0,0 +1,209 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.audio.api; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.Queue; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.springframework.ai.model.ApiKey; +import org.springframework.ai.model.SimpleApiKey; +import org.springframework.ai.openai.api.OpenAiAudioApi; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; + +/** + * @author Filip Hrisafov + */ +class OpenAiAudioApiBuilderTests { + + private static final ApiKey TEST_API_KEY = new SimpleApiKey("test-api-key"); + + private static final String TEST_BASE_URL = "https://test.openai.com"; + + @Test + void testMinimalBuilder() { + OpenAiAudioApi api = OpenAiAudioApi.builder().apiKey(TEST_API_KEY).build(); + + assertThat(api).isNotNull(); + } + + @Test + void testFullBuilder() { + MultiValueMap headers = new LinkedMultiValueMap<>(); + headers.add("Custom-Header", "test-value"); + RestClient.Builder restClientBuilder = RestClient.builder(); + WebClient.Builder webClientBuilder = WebClient.builder(); + ResponseErrorHandler errorHandler = mock(ResponseErrorHandler.class); + + OpenAiAudioApi api = OpenAiAudioApi.builder() + .baseUrl(TEST_BASE_URL) + .apiKey(TEST_API_KEY) + .headers(headers) + .restClientBuilder(restClientBuilder) + .webClientBuilder(webClientBuilder) + .responseErrorHandler(errorHandler) + .build(); + + assertThat(api).isNotNull(); + } + + @Test + void testMissingApiKey() { + assertThatThrownBy(() -> OpenAiAudioApi.builder().build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("apiKey must be set"); + } + + @Test + void testInvalidBaseUrl() { + assertThatThrownBy(() -> OpenAiAudioApi.builder().baseUrl("").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be null or empty"); + + assertThatThrownBy(() -> OpenAiAudioApi.builder().baseUrl(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be null or empty"); + } + + @Test + void testInvalidHeaders() { + assertThatThrownBy(() -> OpenAiAudioApi.builder().headers(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("headers cannot be null"); + } + + @Test + void testInvalidRestClientBuilder() { + assertThatThrownBy(() -> OpenAiAudioApi.builder().restClientBuilder(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("restClientBuilder cannot be null"); + } + + @Test + void testInvalidWebClientBuilder() { + assertThatThrownBy(() -> OpenAiAudioApi.builder().webClientBuilder(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("webClientBuilder cannot be null"); + } + + @Test + void testInvalidResponseErrorHandler() { + assertThatThrownBy(() -> OpenAiAudioApi.builder().responseErrorHandler(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("responseErrorHandler cannot be null"); + } + + @Nested + class MockRequests { + + MockWebServer mockWebServer; + + @BeforeEach + void setUp() throws IOException { + mockWebServer = new MockWebServer(); + mockWebServer.start(); + } + + @AfterEach + void tearDown() throws IOException { + mockWebServer.shutdown(); + } + + @Test + void dynamicApiKeyRestClient() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + OpenAiAudioApi api = OpenAiAudioApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_OCTET_STREAM_VALUE) + .setBody("Audio bytes as string"); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + OpenAiAudioApi.SpeechRequest request = OpenAiAudioApi.SpeechRequest.builder() + .model(OpenAiAudioApi.TtsModel.TTS_1.value) + .input("Test input") + .build(); + ResponseEntity response = api.createSpeech(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); + + response = api.createSpeech(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); + } + + @Test + void dynamicApiKeyWebClient() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + OpenAiAudioApi api = OpenAiAudioApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_OCTET_STREAM_VALUE) + .setBody("Audio bytes as string"); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + OpenAiAudioApi.SpeechRequest request = OpenAiAudioApi.SpeechRequest.builder() + .model(OpenAiAudioApi.TtsModel.TTS_1.value) + .input("Test input") + .build(); + List> response = api.stream(request).collectList().block(); + assertThat(response).hasSize(1); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); + + response = api.stream(request).collectList().block(); + assertThat(response).hasSize(1); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); + } + + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/api/OpenAiImageApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/api/OpenAiImageApiBuilderTests.java new file mode 100644 index 00000000000..5564a087a71 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/api/OpenAiImageApiBuilderTests.java @@ -0,0 +1,176 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.image.api; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.Queue; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.springframework.ai.model.ApiKey; +import org.springframework.ai.model.SimpleApiKey; +import org.springframework.ai.openai.api.OpenAiImageApi; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; + +/** + * @author Filip Hrisafov + */ +class OpenAiImageApiBuilderTests { + + private static final ApiKey TEST_API_KEY = new SimpleApiKey("test-api-key"); + + private static final String TEST_BASE_URL = "https://test.openai.com"; + + @Test + void testMinimalBuilder() { + OpenAiImageApi api = OpenAiImageApi.builder().apiKey(TEST_API_KEY).build(); + + assertThat(api).isNotNull(); + } + + @Test + void testFullBuilder() { + MultiValueMap headers = new LinkedMultiValueMap<>(); + headers.add("Custom-Header", "test-value"); + RestClient.Builder restClientBuilder = RestClient.builder(); + ResponseErrorHandler errorHandler = mock(ResponseErrorHandler.class); + + OpenAiImageApi api = OpenAiImageApi.builder() + .baseUrl(TEST_BASE_URL) + .apiKey(TEST_API_KEY) + .headers(headers) + .restClientBuilder(restClientBuilder) + .responseErrorHandler(errorHandler) + .build(); + + assertThat(api).isNotNull(); + } + + @Test + void testMissingApiKey() { + assertThatThrownBy(() -> OpenAiImageApi.builder().build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("apiKey must be set"); + } + + @Test + void testInvalidBaseUrl() { + assertThatThrownBy(() -> OpenAiImageApi.builder().baseUrl("").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be null or empty"); + + assertThatThrownBy(() -> OpenAiImageApi.builder().baseUrl(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be null or empty"); + } + + @Test + void testInvalidHeaders() { + assertThatThrownBy(() -> OpenAiImageApi.builder().headers(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("headers cannot be null"); + } + + @Test + void testInvalidRestClientBuilder() { + assertThatThrownBy(() -> OpenAiImageApi.builder().restClientBuilder(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("restClientBuilder cannot be null"); + } + + @Test + void testInvalidResponseErrorHandler() { + assertThatThrownBy(() -> OpenAiImageApi.builder().responseErrorHandler(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("responseErrorHandler cannot be null"); + } + + @Nested + class MockRequests { + + MockWebServer mockWebServer; + + @BeforeEach + void setUp() throws IOException { + mockWebServer = new MockWebServer(); + mockWebServer.start(); + } + + @AfterEach + void tearDown() throws IOException { + mockWebServer.shutdown(); + } + + @Test + void dynamicApiKeyRestClient() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + OpenAiImageApi api = OpenAiImageApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + { + "created": 1589478378, + "data": [ + { + "url": "https://upload.wikimedia.org/wikipedia/commons/4/4e/Mini_Golden_Doodle.jpg" + } + ] + } + """); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + OpenAiImageApi.OpenAiImageRequest request = new OpenAiImageApi.OpenAiImageRequest("Test", + OpenAiImageApi.ImageModel.DALL_E_3.getValue()); + ResponseEntity response = api.createImage(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); + + response = api.createImage(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); + } + + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/api/OpenAiModerationApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/api/OpenAiModerationApiBuilderTests.java new file mode 100644 index 00000000000..1c757789b27 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/moderation/api/OpenAiModerationApiBuilderTests.java @@ -0,0 +1,176 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.moderation.api; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.Queue; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.springframework.ai.model.ApiKey; +import org.springframework.ai.model.SimpleApiKey; +import org.springframework.ai.openai.api.OpenAiModerationApi; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; + +/** + * @author Filip Hrisafov + */ +class OpenAiModerationApiBuilderTests { + + private static final ApiKey TEST_API_KEY = new SimpleApiKey("test-api-key"); + + private static final String TEST_BASE_URL = "https://test.openai.com"; + + @Test + void testMinimalBuilder() { + OpenAiModerationApi api = OpenAiModerationApi.builder().apiKey(TEST_API_KEY).build(); + + assertThat(api).isNotNull(); + } + + @Test + void testFullBuilder() { + MultiValueMap headers = new LinkedMultiValueMap<>(); + headers.add("Custom-Header", "test-value"); + RestClient.Builder restClientBuilder = RestClient.builder(); + ResponseErrorHandler errorHandler = mock(ResponseErrorHandler.class); + + OpenAiModerationApi api = OpenAiModerationApi.builder() + .baseUrl(TEST_BASE_URL) + .apiKey(TEST_API_KEY) + .headers(headers) + .restClientBuilder(restClientBuilder) + .responseErrorHandler(errorHandler) + .build(); + + assertThat(api).isNotNull(); + } + + @Test + void testMissingApiKey() { + assertThatThrownBy(() -> OpenAiModerationApi.builder().build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("apiKey must be set"); + } + + @Test + void testInvalidBaseUrl() { + assertThatThrownBy(() -> OpenAiModerationApi.builder().baseUrl("").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be null or empty"); + + assertThatThrownBy(() -> OpenAiModerationApi.builder().baseUrl(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be null or empty"); + } + + @Test + void testInvalidHeaders() { + assertThatThrownBy(() -> OpenAiModerationApi.builder().headers(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("headers cannot be null"); + } + + @Test + void testInvalidRestClientBuilder() { + assertThatThrownBy(() -> OpenAiModerationApi.builder().restClientBuilder(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("restClientBuilder cannot be null"); + } + + @Test + void testInvalidResponseErrorHandler() { + assertThatThrownBy(() -> OpenAiModerationApi.builder().responseErrorHandler(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("responseErrorHandler cannot be null"); + } + + @Nested + class MockRequests { + + MockWebServer mockWebServer; + + @BeforeEach + void setUp() throws IOException { + mockWebServer = new MockWebServer(); + mockWebServer.start(); + } + + @AfterEach + void tearDown() throws IOException { + mockWebServer.shutdown(); + } + + @Test + void dynamicApiKeyRestClient() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + OpenAiModerationApi api = OpenAiModerationApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + { + "created": 1589478378, + "data": [ + { + "url": "https://upload.wikimedia.org/wikipedia/commons/4/4e/Mini_Golden_Doodle.jpg" + } + ] + } + """); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + OpenAiModerationApi.OpenAiModerationRequest request = new OpenAiModerationApi.OpenAiModerationRequest( + "Test"); + ResponseEntity response = api.createModeration(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); + + response = api.createModeration(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); + } + + } + +} From 08510d6ac82cccaf9fe5e372f901d76acd0c262e Mon Sep 17 00:00:00 2001 From: Filip Hrisafov Date: Fri, 30 May 2025 14:30:07 +0200 Subject: [PATCH 2/3] Set ApiKey as late as possible Signed-off-by: Filip Hrisafov --- .../ai/openai/api/OpenAiApi.java | 31 +++--- .../ai/openai/api/OpenAiApiBuilderTests.java | 95 ++++++++++++++++++- 2 files changed, 106 insertions(+), 20 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index b92ec3d75ff..0dcce3d88c2 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -135,21 +135,11 @@ public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap he .baseUrl(baseUrl) .defaultHeaders(finalHeaders) .defaultStatusHandler(responseErrorHandler) - .defaultRequest(requestHeadersSpec -> { - if (!(apiKey instanceof NoopApiKey)) { - requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue()); - } - }) .build(); this.webClient = webClientBuilder.clone() .baseUrl(baseUrl) .defaultHeaders(finalHeaders) - .defaultRequest(requestHeadersSpec -> { - if (!(apiKey instanceof NoopApiKey)) { - requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue()); - } - }) .build(); // @formatter:on } @@ -185,12 +175,12 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null."); - return this.restClient.post() - .uri(this.completionsPath) - .headers(headers -> headers.addAll(additionalHttpHeader)) - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletion.class); + return this.restClient.post().uri(this.completionsPath).headers(headers -> { + headers.addAll(additionalHttpHeader); + if (!headers.containsKey(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) { + headers.setBearerAuth(this.apiKey.getValue()); + } + }).body(chatRequest).retrieve().toEntity(ChatCompletion.class); } /** @@ -219,9 +209,12 @@ public Flux chatCompletionStream(ChatCompletionRequest chat AtomicBoolean isInsideTool = new AtomicBoolean(false); - return this.webClient.post() - .uri(this.completionsPath) - .headers(headers -> headers.addAll(additionalHttpHeader)) + return this.webClient.post().uri(this.completionsPath).headers(headers -> { + headers.addAll(additionalHttpHeader); + if (!headers.containsKey(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) { + headers.setBearerAuth(this.apiKey.getValue()); + } + }) .body(Mono.just(chatRequest), ChatCompletionRequest.class) .retrieve() .bodyToFlux(String.class) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java index e9d188052fa..1d3aec07423 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java @@ -25,11 +25,11 @@ import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; - import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +import org.opentest4j.AssertionFailedError; import org.springframework.ai.model.ApiKey; import org.springframework.ai.model.SimpleApiKey; @@ -227,6 +227,52 @@ void dynamicApiKeyRestClient() throws InterruptedException { assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); } + @Test + void dynamicApiKeyRestClientWithAdditionalAuthorizationHeader() throws InterruptedException { + OpenAiApi api = OpenAiApi.builder().apiKey(() -> { + throw new AssertionFailedError("Should not be called, API key is provided in headers"); + }).baseUrl(mockWebServer.url("/").toString()).build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + { + "id": "chatcmpl-12345", + "object": "chat.completion", + "created": 1677858242, + "model": "gpt-3.5-turbo", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello world" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15 + } + } + """); + mockWebServer.enqueue(mockResponse); + + OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world", + OpenAiApi.ChatCompletionMessage.Role.USER); + OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest( + List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, false); + + MultiValueMap additionalHeaders = new LinkedMultiValueMap<>(); + additionalHeaders.add(HttpHeaders.AUTHORIZATION, "Bearer additional-key"); + ResponseEntity response = api.chatCompletionEntity(request, additionalHeaders); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer additional-key"); + } + @Test void dynamicApiKeyWebClient() throws InterruptedException { Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); @@ -279,6 +325,53 @@ void dynamicApiKeyWebClient() throws InterruptedException { assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); } + @Test + void dynamicApiKeyWebClientWithAdditionalAuthorizationHeader() throws InterruptedException { + OpenAiApi api = OpenAiApi.builder().apiKey(() -> { + throw new AssertionFailedError("Should not be called, API key is provided in headers"); + }).baseUrl(mockWebServer.url("/").toString()).build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + { + "id": "chatcmpl-12345", + "object": "chat.completion", + "created": 1677858242, + "model": "gpt-3.5-turbo", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello world" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15 + } + } + """.replace("\n", "")); + mockWebServer.enqueue(mockResponse); + + OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world", + OpenAiApi.ChatCompletionMessage.Role.USER); + OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest( + List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, true); + MultiValueMap additionalHeaders = new LinkedMultiValueMap<>(); + additionalHeaders.add(HttpHeaders.AUTHORIZATION, "Bearer additional-key"); + List response = api.chatCompletionStream(request, additionalHeaders) + .collectList() + .block(); + assertThat(response).hasSize(1); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer additional-key"); + } + } } From a48e608b68de496cf9c098604d87c8dcaf1b7a0b Mon Sep 17 00:00:00 2001 From: Filip Hrisafov Date: Sat, 31 May 2025 16:32:10 +0200 Subject: [PATCH 3/3] Fix formatting and add API key for embeddings Signed-off-by: Filip Hrisafov --- .../ai/openai/api/OpenAiApi.java | 37 +++++++++----- .../ai/openai/api/OpenAiApiBuilderTests.java | 51 +++++++++++++++++++ 2 files changed, 76 insertions(+), 12 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 0dcce3d88c2..3071bc6b716 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -175,12 +175,17 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null."); - return this.restClient.post().uri(this.completionsPath).headers(headers -> { - headers.addAll(additionalHttpHeader); - if (!headers.containsKey(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) { - headers.setBearerAuth(this.apiKey.getValue()); - } - }).body(chatRequest).retrieve().toEntity(ChatCompletion.class); + // @formatter:off + return this.restClient.post() + .uri(this.completionsPath) + .headers(headers -> { + headers.addAll(additionalHttpHeader); + addDefaultHeadersIfMissing(headers); + }) + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + // @formatter:on } /** @@ -209,12 +214,13 @@ public Flux chatCompletionStream(ChatCompletionRequest chat AtomicBoolean isInsideTool = new AtomicBoolean(false); - return this.webClient.post().uri(this.completionsPath).headers(headers -> { - headers.addAll(additionalHttpHeader); - if (!headers.containsKey(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) { - headers.setBearerAuth(this.apiKey.getValue()); - } - }) + // @formatter:off + return this.webClient.post() + .uri(this.completionsPath) + .headers(headers -> { + headers.addAll(additionalHttpHeader); + addDefaultHeadersIfMissing(headers); + }) // @formatter:on .body(Mono.just(chatRequest), ChatCompletionRequest.class) .retrieve() .bodyToFlux(String.class) @@ -288,6 +294,7 @@ public ResponseEntity> embeddings(EmbeddingRequest< return this.restClient.post() .uri(this.embeddingsPath) + .headers(this::addDefaultHeadersIfMissing) .body(embeddingRequest) .retrieve() .toEntity(new ParameterizedTypeReference<>() { @@ -295,6 +302,12 @@ public ResponseEntity> embeddings(EmbeddingRequest< }); } + private void addDefaultHeadersIfMissing(HttpHeaders headers) { + if (!headers.containsKey(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) { + headers.setBearerAuth(this.apiKey.getValue()); + } + } + // Package-private getters for mutate/copy String getBaseUrl() { return this.baseUrl; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java index 1d3aec07423..b47a4a91bac 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java @@ -372,6 +372,57 @@ void dynamicApiKeyWebClientWithAdditionalAuthorizationHeader() throws Interrupte assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer additional-key"); } + @Test + void dynamicApiKeyRestClientEmbeddings() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + OpenAiApi api = OpenAiApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + -0.005540426, + 0.0047363234, + -0.015009919, + -0.027093535, + -0.015173893, + 0.015173893, + -0.017608276 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 2, + "total_tokens": 2 + } + } + """); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + OpenAiApi.EmbeddingRequest request = new OpenAiApi.EmbeddingRequest<>("Hello world"); + ResponseEntity> response = api.embeddings(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); + + response = api.embeddings(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); + } + } }