Skip to content

Commit 4af3439

Browse files
committed
Resolve OpenAI ApiKey for every request
Signed-off-by: Filip Hrisafov <[email protected]>
1 parent 313aae0 commit 4af3439

File tree

8 files changed

+748
-20
lines changed

8 files changed

+748
-20
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,22 +128,28 @@ public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> he
128128

129129
// @formatter:off
130130
Consumer<HttpHeaders> finalHeaders = h -> {
131-
if (!(apiKey instanceof NoopApiKey)) {
132-
h.setBearerAuth(apiKey.getValue());
133-
}
134-
135131
h.setContentType(MediaType.APPLICATION_JSON);
136132
h.addAll(headers);
137133
};
138134
this.restClient = restClientBuilder.clone()
139135
.baseUrl(baseUrl)
140136
.defaultHeaders(finalHeaders)
141137
.defaultStatusHandler(responseErrorHandler)
138+
.defaultRequest(requestHeadersSpec -> {
139+
if (!(apiKey instanceof NoopApiKey)) {
140+
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
141+
}
142+
})
142143
.build();
143144

144145
this.webClient = webClientBuilder.clone()
145146
.baseUrl(baseUrl)
146147
.defaultHeaders(finalHeaders)
148+
.defaultRequest(requestHeadersSpec -> {
149+
if (!(apiKey instanceof NoopApiKey)) {
150+
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
151+
}
152+
})
147153
.build(); // @formatter:on
148154
}
149155

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,20 +71,30 @@ public OpenAiAudioApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, Strin
7171
ResponseErrorHandler responseErrorHandler) {
7272

7373
Consumer<HttpHeaders> authHeaders = h -> {
74-
if (!(apiKey instanceof NoopApiKey)) {
75-
h.setBearerAuth(apiKey.getValue());
76-
}
7774
h.addAll(headers);
78-
// h.setContentType(MediaType.APPLICATION_JSON);
7975
};
8076

77+
// @formatter:off
8178
this.restClient = restClientBuilder.clone()
8279
.baseUrl(baseUrl)
8380
.defaultHeaders(authHeaders)
8481
.defaultStatusHandler(responseErrorHandler)
82+
.defaultRequest(requestHeadersSpec -> {
83+
if (!(apiKey instanceof NoopApiKey)) {
84+
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
85+
}
86+
})
8587
.build();
8688

87-
this.webClient = webClientBuilder.clone().baseUrl(baseUrl).defaultHeaders(authHeaders).build();
89+
this.webClient = webClientBuilder.clone()
90+
.baseUrl(baseUrl)
91+
.defaultHeaders(authHeaders)
92+
.defaultRequest(requestHeadersSpec -> {
93+
if (!(apiKey instanceof NoopApiKey)) {
94+
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
95+
}
96+
})
97+
.build(); // @formatter:on
8898
}
8999

