|
19 | 19 | import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS; |
20 | 20 | import static org.assertj.core.api.Assertions.assertThat; |
21 | 21 |
|
| 22 | +import java.util.Arrays; |
22 | 23 | import java.util.stream.Collectors; |
23 | 24 |
|
24 | 25 | import org.junit.jupiter.api.Test; |
@@ -47,33 +48,42 @@ public class AzureOpenAiChatClientTest { |
47 | 48 | private ChatClient chatClient; |
48 | 49 |
|
49 | 50 | @Test |
50 | | - void basicAzureOpenAiChatClientStreaming() { |
51 | | - String stitchedResponseContent = chatClient.prompt( |
52 | | - "Name all states in the USA and their capitals, add a space followed by a hyphen, then another space between the two") |
| 51 | + void streamingAndImperativeResponsesContainIdenticalRelevantResults() { |
| 52 | + String prompt = "Name all states in the USA and their capitals, add a space followed by a hyphen, then another space between the two. " |
| 53 | + + "List them with a numerical index. Do not use any abbreviations in state or capitals."; |
| 54 | + |
| 55 | + // Imperative call |
| 56 | + String rawDataFromImperativeCall = chatClient.prompt(prompt).call().content(); |
| 57 | + String imperativeStatesData = extractStatesData(rawDataFromImperativeCall); |
| 58 | + String formattedImperativeResponse = formatResponse(imperativeStatesData); |
| 59 | + |
| 60 | + // Streaming call |
| 61 | + String stitchedResponseFromStream = chatClient.prompt(prompt) |
53 | 62 | .stream() |
54 | 63 | .content() |
55 | 64 | .collectList() |
56 | 65 | .block() |
57 | 66 | .stream() |
58 | 67 | .collect(Collectors.joining()); |
59 | | - verifyResponse(stitchedResponseContent); |
| 68 | + String streamingStatesData = extractStatesData(stitchedResponseFromStream); |
| 69 | + String formattedStreamingResponse = formatResponse(streamingStatesData); |
| 70 | + |
| 71 | + // Assertions |
| 72 | + assertThat(formattedStreamingResponse).isEqualTo(formattedImperativeResponse); |
| 73 | + assertThat(formattedStreamingResponse).contains("1. Alabama - Montgomery"); |
| 74 | + assertThat(formattedStreamingResponse).contains("50. Wyoming - Cheyenne"); |
| 75 | + assertThat(formattedStreamingResponse.lines().count()).isEqualTo(50); |
60 | 76 | } |
61 | 77 |
|
62 | | - @Test |
63 | | - void basicAzureOpenAiChatClientImperative() { |
64 | | - String stitchedResponseContent = chatClient.prompt( |
65 | | - "Name all states in the USA and their capitals, add a space followed by a hyphen, then another space between the two") |
66 | | - .call() |
67 | | - .content(); |
68 | | - verifyResponse(stitchedResponseContent); |
| 78 | + private String extractStatesData(String rawData) { |
| 79 | + int firstStateIndex = rawData.indexOf("1. Alabama - Montgomery"); |
| 80 | + String lastAlphabeticalState = "50. Wyoming - Cheyenne"; |
| 81 | + int lastStateIndex = rawData.indexOf(lastAlphabeticalState) + lastAlphabeticalState.length(); |
| 82 | + return rawData.substring(firstStateIndex, lastStateIndex); |
69 | 83 | } |
70 | 84 |
|
71 | | - private static void verifyResponse(String stitchedResponseContent) { |
72 | | - assertThat(stitchedResponseContent).contains("Alabama - Montgomery"); |
73 | | - assertThat(stitchedResponseContent).contains("New York - Albany"); |
74 | | - assertThat(stitchedResponseContent).contains("Pennsylvania - Harrisburg"); |
75 | | - assertThat(stitchedResponseContent).contains("Tennessee - Nashville"); |
76 | | - assertThat(stitchedResponseContent).contains("Wyoming - Cheyenne"); |
| 85 | + private String formatResponse(String response) { |
| 86 | + return String.join("\n", Arrays.stream(response.split("\n")).map(String::strip).toArray(String[]::new)); |
77 | 87 | } |
78 | 88 |
|
79 | 89 | @SpringBootConfiguration |
|
0 commit comments