|
49 | 49 | import org.springframework.util.MimeTypeUtils; |
50 | 50 |
|
51 | 51 | import static org.assertj.core.api.Assertions.assertThat; |
| 52 | +import static org.mockito.ArgumentMatchers.matches; |
52 | 53 |
|
53 | 54 | @SpringBootTest(classes = BedrockConverseTestConfiguration.class) |
54 | 55 | @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") |
@@ -227,6 +228,41 @@ void functionCallTest() { |
227 | 228 | assertThat(response).contains("30", "10", "15"); |
228 | 229 | } |
229 | 230 |
|
| 231 | + @Test |
| 232 | + void functionCallWithUsageMetadataTest() { |
| 233 | + |
| 234 | + // @formatter:off |
| 235 | + ChatResponse response = ChatClient.create(this.chatModel) |
| 236 | + .prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.") |
| 237 | + .functions(FunctionCallback.builder() |
| 238 | + .description("Get the weather in location") |
| 239 | + .function("getCurrentWeather", new MockWeatherService()) |
| 240 | + .inputType(MockWeatherService.Request.class) |
| 241 | + .build()) |
| 242 | + .call() |
| 243 | + .chatResponse(); |
| 244 | + // @formatter:on |
| 245 | + |
| 246 | + var metadata = response.getMetadata(); |
| 247 | + |
| 248 | + assertThat(metadata.getUsage()).isNotNull(); |
| 249 | + |
| 250 | + logger.info(metadata.getUsage().toString()); |
| 251 | + |
| 252 | + assertThat(metadata.getUsage().getPromptTokens()).isGreaterThan(500); |
| 253 | + assertThat(metadata.getUsage().getPromptTokens()).isLessThan(3500); |
| 254 | + |
| 255 | + assertThat(metadata.getUsage().getGenerationTokens()).isGreaterThan(0); |
| 256 | + assertThat(metadata.getUsage().getGenerationTokens()).isLessThan(1500); |
| 257 | + |
| 258 | + assertThat(metadata.getUsage().getTotalTokens()) |
| 259 | + .isEqualTo(metadata.getUsage().getPromptTokens() + metadata.getUsage().getGenerationTokens()); |
| 260 | + |
| 261 | + logger.info("Response: {}", response); |
| 262 | + |
| 263 | + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); |
| 264 | + } |
| 265 | + |
230 | 266 | @Test |
231 | 267 | void functionCallWithAdvisorTest() { |
232 | 268 |
|
@@ -287,13 +323,24 @@ void streamFunctionCallTest() { |
287 | 323 |
|
288 | 324 | List<ChatResponse> chatResponses = response.collectList().block(); |
289 | 325 |
|
290 | | - chatResponses.forEach(cr -> logger.info("Response: {}", cr)); |
| 326 | + // chatResponses.forEach(cr -> logger.info("Response: {}", cr)); |
| 327 | + var lastChatResponse = chatResponses.get(chatResponses.size() - 1); |
| 328 | + var metadata = lastChatResponse.getMetadata(); |
| 329 | + assertThat(metadata.getUsage()).isNotNull(); |
291 | 330 |
|
292 | | - List<ChatResponse> chatResponses2 = chatResponses.stream() |
293 | | - .filter(cr -> cr.getResult() != null) |
294 | | - .collect(Collectors.toList()); |
| 331 | + logger.info(metadata.getUsage().toString()); |
| 332 | + |
| 333 | + assertThat(metadata.getUsage().getPromptTokens()).isGreaterThan(1500); |
| 334 | + assertThat(metadata.getUsage().getPromptTokens()).isLessThan(3500); |
295 | 335 |
|
296 | | - String content = chatResponses2.stream() |
| 336 | + assertThat(metadata.getUsage().getGenerationTokens()).isGreaterThan(0); |
| 337 | + assertThat(metadata.getUsage().getGenerationTokens()).isLessThan(1500); |
| 338 | + |
| 339 | + assertThat(metadata.getUsage().getTotalTokens()) |
| 340 | + .isEqualTo(metadata.getUsage().getPromptTokens() + metadata.getUsage().getGenerationTokens()); |
| 341 | + |
| 342 | + String content = chatResponses.stream() |
| 343 | + .filter(cr -> cr.getResult() != null) |
297 | 344 | .map(cr -> cr.getResult().getOutput().getContent()) |
298 | 345 | .collect(Collectors.joining()); |
299 | 346 | logger.info("Response: {}", content); |
|
0 commit comments