Skip to content

Commit 08510d6

Browse files
committed
Set ApiKey as late as possible
Signed-off-by: Filip Hrisafov <[email protected]>
1 parent 4af3439 commit 08510d6

File tree

2 files changed

+106
-20
lines changed

2 files changed

+106
-20
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -135,21 +135,11 @@ public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> he
135135
.baseUrl(baseUrl)
136136
.defaultHeaders(finalHeaders)
137137
.defaultStatusHandler(responseErrorHandler)
138-
.defaultRequest(requestHeadersSpec -> {
139-
if (!(apiKey instanceof NoopApiKey)) {
140-
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
141-
}
142-
})
143138
.build();
144139

145140
this.webClient = webClientBuilder.clone()
146141
.baseUrl(baseUrl)
147142
.defaultHeaders(finalHeaders)
148-
.defaultRequest(requestHeadersSpec -> {
149-
if (!(apiKey instanceof NoopApiKey)) {
150-
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
151-
}
152-
})
153143
.build(); // @formatter:on
154144
}
155145

@@ -185,12 +175,12 @@ public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest
185175
Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false.");
186176
Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null.");
187177

188-
return this.restClient.post()
189-
.uri(this.completionsPath)
190-
.headers(headers -> headers.addAll(additionalHttpHeader))
191-
.body(chatRequest)
192-
.retrieve()
193-
.toEntity(ChatCompletion.class);
178+
return this.restClient.post().uri(this.completionsPath).headers(headers -> {
179+
headers.addAll(additionalHttpHeader);
180+
if (!headers.containsKey(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) {
181+
headers.setBearerAuth(this.apiKey.getValue());
182+
}
183+
}).body(chatRequest).retrieve().toEntity(ChatCompletion.class);
194184
}
195185

196186
/**
@@ -219,9 +209,12 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
219209

220210
AtomicBoolean isInsideTool = new AtomicBoolean(false);
221211

222-
return this.webClient.post()
223-
.uri(this.completionsPath)
224-
.headers(headers -> headers.addAll(additionalHttpHeader))
212+
return this.webClient.post().uri(this.completionsPath).headers(headers -> {
213+
headers.addAll(additionalHttpHeader);
214+
if (!headers.containsKey(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) {
215+
headers.setBearerAuth(this.apiKey.getValue());
216+
}
217+
})
225218
.body(Mono.just(chatRequest), ChatCompletionRequest.class)
226219
.retrieve()
227220
.bodyToFlux(String.class)

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
import okhttp3.mockwebserver.MockResponse;
2626
import okhttp3.mockwebserver.MockWebServer;
2727
import okhttp3.mockwebserver.RecordedRequest;
28-
2928
import org.junit.jupiter.api.AfterEach;
3029
import org.junit.jupiter.api.BeforeEach;
3130
import org.junit.jupiter.api.Nested;
3231
import org.junit.jupiter.api.Test;
32+
import org.opentest4j.AssertionFailedError;
3333

3434
import org.springframework.ai.model.ApiKey;
3535
import org.springframework.ai.model.SimpleApiKey;
@@ -227,6 +227,52 @@ void dynamicApiKeyRestClient() throws InterruptedException {
227227
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2");
228228
}
229229

230+
@Test
231+
void dynamicApiKeyRestClientWithAdditionalAuthorizationHeader() throws InterruptedException {
232+
OpenAiApi api = OpenAiApi.builder().apiKey(() -> {
233+
throw new AssertionFailedError("Should not be called, API key is provided in headers");
234+
}).baseUrl(mockWebServer.url("/").toString()).build();
235+
236+
MockResponse mockResponse = new MockResponse().setResponseCode(200)
237+
.addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
238+
.setBody("""
239+
{
240+
"id": "chatcmpl-12345",
241+
"object": "chat.completion",
242+
"created": 1677858242,
243+
"model": "gpt-3.5-turbo",
244+
"choices": [
245+
{
246+
"index": 0,
247+
"message": {
248+
"role": "assistant",
249+
"content": "Hello world"
250+
},
251+
"finish_reason": "stop"
252+
}
253+
],
254+
"usage": {
255+
"prompt_tokens": 10,
256+
"completion_tokens": 5,
257+
"total_tokens": 15
258+
}
259+
}
260+
""");
261+
mockWebServer.enqueue(mockResponse);
262+
263+
OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world",
264+
OpenAiApi.ChatCompletionMessage.Role.USER);
265+
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(
266+
List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, false);
267+
268+
MultiValueMap<String, String> additionalHeaders = new LinkedMultiValueMap<>();
269+
additionalHeaders.add(HttpHeaders.AUTHORIZATION, "Bearer additional-key");
270+
ResponseEntity<OpenAiApi.ChatCompletion> response = api.chatCompletionEntity(request, additionalHeaders);
271+
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK);
272+
RecordedRequest recordedRequest = mockWebServer.takeRequest();
273+
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer additional-key");
274+
}
275+
230276
@Test
231277
void dynamicApiKeyWebClient() throws InterruptedException {
232278
Queue<ApiKey> apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2")));
@@ -279,6 +325,53 @@ void dynamicApiKeyWebClient() throws InterruptedException {
279325
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2");
280326
}
281327

328+
@Test
329+
void dynamicApiKeyWebClientWithAdditionalAuthorizationHeader() throws InterruptedException {
330+
OpenAiApi api = OpenAiApi.builder().apiKey(() -> {
331+
throw new AssertionFailedError("Should not be called, API key is provided in headers");
332+
}).baseUrl(mockWebServer.url("/").toString()).build();
333+
334+
MockResponse mockResponse = new MockResponse().setResponseCode(200)
335+
.addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
336+
.setBody("""
337+
{
338+
"id": "chatcmpl-12345",
339+
"object": "chat.completion",
340+
"created": 1677858242,
341+
"model": "gpt-3.5-turbo",
342+
"choices": [
343+
{
344+
"index": 0,
345+
"message": {
346+
"role": "assistant",
347+
"content": "Hello world"
348+
},
349+
"finish_reason": "stop"
350+
}
351+
],
352+
"usage": {
353+
"prompt_tokens": 10,
354+
"completion_tokens": 5,
355+
"total_tokens": 15
356+
}
357+
}
358+
""".replace("\n", ""));
359+
mockWebServer.enqueue(mockResponse);
360+
361+
OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world",
362+
OpenAiApi.ChatCompletionMessage.Role.USER);
363+
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(
364+
List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, true);
365+
MultiValueMap<String, String> additionalHeaders = new LinkedMultiValueMap<>();
366+
additionalHeaders.add(HttpHeaders.AUTHORIZATION, "Bearer additional-key");
367+
List<OpenAiApi.ChatCompletionChunk> response = api.chatCompletionStream(request, additionalHeaders)
368+
.collectList()
369+
.block();
370+
assertThat(response).hasSize(1);
371+
RecordedRequest recordedRequest = mockWebServer.takeRequest();
372+
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer additional-key");
373+
}
374+
282375
}
283376

284377
}

0 commit comments

Comments
 (0)