Skip to content

Commit 9ab2a4b

Browse files
committed
test: Add comprehensive tests for OpenAI API builder, mutation, and streaming functionality
Co-authored-by: Oleksandr Klymenko <[email protected]> Signed-off-by: Oleksandr Klymenko <[email protected]>
1 parent 3e17e16 commit 9ab2a4b

File tree

3 files changed

+272
-0
lines changed

3 files changed

+272
-0
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,4 +425,120 @@ void dynamicApiKeyRestClientEmbeddings() throws InterruptedException {
425425

426426
}
427427

428+
@Test
429+
void testBuilderWithAllCustomPaths() {
430+
OpenAiApi api = OpenAiApi.builder()
431+
.apiKey(TEST_API_KEY)
432+
.baseUrl(TEST_BASE_URL)
433+
.completionsPath("/custom/completions")
434+
.embeddingsPath("/custom/embeddings")
435+
.build();
436+
437+
assertThat(api).isNotNull();
438+
}
439+
440+
@Test
441+
void testBuilderImmutability() {
442+
OpenAiApi.Builder builder = OpenAiApi.builder().apiKey(TEST_API_KEY).baseUrl(TEST_BASE_URL);
443+
444+
OpenAiApi api1 = builder.build();
445+
OpenAiApi api2 = builder.build();
446+
447+
assertThat(api1).isNotNull();
448+
assertThat(api2).isNotNull();
449+
assertThat(api1).isNotSameAs(api2);
450+
}
451+
452+
@Test
453+
void testNullApiKeyValue() {
454+
assertThatThrownBy(() -> OpenAiApi.builder().apiKey((ApiKey) null).build())
455+
.isInstanceOf(IllegalArgumentException.class)
456+
.hasMessageContaining("apiKey cannot be null");
457+
}
458+
459+
@Test
460+
void testBuilderMethodChaining() {
461+
MultiValueMap<String, String> headers = new LinkedMultiValueMap<>();
462+
headers.add("Test-Header", "test-value");
463+
464+
OpenAiApi api = OpenAiApi.builder()
465+
.apiKey(TEST_API_KEY)
466+
.baseUrl(TEST_BASE_URL)
467+
.completionsPath(TEST_COMPLETIONS_PATH)
468+
.embeddingsPath(TEST_EMBEDDINGS_PATH)
469+
.headers(headers)
470+
.restClientBuilder(RestClient.builder())
471+
.webClientBuilder(WebClient.builder())
472+
.responseErrorHandler(mock(ResponseErrorHandler.class))
473+
.build();
474+
475+
assertThat(api).isNotNull();
476+
}
477+
478+
@Test
479+
void testCustomHeadersPreservation() {
480+
MultiValueMap<String, String> customHeaders = new LinkedMultiValueMap<>();
481+
customHeaders.add("X-Custom-Header", "custom-value");
482+
customHeaders.add("X-Organization", "org-123");
483+
customHeaders.add("User-Agent", "Custom-Client/1.0");
484+
485+
OpenAiApi api = OpenAiApi.builder().apiKey(TEST_API_KEY).headers(customHeaders).build();
486+
487+
assertThat(api).isNotNull();
488+
}
489+
490+
@Test
491+
void testComplexMultiValueHeaders() {
492+
MultiValueMap<String, String> multiHeaders = new LinkedMultiValueMap<>();
493+
multiHeaders.add("Accept", "application/json");
494+
multiHeaders.add("Accept", "text/plain");
495+
multiHeaders.add("Cache-Control", "no-cache");
496+
multiHeaders.add("Cache-Control", "no-store");
497+
498+
OpenAiApi api = OpenAiApi.builder().apiKey(TEST_API_KEY).headers(multiHeaders).build();
499+
500+
assertThat(api).isNotNull();
501+
}
502+
503+
@Test
504+
void testPathValidationWithSlashes() {
505+
OpenAiApi api1 = OpenAiApi.builder()
506+
.apiKey(TEST_API_KEY)
507+
.completionsPath("/v1/completions")
508+
.embeddingsPath("/v1/embeddings")
509+
.build();
510+
511+
OpenAiApi api2 = OpenAiApi.builder()
512+
.apiKey(TEST_API_KEY)
513+
.completionsPath("v1/completions")
514+
.embeddingsPath("v1/embeddings")
515+
.build();
516+
517+
assertThat(api1).isNotNull();
518+
assertThat(api2).isNotNull();
519+
}
520+
521+
@Test
522+
void testInvalidPathsWithSpecialCharacters() {
523+
assertThatThrownBy(() -> OpenAiApi.builder().apiKey(TEST_API_KEY).completionsPath(" ").build())
524+
.isInstanceOf(IllegalArgumentException.class);
525+
}
526+
527+
@Test
528+
void testBuilderWithOnlyRequiredFields() {
529+
OpenAiApi api = OpenAiApi.builder().apiKey(TEST_API_KEY).build();
530+
531+
assertThat(api).isNotNull();
532+
}
533+
534+
@Test
535+
void testDifferentApiKeyTypes() {
536+
SimpleApiKey simpleKey = new SimpleApiKey("simple-key");
537+
OpenAiApi api1 = OpenAiApi.builder().apiKey(simpleKey).build();
538+
assertThat(api1).isNotNull();
539+
540+
OpenAiApi api2 = OpenAiApi.builder().apiKey(() -> "supplier-key").build();
541+
assertThat(api2).isNotNull();
542+
}
543+
428544
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiChatModelMutateTests.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.springframework.util.LinkedMultiValueMap;
2525

