Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient,
*/
@Override
public ChatResponse call(Prompt prompt) {
return this.internalCall(prompt, null);
}

private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatResponse) {

ConverseRequest converseRequest = this.createRequest(prompt);

Expand All @@ -185,7 +189,7 @@ public ChatResponse call(Prompt prompt) {

ConverseResponse converseResponse = this.bedrockRuntimeClient.converse(converseRequest);

var response = this.toChatResponse(converseResponse);
var response = this.toChatResponse(converseResponse, perviousChatResponse);

observationContext.setResponse(response);

Expand All @@ -195,7 +199,7 @@ public ChatResponse call(Prompt prompt) {
if (!this.isProxyToolCalls(prompt, this.defaultOptions) && chatResponse != null
&& this.isToolCall(chatResponse, Set.of("tool_use"))) {
var toolCallConversation = this.handleToolCalls(prompt, chatResponse);
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need test coverage to test this flow - especially to verify the final ChatResponse containing the preceding one.

}

return chatResponse;
Expand Down Expand Up @@ -402,7 +406,7 @@ else if (mediaData instanceof URL url) {
* @param response The Bedrock Converse response.
* @return The ChatResponse entity.
*/
private ChatResponse toChatResponse(ConverseResponse response) {
private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perviousChatResponse) {

Assert.notNull(response, "'response' must not be null.");

Expand Down Expand Up @@ -448,8 +452,19 @@ private ChatResponse toChatResponse(ConverseResponse response) {
allGenerations.add(toolCallGeneration);
}

DefaultUsage usage = new DefaultUsage(response.usage().inputTokens().longValue(),
response.usage().outputTokens().longValue(), response.usage().totalTokens().longValue());
Long promptTokens = response.usage().inputTokens().longValue();
Long generationTokens = response.usage().outputTokens().longValue();
Long totalTokens = response.usage().totalTokens().longValue();

if (perviousChatResponse != null && perviousChatResponse.getMetadata() != null
&& perviousChatResponse.getMetadata().getUsage() != null) {

promptTokens += perviousChatResponse.getMetadata().getUsage().getPromptTokens();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have manually verified this step. Is there an effective way we can add a test to validate this additive operation in our tests?

generationTokens += perviousChatResponse.getMetadata().getUsage().getGenerationTokens();
totalTokens += perviousChatResponse.getMetadata().getUsage().getTotalTokens();
}

DefaultUsage usage = new DefaultUsage(promptTokens, generationTokens, totalTokens);

Document modelResponseFields = response.additionalModelResponseFields();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't seem to use modelResponseFields and metrics. Also, we don't copy all the metadata attributes into the ChatResponse here.


Expand All @@ -473,14 +488,16 @@ private ChatResponse toChatResponse(ConverseResponse response) {
*/
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return this.internalStream(prompt, null);
}

private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousChatResponse) {
Assert.notNull(prompt, "'prompt' must not be null");

return Flux.deferContextual(contextView -> {

ConverseRequest converseRequest = this.createRequest(prompt);

// System.out.println(">>>>> CONVERSE REQUEST: " + converseRequest);

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AiProvider.BEDROCK_CONVERSE.value())
Expand All @@ -504,13 +521,14 @@ public Flux<ChatResponse> stream(Prompt prompt) {
Flux<ConverseStreamOutput> response = converseStream(converseStreamRequest);

// @formatter:off
Flux<ChatResponse> chatResponses = ConverseApiUtils.toChatResponse(response);
Flux<ChatResponse> chatResponses = ConverseApiUtils.toChatResponse(response, perviousChatResponse);

Flux<ChatResponse> chatResponseFlux = chatResponses.switchMap(chatResponse -> {
if (!this.isProxyToolCalls(prompt, this.defaultOptions) && chatResponse != null
&& this.isToolCall(chatResponse, Set.of("tool_use"))) {

var toolCallConversation = this.handleToolCalls(prompt, chatResponse);
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse);
}
return Mono.just(chatResponse);
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ public static boolean isToolUseFinish(ConverseStreamOutput event) {
return true;
}

public static Flux<ChatResponse> toChatResponse(Flux<ConverseStreamOutput> responses) {
public static Flux<ChatResponse> toChatResponse(Flux<ConverseStreamOutput> responses,
ChatResponse perviousChatResponse) {

AtomicBoolean isInsideTool = new AtomicBoolean(false);

Expand Down Expand Up @@ -120,20 +121,30 @@ public static Flux<ChatResponse> toChatResponse(Flux<ConverseStreamOutput> respo

List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>();

Long promptTokens = 0L;
Long generationTokens = 0L;
Long totalTokens = 0L;

for (ToolUseAggregationEvent.ToolUseEntry toolUseEntry : toolUseAggregationEvent.toolUseEntries()) {
var functionCallId = toolUseEntry.id();
var functionName = toolUseEntry.name();
var functionArguments = toolUseEntry.input();
toolCalls.add(
new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));

if (toolUseEntry.usage() != null) {
promptTokens += toolUseEntry.usage().getPromptTokens();
generationTokens += toolUseEntry.usage().getGenerationTokens();
totalTokens += toolUseEntry.usage().getTotalTokens();
}
}

AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
Generation toolCallGeneration = new Generation(assistantMessage,
ChatGenerationMetadata.from("tool_use", null));

var chatResponseMetaData = ChatResponseMetadata.builder()
.withUsage(toolUseAggregationEvent.usage)
.withUsage(new DefaultUsage(promptTokens, generationTokens, totalTokens))
.build();

return new Aggregation(
Expand Down Expand Up @@ -181,22 +192,22 @@ else if (nextEvent instanceof ContentBlockStopEvent contentBlockStopEvent) {
return new Aggregation();
}
else if (nextEvent instanceof ConverseStreamMetadataEvent metadataEvent) {
// return new Aggregation();

var newMeta = MetadataAggregation.builder()
.copy(lastAggregation.metadataAggregation())
.withTokenUsage(metadataEvent.usage())
.withMetrics(metadataEvent.metrics())
.withTrace(metadataEvent.trace())
.build();

DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens().longValue(),
metadataEvent.usage().outputTokens().longValue(),
metadataEvent.usage().totalTokens().longValue());

// TODO
Document modelResponseFields = lastAggregation.metadataAggregation().additionalModelResponseFields();
ConverseStreamMetrics metrics = metadataEvent.metrics();

DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens().longValue(),
metadataEvent.usage().outputTokens().longValue(),
metadataEvent.usage().totalTokens().longValue());

var chatResponseMetaData = ChatResponseMetadata.builder().withUsage(usage).build();

return new Aggregation(newMeta, new ChatResponse(List.of(), chatResponseMetaData));
Expand All @@ -206,8 +217,42 @@ else if (nextEvent instanceof ConverseStreamMetadataEvent metadataEvent) {
}
})
// .skip(1)
.map(aggregation -> aggregation.chatResponse())
.filter(chatResponse -> chatResponse != ConverseApiUtils.EMPTY_CHAT_RESPONSE);
.filter(aggregation -> aggregation.chatResponse() != ConverseApiUtils.EMPTY_CHAT_RESPONSE)
.map(aggregation -> {

var chatResponse = aggregation.chatResponse();

// Merge the previous chat response metadata with the current one.
if (perviousChatResponse != null && perviousChatResponse.getMetadata() != null
&& perviousChatResponse.getMetadata().getUsage() != null) {

var metadataBuilder = ChatResponseMetadata.builder();

Long promptTokens = perviousChatResponse.getMetadata().getUsage().getPromptTokens();
Long generationTokens = perviousChatResponse.getMetadata().getUsage().getGenerationTokens();
Long totalTokens = perviousChatResponse.getMetadata().getUsage().getTotalTokens();

if (chatResponse.getMetadata() != null) {
metadataBuilder.withId(chatResponse.getMetadata().getId());
metadataBuilder.withModel(chatResponse.getMetadata().getModel());
metadataBuilder.withRateLimit(chatResponse.getMetadata().getRateLimit());
metadataBuilder.withPromptMetadata(chatResponse.getMetadata().getPromptMetadata());

if (chatResponse.getMetadata().getUsage() != null) {
promptTokens = promptTokens + chatResponse.getMetadata().getUsage().getPromptTokens();
generationTokens = generationTokens
+ chatResponse.getMetadata().getUsage().getGenerationTokens();
totalTokens = totalTokens + chatResponse.getMetadata().getUsage().getTotalTokens();
}
}

metadataBuilder.withUsage(new DefaultUsage(promptTokens, generationTokens, totalTokens));

return new ChatResponse(chatResponse.getResults(), metadataBuilder.build());
}

return aggregation.chatResponse();
});
}

public static ConverseStreamOutput mergeToolUseEvents(ConverseStreamOutput previousEvent,
Expand Down Expand Up @@ -245,7 +290,7 @@ else if (event.sdkEventType() == EventType.METADATA) {
DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens().longValue(),
metadataEvent.usage().outputTokens().longValue(), metadataEvent.usage().totalTokens().longValue());
toolUseEventAggregator.withUsage(usage);
// TODO

if (!toolUseEventAggregator.isEmpty()) {
toolUseEventAggregator.squashIntoContentBlock();
return toolUseEventAggregator;
Expand Down Expand Up @@ -400,7 +445,7 @@ ToolUseAggregationEvent appendPartialJson(String partialJson) {
}

void squashIntoContentBlock() {
this.toolUseEntries.add(new ToolUseEntry(this.index, this.id, this.name, this.partialJson));
this.toolUseEntries.add(new ToolUseEntry(this.index, this.id, this.name, this.partialJson, this.usage));
this.index = null;
this.id = null;
this.name = null;
Expand All @@ -424,7 +469,7 @@ public void accept(Visitor visitor) {
throw new UnsupportedOperationException();
}

public record ToolUseEntry(Integer index, String id, String name, String input) {
public record ToolUseEntry(Integer index, String id, String name, String input, DefaultUsage usage) {
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.springframework.util.MimeTypeUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.matches;

@SpringBootTest(classes = BedrockConverseTestConfiguration.class)
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
Expand Down Expand Up @@ -227,6 +228,41 @@ void functionCallTest() {
assertThat(response).contains("30", "10", "15");
}

@Test
void functionCallWithUsageMetadataTest() {

// @formatter:off
ChatResponse response = ChatClient.create(this.chatModel)
.prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
.functions(FunctionCallback.builder()
.description("Get the weather in location")
.function("getCurrentWeather", new MockWeatherService())
.inputType(MockWeatherService.Request.class)
.build())
.call()
.chatResponse();
// @formatter:on

var metadata = response.getMetadata();

assertThat(metadata.getUsage()).isNotNull();

logger.info(metadata.getUsage().toString());

assertThat(metadata.getUsage().getPromptTokens()).isGreaterThan(500);
assertThat(metadata.getUsage().getPromptTokens()).isLessThan(3500);

assertThat(metadata.getUsage().getGenerationTokens()).isGreaterThan(0);
assertThat(metadata.getUsage().getGenerationTokens()).isLessThan(1500);

assertThat(metadata.getUsage().getTotalTokens())
.isEqualTo(metadata.getUsage().getPromptTokens() + metadata.getUsage().getGenerationTokens());

logger.info("Response: {}", response);

assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15");
}

@Test
void functionCallWithAdvisorTest() {

Expand Down Expand Up @@ -274,18 +310,39 @@ void defaultFunctionCallTest() {
void streamFunctionCallTest() {

// @formatter:off
Flux<String> response = ChatClient.create(this.chatModel).prompt()
Flux<ChatResponse> response = ChatClient.create(this.chatModel).prompt()
.user("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
.functions(FunctionCallback.builder()
.description("Get the weather in location")
.function("getCurrentWeather", new MockWeatherService())
.inputType(MockWeatherService.Request.class)
.build())
.stream()
.content();
.chatResponse();
// @formatter:on

String content = response.collectList().block().stream().collect(Collectors.joining());
List<ChatResponse> chatResponses = response.collectList().block();

// chatResponses.forEach(cr -> logger.info("Response: {}", cr));
var lastChatResponse = chatResponses.get(chatResponses.size() - 1);
var metadata = lastChatResponse.getMetadata();
assertThat(metadata.getUsage()).isNotNull();

logger.info(metadata.getUsage().toString());

assertThat(metadata.getUsage().getPromptTokens()).isGreaterThan(1500);
assertThat(metadata.getUsage().getPromptTokens()).isLessThan(3500);

assertThat(metadata.getUsage().getGenerationTokens()).isGreaterThan(0);
assertThat(metadata.getUsage().getGenerationTokens()).isLessThan(1500);

assertThat(metadata.getUsage().getTotalTokens())
.isEqualTo(metadata.getUsage().getPromptTokens() + metadata.getUsage().getGenerationTokens());

String content = chatResponses.stream()
.filter(cr -> cr.getResult() != null)
.map(cr -> cr.getResult().getOutput().getContent())
.collect(Collectors.joining());
logger.info("Response: {}", content);

assertThat(content).contains("30", "10", "15");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ public static void main(String[] args) {
Flux<ConverseStreamOutput> responses = chatModel.converseStream(streamRequest);
List<ConverseStreamOutput> responseList = responses.collectList().block();
System.out.println(responseList);
System.out.println("Response count: " + responseList.size());
responseList.forEach(System.out::println);
}

}