| 
25 | 25 | import okhttp3.mockwebserver.MockResponse;  | 
26 | 26 | import okhttp3.mockwebserver.MockWebServer;  | 
27 | 27 | import okhttp3.mockwebserver.RecordedRequest;  | 
28 |  | - | 
29 | 28 | import org.junit.jupiter.api.AfterEach;  | 
30 | 29 | import org.junit.jupiter.api.BeforeEach;  | 
31 | 30 | import org.junit.jupiter.api.Nested;  | 
32 | 31 | import org.junit.jupiter.api.Test;  | 
 | 32 | +import org.opentest4j.AssertionFailedError;  | 
33 | 33 | 
 
  | 
34 | 34 | import org.springframework.ai.model.ApiKey;  | 
35 | 35 | import org.springframework.ai.model.SimpleApiKey;  | 
@@ -227,6 +227,52 @@ void dynamicApiKeyRestClient() throws InterruptedException {  | 
227 | 227 | 			assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2");  | 
228 | 228 | 		}  | 
229 | 229 | 
 
  | 
 | 230 | +		@Test  | 
 | 231 | +		void dynamicApiKeyRestClientWithAdditionalAuthorizationHeader() throws InterruptedException {  | 
 | 232 | +			OpenAiApi api = OpenAiApi.builder().apiKey(() -> {  | 
 | 233 | +				throw new AssertionFailedError("Should not be called, API key is provided in headers");  | 
 | 234 | +			}).baseUrl(mockWebServer.url("/").toString()).build();  | 
 | 235 | + | 
 | 236 | +			MockResponse mockResponse = new MockResponse().setResponseCode(200)  | 
 | 237 | +				.addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)  | 
 | 238 | +				.setBody("""  | 
 | 239 | +						{  | 
 | 240 | +							"id": "chatcmpl-12345",  | 
 | 241 | +							"object": "chat.completion",  | 
 | 242 | +							"created": 1677858242,  | 
 | 243 | +							"model": "gpt-3.5-turbo",  | 
 | 244 | +							"choices": [  | 
 | 245 | +								{  | 
 | 246 | +						    		"index": 0,  | 
 | 247 | +									"message": {  | 
 | 248 | +									"role": "assistant",  | 
 | 249 | +									"content": "Hello world"  | 
 | 250 | +									},  | 
 | 251 | +									"finish_reason": "stop"  | 
 | 252 | +								}  | 
 | 253 | +							],  | 
 | 254 | +							"usage": {  | 
 | 255 | +								"prompt_tokens": 10,  | 
 | 256 | +								"completion_tokens": 5,  | 
 | 257 | +								"total_tokens": 15  | 
 | 258 | +							}  | 
 | 259 | +						}  | 
 | 260 | +						""");  | 
 | 261 | +			mockWebServer.enqueue(mockResponse);  | 
 | 262 | + | 
 | 263 | +			OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world",  | 
 | 264 | +					OpenAiApi.ChatCompletionMessage.Role.USER);  | 
 | 265 | +			OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(  | 
 | 266 | +					List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, false);  | 
 | 267 | + | 
 | 268 | +			MultiValueMap<String, String> additionalHeaders = new LinkedMultiValueMap<>();  | 
 | 269 | +			additionalHeaders.add(HttpHeaders.AUTHORIZATION, "Bearer additional-key");  | 
 | 270 | +			ResponseEntity<OpenAiApi.ChatCompletion> response = api.chatCompletionEntity(request, additionalHeaders);  | 
 | 271 | +			assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK);  | 
 | 272 | +			RecordedRequest recordedRequest = mockWebServer.takeRequest();  | 
 | 273 | +			assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer additional-key");  | 
 | 274 | +		}  | 
 | 275 | + | 
230 | 276 | 		@Test  | 
231 | 277 | 		void dynamicApiKeyWebClient() throws InterruptedException {  | 
232 | 278 | 			Queue<ApiKey> apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2")));  | 
@@ -279,6 +325,53 @@ void dynamicApiKeyWebClient() throws InterruptedException {  | 
279 | 325 | 			assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2");  | 
280 | 326 | 		}  | 
281 | 327 | 
 
  | 
 | 328 | +		@Test  | 
 | 329 | +		void dynamicApiKeyWebClientWithAdditionalAuthorizationHeader() throws InterruptedException {  | 
 | 330 | +			OpenAiApi api = OpenAiApi.builder().apiKey(() -> {  | 
 | 331 | +				throw new AssertionFailedError("Should not be called, API key is provided in headers");  | 
 | 332 | +			}).baseUrl(mockWebServer.url("/").toString()).build();  | 
 | 333 | + | 
 | 334 | +			MockResponse mockResponse = new MockResponse().setResponseCode(200)  | 
 | 335 | +				.addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)  | 
 | 336 | +				.setBody("""  | 
 | 337 | +						{  | 
 | 338 | +							"id": "chatcmpl-12345",  | 
 | 339 | +							"object": "chat.completion",  | 
 | 340 | +							"created": 1677858242,  | 
 | 341 | +							"model": "gpt-3.5-turbo",  | 
 | 342 | +							"choices": [  | 
 | 343 | +								{  | 
 | 344 | +						    		"index": 0,  | 
 | 345 | +									"message": {  | 
 | 346 | +									"role": "assistant",  | 
 | 347 | +									"content": "Hello world"  | 
 | 348 | +									},  | 
 | 349 | +									"finish_reason": "stop"  | 
 | 350 | +								}  | 
 | 351 | +							],  | 
 | 352 | +							"usage": {  | 
 | 353 | +								"prompt_tokens": 10,  | 
 | 354 | +								"completion_tokens": 5,  | 
 | 355 | +								"total_tokens": 15  | 
 | 356 | +							}  | 
 | 357 | +						}  | 
 | 358 | +						""".replace("\n", ""));  | 
 | 359 | +			mockWebServer.enqueue(mockResponse);  | 
 | 360 | + | 
 | 361 | +			OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world",  | 
 | 362 | +					OpenAiApi.ChatCompletionMessage.Role.USER);  | 
 | 363 | +			OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(  | 
 | 364 | +					List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, true);  | 
 | 365 | +			MultiValueMap<String, String> additionalHeaders = new LinkedMultiValueMap<>();  | 
 | 366 | +			additionalHeaders.add(HttpHeaders.AUTHORIZATION, "Bearer additional-key");  | 
 | 367 | +			List<OpenAiApi.ChatCompletionChunk> response = api.chatCompletionStream(request, additionalHeaders)  | 
 | 368 | +				.collectList()  | 
 | 369 | +				.block();  | 
 | 370 | +			assertThat(response).hasSize(1);  | 
 | 371 | +			RecordedRequest recordedRequest = mockWebServer.takeRequest();  | 
 | 372 | +			assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer additional-key");  | 
 | 373 | +		}  | 
 | 374 | + | 
282 | 375 | 	}  | 
283 | 376 | 
 
  | 
284 | 377 | }  | 
0 commit comments