2626
import static org.assertj.core.api.Assertions.assertThat;
27+
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
2728

2829
/*
2930
* Integration test for mutate/clone functionality on OpenAiApi and OpenAiChatModel.
@@ -190,4 +191,31 @@ void testMutateBuilderValidation() {
190191
assertThat(unchanged).isNotSameAs(this.baseModel);
191192
}
192193

194+
@Test
195+
void testMutateWithInvalidBaseUrl() {
196+
assertThatThrownBy(() -> this.baseApi.mutate().baseUrl("").build()).isInstanceOf(IllegalArgumentException.class)
197+
.hasMessageContaining("baseUrl");
198+
199+
assertThatThrownBy(() -> this.baseApi.mutate().baseUrl(null).build())
200+
.isInstanceOf(IllegalArgumentException.class)
201+
.hasMessageContaining("baseUrl");
202+
}
203+
204+
@Test
205+
void testMutateWithNullOpenAiApi() {
206+
assertThatThrownBy(() -> this.baseModel.mutate().openAiApi(null).build())
207+
.isInstanceOf(IllegalArgumentException.class);
208+
}
209+
210+
@Test
211+
void testMutatePreservesUnchangedFields() {
212+
String originalBaseUrl = this.baseApi.getBaseUrl();
213+
String newApiKey = "new-test-key";
214+
215+
OpenAiApi mutated = this.baseApi.mutate().apiKey(newApiKey).build();
216+
217+
assertThat(mutated.getBaseUrl()).isEqualTo(originalBaseUrl);
218+
assertThat(mutated.getApiKey().getValue()).isEqualTo(newApiKey);
219+
}
220+
193221
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelperTest.java

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.mockito.Mockito;
2626

2727
import static org.assertj.core.api.Assertions.assertThat;
28+
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
2829

2930
/**
3031
* Unit tests for {@link OpenAiStreamFunctionCallingHelper}
@@ -206,4 +207,131 @@ public void isStreamingToolFunctionCallFinishDetectsToolCallsFinishReason() {
206207
assertThat(this.helper.isStreamingToolFunctionCallFinish(chunk)).isTrue();
207208
}
208209

210+
@Test
211+
public void merge_whenBothChunksAreNull() {
212+
var result = this.helper.merge(null, null);
213+
assertThat(result).isNull();
214+
}
215+
216+
@Test
217+
public void merge_whenPreviousIsNull() {
218+
var current = new OpenAiApi.ChatCompletionChunk("id", Collections.emptyList(), System.currentTimeMillis(),
219+
"model", "default", "fingerprint", "object", null);
220+
221+
var result = this.helper.merge(null, current);
222+
assertThat(result).isEqualTo(current);
223+
}
224+
225+
@Test
226+
public void merge_whenCurrentIsNull() {
227+
var previous = new OpenAiApi.ChatCompletionChunk("id", Collections.emptyList(), System.currentTimeMillis(),
228+
"model", "default", "fingerprint", "object", null);
229+
230+
var result = this.helper.merge(previous, null);
231+
assertThat(result).isEqualTo(previous);
232+
}
233+
234+
@Test
235+
public void merge_partialFieldsFromEachChunk() {
236+
var choices = List.of(Mockito.mock(OpenAiApi.ChatCompletionChunk.ChunkChoice.class));
237+
var usage = Mockito.mock(OpenAiApi.Usage.class);
238+
239+
var previous = new OpenAiApi.ChatCompletionChunk(null, choices, 1L, "model1", null, "fp1", null, null);
240+
var current = new OpenAiApi.ChatCompletionChunk("id2", null, null, null, "tier2", null, "object2", usage);
241+
242+
var result = this.helper.merge(previous, current);
243+
244+
assertThat(result.id()).isEqualTo("id2");
245+
assertThat(result.choices()).isEqualTo(choices);
246+
assertThat(result.created()).isEqualTo(1L);
247+
assertThat(result.model()).isEqualTo("model1");
248+
assertThat(result.serviceTier()).isEqualTo("tier2");
249+
assertThat(result.systemFingerprint()).isEqualTo("fp1");
250+
assertThat(result.object()).isEqualTo("object2");
251+
assertThat(result.usage()).isEqualTo(usage);
252+
}
253+
254+
@Test
255+
public void isStreamingToolFunctionCall_withMultipleChoicesAndOnlyFirstHasToolCalls() {
256+
var toolCall = Mockito.mock(OpenAiApi.ChatCompletionMessage.ToolCall.class);
257+
var deltaWithToolCalls = new OpenAiApi.ChatCompletionMessage(null, null, null, null, List.of(toolCall), null,
258+
null, null);
259+
var deltaWithoutToolCalls = new OpenAiApi.ChatCompletionMessage(null, null);
260+
261+
var choice1 = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, deltaWithToolCalls, null);
262+
var choice2 = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, deltaWithoutToolCalls, null);
263+
264+
var chunk = new OpenAiApi.ChatCompletionChunk(null, List.of(choice1, choice2), null, null, null, null, null,
265+
null);
266+
267+
assertThat(this.helper.isStreamingToolFunctionCall(chunk)).isTrue();
268+
}
269+
270+
@Test
271+
public void isStreamingToolFunctionCall_withMultipleChoicesAndNoneHaveToolCalls() {
272+
var deltaWithoutToolCalls = new OpenAiApi.ChatCompletionMessage(null, null);
273+
274+
var choice1 = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, deltaWithoutToolCalls, null);
275+
var choice2 = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, deltaWithoutToolCalls, null);
276+
277+
var chunk = new OpenAiApi.ChatCompletionChunk(null, List.of(choice1, choice2), null, null, null, null, null,
278+
null);
279+
280+
assertThat(this.helper.isStreamingToolFunctionCall(chunk)).isFalse();
281+
}
282+
283+
@Test
284+
public void isStreamingToolFunctionCallFinish_withMultipleChoicesAndOnlyFirstIsToolCallsFinish() {
285+
var choice1 = new OpenAiApi.ChatCompletionChunk.ChunkChoice(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS,
286+
null, new OpenAiApi.ChatCompletionMessage(null, null), null);
287+
var choice2 = new OpenAiApi.ChatCompletionChunk.ChunkChoice(OpenAiApi.ChatCompletionFinishReason.STOP, null,
288+
new OpenAiApi.ChatCompletionMessage(null, null), null);
289+
290+
var chunk = new OpenAiApi.ChatCompletionChunk(null, List.of(choice1, choice2), null, null, null, null, null,
291+
null);
292+
293+
assertThat(this.helper.isStreamingToolFunctionCallFinish(chunk)).isTrue();
294+
}
295+
296+
@Test
297+
public void chunkToChatCompletion_whenChunkIsNull() {
298+
assertThatThrownBy(() -> this.helper.chunkToChatCompletion(null)).isInstanceOf(NullPointerException.class);
299+
}
300+
301+
@Test
302+
public void chunkToChatCompletion_withEmptyChoices() {
303+
var chunk = new OpenAiApi.ChatCompletionChunk("id", Collections.emptyList(), 1L, "model", "tier", "fp",
304+
"object", null);
305+
306+
var result = this.helper.chunkToChatCompletion(chunk);
307+
308+
assertThat(result.object()).isEqualTo("chat.completion");
309+
assertThat(result.choices()).isEmpty();
310+
assertThat(result.id()).isEqualTo("id");
311+
assertThat(result.created()).isEqualTo(1L);
312+
assertThat(result.model()).isEqualTo("model");
313+
}
314+
315+
@Test
316+
public void edgeCases_emptyStringFields() {
317+
var chunk = new OpenAiApi.ChatCompletionChunk("", Collections.emptyList(), 0L, "", "", "", "", null);
318+
319+
var result = this.helper.chunkToChatCompletion(chunk);
320+
321+
assertThat(result.id()).isEmpty();
322+
assertThat(result.model()).isEmpty();
323+
assertThat(result.serviceTier()).isEmpty();
324+
assertThat(result.systemFingerprint()).isEmpty();
325+
assertThat(result.created()).isEqualTo(0L);
326+
}
327+
328+
@Test
329+
public void isStreamingToolFunctionCall_withNullToolCallsList() {
330+
var delta = new OpenAiApi.ChatCompletionMessage(null, null, null, null, null, null, null, null);
331+
var choice = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, delta, null);
332+
var chunk = new OpenAiApi.ChatCompletionChunk(null, List.of(choice), null, null, null, null, null, null);
333+
334+
assertThat(this.helper.isStreamingToolFunctionCall(chunk)).isFalse();
335+
}
336+
209337
}

0 commit comments

Comments
 (0)