Skip to content

Commit d420919

Browse files
committed
Test improvements
1 parent 85baee9 commit d420919

File tree

1 file changed

+27
-17
lines changed

1 file changed

+27
-17
lines changed

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientTest.java

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS;
2020
import static org.assertj.core.api.Assertions.assertThat;
2121

22+
import java.util.Arrays;
2223
import java.util.stream.Collectors;
2324

2425
import org.junit.jupiter.api.Test;
@@ -47,33 +48,42 @@ public class AzureOpenAiChatClientTest {
4748
private ChatClient chatClient;
4849

4950
@Test
50-
void basicAzureOpenAiChatClientStreaming() {
51-
String stitchedResponseContent = chatClient.prompt(
52-
"Name all states in the USA and their capitals, add a space followed by a hyphen, then another space between the two")
51+
void streamingAndImperativeResponsesContainIdenticalRelevantResults() {
52+
String prompt = "Name all states in the USA and their capitals, add a space followed by a hyphen, then another space between the two. "
53+
+ "List them with a numerical index. Do not use any abbreviations in state or capitals.";
54+
55+
// Imperative call
56+
String rawDataFromImperativeCall = chatClient.prompt(prompt).call().content();
57+
String imperativeStatesData = extractStatesData(rawDataFromImperativeCall);
58+
String formattedImperativeResponse = formatResponse(imperativeStatesData);
59+
60+
// Streaming call
61+
String stitchedResponseFromStream = chatClient.prompt(prompt)
5362
.stream()
5463
.content()
5564
.collectList()
5665
.block()
5766
.stream()
5867
.collect(Collectors.joining());
59-
verifyResponse(stitchedResponseContent);
68+
String streamingStatesData = extractStatesData(stitchedResponseFromStream);
69+
String formattedStreamingResponse = formatResponse(streamingStatesData);
70+
71+
// Assertions
72+
assertThat(formattedStreamingResponse).isEqualTo(formattedImperativeResponse);
73+
assertThat(formattedStreamingResponse).contains("1. Alabama - Montgomery");
74+
assertThat(formattedStreamingResponse).contains("50. Wyoming - Cheyenne");
75+
assertThat(formattedStreamingResponse.lines().count()).isEqualTo(50);
6076
}
6177

62-
@Test
63-
void basicAzureOpenAiChatClientImperative() {
64-
String stitchedResponseContent = chatClient.prompt(
65-
"Name all states in the USA and their capitals, add a space followed by a hyphen, then another space between the two")
66-
.call()
67-
.content();
68-
verifyResponse(stitchedResponseContent);
78+
private String extractStatesData(String rawData) {
79+
int firstStateIndex = rawData.indexOf("1. Alabama - Montgomery");
80+
String lastAlphabeticalState = "50. Wyoming - Cheyenne";
81+
int lastStateIndex = rawData.indexOf(lastAlphabeticalState) + lastAlphabeticalState.length();
82+
return rawData.substring(firstStateIndex, lastStateIndex);
6983
}
7084

71-
private static void verifyResponse(String stitchedResponseContent) {
72-
assertThat(stitchedResponseContent).contains("Alabama - Montgomery");
73-
assertThat(stitchedResponseContent).contains("New York - Albany");
74-
assertThat(stitchedResponseContent).contains("Pennsylvania - Harrisburg");
75-
assertThat(stitchedResponseContent).contains("Tennessee - Nashville");
76-
assertThat(stitchedResponseContent).contains("Wyoming - Cheyenne");
85+
private String formatResponse(String response) {
86+
return String.join("\n", Arrays.stream(response.split("\n")).map(String::strip).toArray(String[]::new));
7787
}
7888

7989
@SpringBootConfiguration

0 commit comments

Comments
 (0)