Skip to content

Commit 625a557

Browse files
committed
minor tests adjustments
1 parent dfb4b47 commit 625a557

File tree

1 file changed

+52
-5
lines changed

1 file changed

+52
-5
lines changed

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import org.springframework.util.MimeTypeUtils;
5050

5151
import static org.assertj.core.api.Assertions.assertThat;
52+
import static org.mockito.ArgumentMatchers.matches;
5253

5354
@SpringBootTest(classes = BedrockConverseTestConfiguration.class)
5455
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
@@ -227,6 +228,41 @@ void functionCallTest() {
227228
assertThat(response).contains("30", "10", "15");
228229
}
229230

231+
@Test
232+
void functionCallWithUsageMetadataTest() {
233+
234+
// @formatter:off
235+
ChatResponse response = ChatClient.create(this.chatModel)
236+
.prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
237+
.functions(FunctionCallback.builder()
238+
.description("Get the weather in location")
239+
.function("getCurrentWeather", new MockWeatherService())
240+
.inputType(MockWeatherService.Request.class)
241+
.build())
242+
.call()
243+
.chatResponse();
244+
// @formatter:on
245+
246+
var metadata = response.getMetadata();
247+
248+
assertThat(metadata.getUsage()).isNotNull();
249+
250+
logger.info(metadata.getUsage().toString());
251+
252+
assertThat(metadata.getUsage().getPromptTokens()).isGreaterThan(500);
253+
assertThat(metadata.getUsage().getPromptTokens()).isLessThan(3500);
254+
255+
assertThat(metadata.getUsage().getGenerationTokens()).isGreaterThan(0);
256+
assertThat(metadata.getUsage().getGenerationTokens()).isLessThan(1500);
257+
258+
assertThat(metadata.getUsage().getTotalTokens())
259+
.isEqualTo(metadata.getUsage().getPromptTokens() + metadata.getUsage().getGenerationTokens());
260+
261+
logger.info("Response: {}", response);
262+
263+
assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15");
264+
}
265+
230266
@Test
231267
void functionCallWithAdvisorTest() {
232268

@@ -287,13 +323,24 @@ void streamFunctionCallTest() {
287323

288324
List<ChatResponse> chatResponses = response.collectList().block();
289325

290-
chatResponses.forEach(cr -> logger.info("Response: {}", cr));
326+
// chatResponses.forEach(cr -> logger.info("Response: {}", cr));
327+
var lastChatResponse = chatResponses.get(chatResponses.size() - 1);
328+
var metadata = lastChatResponse.getMetadata();
329+
assertThat(metadata.getUsage()).isNotNull();
291330

292-
List<ChatResponse> chatResponses2 = chatResponses.stream()
293-
.filter(cr -> cr.getResult() != null)
294-
.collect(Collectors.toList());
331+
logger.info(metadata.getUsage().toString());
332+
333+
assertThat(metadata.getUsage().getPromptTokens()).isGreaterThan(1500);
334+
assertThat(metadata.getUsage().getPromptTokens()).isLessThan(3500);
295335

296-
String content = chatResponses2.stream()
336+
assertThat(metadata.getUsage().getGenerationTokens()).isGreaterThan(0);
337+
assertThat(metadata.getUsage().getGenerationTokens()).isLessThan(1500);
338+
339+
assertThat(metadata.getUsage().getTotalTokens())
340+
.isEqualTo(metadata.getUsage().getPromptTokens() + metadata.getUsage().getGenerationTokens());
341+
342+
String content = chatResponses.stream()
343+
.filter(cr -> cr.getResult() != null)
297344
.map(cr -> cr.getResult().getOutput().getContent())
298345
.collect(Collectors.joining());
299346
logger.info("Response: {}", content);

0 commit comments

Comments
 (0)