|
16 | 16 |
|
17 | 17 | package org.springframework.ai.openai.api; |
18 | 18 |
|
| 19 | +import java.io.IOException; |
| 20 | +import java.util.LinkedList; |
| 21 | +import java.util.List; |
| 22 | +import java.util.Objects; |
| 23 | +import java.util.Queue; |
| 24 | + |
| 25 | +import okhttp3.mockwebserver.MockResponse; |
| 26 | +import okhttp3.mockwebserver.MockWebServer; |
| 27 | +import okhttp3.mockwebserver.RecordedRequest; |
| 28 | + |
| 29 | +import org.junit.jupiter.api.AfterEach; |
| 30 | +import org.junit.jupiter.api.BeforeEach; |
| 31 | +import org.junit.jupiter.api.Nested; |
19 | 32 | import org.junit.jupiter.api.Test; |
20 | 33 |
|
21 | 34 | import org.springframework.ai.model.ApiKey; |
22 | 35 | import org.springframework.ai.model.SimpleApiKey; |
| 36 | +import org.springframework.http.HttpHeaders; |
| 37 | +import org.springframework.http.HttpStatus; |
| 38 | +import org.springframework.http.MediaType; |
| 39 | +import org.springframework.http.ResponseEntity; |
23 | 40 | import org.springframework.util.LinkedMultiValueMap; |
24 | 41 | import org.springframework.util.MultiValueMap; |
25 | 42 | import org.springframework.web.client.ResponseErrorHandler; |
@@ -142,4 +159,126 @@ void testInvalidResponseErrorHandler() { |
142 | 159 | .hasMessageContaining("responseErrorHandler cannot be null"); |
143 | 160 | } |
144 | 161 |
|
| 162 | + @Nested |
| 163 | + class MockRequests { |
| 164 | + |
| 165 | + MockWebServer mockWebServer; |
| 166 | + |
| 167 | + @BeforeEach |
| 168 | + void setUp() throws IOException { |
| 169 | + mockWebServer = new MockWebServer(); |
| 170 | + mockWebServer.start(); |
| 171 | + } |
| 172 | + |
| 173 | + @AfterEach |
| 174 | + void tearDown() throws IOException { |
| 175 | + mockWebServer.shutdown(); |
| 176 | + } |
| 177 | + |
| 178 | + @Test |
| 179 | + void dynamicApiKeyRestClient() throws InterruptedException { |
| 180 | + Queue<ApiKey> apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); |
| 181 | + OpenAiApi api = OpenAiApi.builder() |
| 182 | + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) |
| 183 | + .baseUrl(mockWebServer.url("/").toString()) |
| 184 | + .build(); |
| 185 | + |
| 186 | + MockResponse mockResponse = new MockResponse().setResponseCode(200) |
| 187 | + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) |
| 188 | + .setBody(""" |
| 189 | + { |
| 190 | + "id": "chatcmpl-12345", |
| 191 | + "object": "chat.completion", |
| 192 | + "created": 1677858242, |
| 193 | + "model": "gpt-3.5-turbo", |
| 194 | + "choices": [ |
| 195 | + { |
| 196 | + "index": 0, |
| 197 | + "message": { |
| 198 | + "role": "assistant", |
| 199 | + "content": "Hello world" |
| 200 | + }, |
| 201 | + "finish_reason": "stop" |
| 202 | + } |
| 203 | + ], |
| 204 | + "usage": { |
| 205 | + "prompt_tokens": 10, |
| 206 | + "completion_tokens": 5, |
| 207 | + "total_tokens": 15 |
| 208 | + } |
| 209 | + } |
| 210 | + """); |
| 211 | + mockWebServer.enqueue(mockResponse); |
| 212 | + mockWebServer.enqueue(mockResponse); |
| 213 | + |
| 214 | + OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world", |
| 215 | + OpenAiApi.ChatCompletionMessage.Role.USER); |
| 216 | + OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest( |
| 217 | + List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, false); |
| 218 | + ResponseEntity<OpenAiApi.ChatCompletion> response = api.chatCompletionEntity(request); |
| 219 | + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); |
| 220 | + RecordedRequest recordedRequest = mockWebServer.takeRequest(); |
| 221 | + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); |
| 222 | + |
| 223 | + response = api.chatCompletionEntity(request); |
| 224 | + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); |
| 225 | + |
| 226 | + recordedRequest = mockWebServer.takeRequest(); |
| 227 | + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); |
| 228 | + } |
| 229 | + |
| 230 | + @Test |
| 231 | + void dynamicApiKeyWebClient() throws InterruptedException { |
| 232 | + Queue<ApiKey> apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); |
| 233 | + OpenAiApi api = OpenAiApi.builder() |
| 234 | + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) |
| 235 | + .baseUrl(mockWebServer.url("/").toString()) |
| 236 | + .build(); |
| 237 | + |
| 238 | + MockResponse mockResponse = new MockResponse().setResponseCode(200) |
| 239 | + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) |
| 240 | + .setBody(""" |
| 241 | + { |
| 242 | + "id": "chatcmpl-12345", |
| 243 | + "object": "chat.completion", |
| 244 | + "created": 1677858242, |
| 245 | + "model": "gpt-3.5-turbo", |
| 246 | + "choices": [ |
| 247 | + { |
| 248 | + "index": 0, |
| 249 | + "message": { |
| 250 | + "role": "assistant", |
| 251 | + "content": "Hello world" |
| 252 | + }, |
| 253 | + "finish_reason": "stop" |
| 254 | + } |
| 255 | + ], |
| 256 | + "usage": { |
| 257 | + "prompt_tokens": 10, |
| 258 | + "completion_tokens": 5, |
| 259 | + "total_tokens": 15 |
| 260 | + } |
| 261 | + } |
| 262 | + """.replace("\n", "")); |
| 263 | + mockWebServer.enqueue(mockResponse); |
| 264 | + mockWebServer.enqueue(mockResponse); |
| 265 | + |
| 266 | + OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world", |
| 267 | + OpenAiApi.ChatCompletionMessage.Role.USER); |
| 268 | + OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest( |
| 269 | + List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, true); |
| 270 | + List<OpenAiApi.ChatCompletionChunk> response = api.chatCompletionStream(request).collectList().block(); |
| 271 | + assertThat(response).hasSize(1); |
| 272 | + RecordedRequest recordedRequest = mockWebServer.takeRequest(); |
| 273 | + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); |
| 274 | + |
| 275 | + response = api.chatCompletionStream(request).collectList().block(); |
| 276 | + assertThat(response).hasSize(1); |
| 277 | + |
| 278 | + recordedRequest = mockWebServer.takeRequest(); |
| 279 | + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); |
| 280 | + } |
| 281 | + |
| 282 | + } |
| 283 | + |
145 | 284 | } |
0 commit comments