90100
public static Builder builder() {

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.ai.model.SimpleApiKey;
2828
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
2929
import org.springframework.ai.retry.RetryUtils;
30+
import org.springframework.http.HttpHeaders;
3031
import org.springframework.http.MediaType;
3132
import org.springframework.http.ResponseEntity;
3233
import org.springframework.util.Assert;
@@ -62,15 +63,18 @@ public OpenAiImageApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, Strin
6263
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
6364

6465
// @formatter:off
65-
this.restClient = restClientBuilder.baseUrl(baseUrl)
66+
this.restClient = restClientBuilder.clone()
67+
.baseUrl(baseUrl)
6668
.defaultHeaders(h -> {
67-
if (!(apiKey instanceof NoopApiKey)) {
68-
h.setBearerAuth(apiKey.getValue());
69-
}
7069
h.setContentType(MediaType.APPLICATION_JSON);
7170
h.addAll(headers);
7271
})
7372
.defaultStatusHandler(responseErrorHandler)
73+
.defaultRequest(requestHeadersSpec -> {
74+
if (!(apiKey instanceof NoopApiKey)) {
75+
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
76+
}
77+
})
7478
.build();
7579
// @formatter:on
7680

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.ai.model.SimpleApiKey;
2828
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
2929
import org.springframework.ai.retry.RetryUtils;
30+
import org.springframework.http.HttpHeaders;
3031
import org.springframework.http.MediaType;
3132
import org.springframework.http.ResponseEntity;
3233
import org.springframework.util.Assert;
@@ -64,13 +65,20 @@ public OpenAiModerationApi(String baseUrl, ApiKey apiKey, MultiValueMap<String,
6465

6566
this.objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
6667

67-
this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(h -> {
68-
if (!(apiKey instanceof NoopApiKey)) {
69-
h.setBearerAuth(apiKey.getValue());
70-
}
71-
h.setContentType(MediaType.APPLICATION_JSON);
72-
h.addAll(headers);
73-
}).defaultStatusHandler(responseErrorHandler).build();
68+
// @formatter:off
69+
this.restClient = restClientBuilder.clone()
70+
.baseUrl(baseUrl)
71+
.defaultHeaders(h -> {
72+
h.setContentType(MediaType.APPLICATION_JSON);
73+
h.addAll(headers);
74+
})
75+
.defaultStatusHandler(responseErrorHandler)
76+
.defaultRequest(requestHeadersSpec -> {
77+
if (!(apiKey instanceof NoopApiKey)) {
78+
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
79+
}
80+
})
81+
.build(); // @formatter:on
7482
}
7583

7684
public ResponseEntity<OpenAiModerationResponse> createModeration(OpenAiModerationRequest openAiModerationRequest) {

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,27 @@
1616

1717
package org.springframework.ai.openai.api;
1818

19+
import java.io.IOException;
20+
import java.util.LinkedList;
21+
import java.util.List;
22+
import java.util.Objects;
23+
import java.util.Queue;
24+
25+
import okhttp3.mockwebserver.MockResponse;
26+
import okhttp3.mockwebserver.MockWebServer;
27+
import okhttp3.mockwebserver.RecordedRequest;
28+
29+
import org.junit.jupiter.api.AfterEach;
30+
import org.junit.jupiter.api.BeforeEach;
31+
import org.junit.jupiter.api.Nested;
1932
import org.junit.jupiter.api.Test;
2033

2134
import org.springframework.ai.model.ApiKey;
2235
import org.springframework.ai.model.SimpleApiKey;
36+
import org.springframework.http.HttpHeaders;
37+
import org.springframework.http.HttpStatus;
38+
import org.springframework.http.MediaType;
39+
import org.springframework.http.ResponseEntity;
2340
import org.springframework.util.LinkedMultiValueMap;
2441
import org.springframework.util.MultiValueMap;
2542
import org.springframework.web.client.ResponseErrorHandler;
@@ -142,4 +159,126 @@ void testInvalidResponseErrorHandler() {
142159
.hasMessageContaining("responseErrorHandler cannot be null");
143160
}
144161

162+
@Nested
163+
class MockRequests {
164+
165+
MockWebServer mockWebServer;
166+
167+
@BeforeEach
168+
void setUp() throws IOException {
169+
mockWebServer = new MockWebServer();
170+
mockWebServer.start();
171+
}
172+
173+
@AfterEach
174+
void tearDown() throws IOException {
175+
mockWebServer.shutdown();
176+
}
177+
178+
@Test
179+
void dynamicApiKeyRestClient() throws InterruptedException {
180+
Queue<ApiKey> apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2")));
181+
OpenAiApi api = OpenAiApi.builder()
182+
.apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue())
183+
.baseUrl(mockWebServer.url("/").toString())
184+
.build();
185+
186+
MockResponse mockResponse = new MockResponse().setResponseCode(200)
187+
.addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
188+
.setBody("""
189+
{
190+
"id": "chatcmpl-12345",
191+
"object": "chat.completion",
192+
"created": 1677858242,
193+
"model": "gpt-3.5-turbo",
194+
"choices": [
195+
{
196+
"index": 0,
197+
"message": {
198+
"role": "assistant",
199+
"content": "Hello world"
200+
},
201+
"finish_reason": "stop"
202+
}
203+
],
204+
"usage": {
205+
"prompt_tokens": 10,
206+
"completion_tokens": 5,
207+
"total_tokens": 15
208+
}
209+
}
210+
""");
211+
mockWebServer.enqueue(mockResponse);
212+
mockWebServer.enqueue(mockResponse);
213+
214+
OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world",
215+
OpenAiApi.ChatCompletionMessage.Role.USER);
216+
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(
217+
List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, false);
218+
ResponseEntity<OpenAiApi.ChatCompletion> response = api.chatCompletionEntity(request);
219+
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK);
220+
RecordedRequest recordedRequest = mockWebServer.takeRequest();
221+
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1");
222+
223+
response = api.chatCompletionEntity(request);
224+
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK);
225+
226+
recordedRequest = mockWebServer.takeRequest();
227+
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2");
228+
}
229+
230+
@Test
231+
void dynamicApiKeyWebClient() throws InterruptedException {
232+
Queue<ApiKey> apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2")));
233+
OpenAiApi api = OpenAiApi.builder()
234+
.apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue())
235+
.baseUrl(mockWebServer.url("/").toString())
236+
.build();
237+
238+
MockResponse mockResponse = new MockResponse().setResponseCode(200)
239+
.addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
240+
.setBody("""
241+
{
242+
"id": "chatcmpl-12345",
243+
"object": "chat.completion",
244+
"created": 1677858242,
245+
"model": "gpt-3.5-turbo",
246+
"choices": [
247+
{
248+
"index": 0,
249+
"message": {
250+
"role": "assistant",
251+
"content": "Hello world"
252+
},
253+
"finish_reason": "stop"
254+
}
255+
],
256+
"usage": {
257+
"prompt_tokens": 10,
258+
"completion_tokens": 5,
259+
"total_tokens": 15
260+
}
261+
}
262+
""".replace("\n", ""));
263+
mockWebServer.enqueue(mockResponse);
264+
mockWebServer.enqueue(mockResponse);
265+
266+
OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world",
267+
OpenAiApi.ChatCompletionMessage.Role.USER);
268+
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(
269+
List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, true);
270+
List<OpenAiApi.ChatCompletionChunk> response = api.chatCompletionStream(request).collectList().block();
271+
assertThat(response).hasSize(1);
272+
RecordedRequest recordedRequest = mockWebServer.takeRequest();
273+
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1");
274+
275+
response = api.chatCompletionStream(request).collectList().block();
276+
assertThat(response).hasSize(1);
277+
278+
recordedRequest = mockWebServer.takeRequest();
279+
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2");
280+
}
281+
282+
}
283+
145284
}

0 commit comments

Comments
 (0)