diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 3c4e287ec2c..8bdc1edd3a7 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -34,8 +34,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; -import reactor.core.publisher.Sinks; -import reactor.core.publisher.Sinks.EmitFailureHandler; import reactor.core.scheduler.Schedulers; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.core.SdkBytes; @@ -51,9 +49,7 @@ import software.amazon.awssdk.services.bedrockruntime.model.ConverseMetrics; import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler; import software.amazon.awssdk.services.bedrockruntime.model.DocumentBlock; import software.amazon.awssdk.services.bedrockruntime.model.DocumentSource; import software.amazon.awssdk.services.bedrockruntime.model.ImageBlock; @@ -76,6 +72,7 @@ import org.springframework.ai.bedrock.converse.api.BedrockMediaFormat; import org.springframework.ai.bedrock.converse.api.ConverseApiUtils; +import org.springframework.ai.bedrock.converse.api.ConverseChatResponseStream; import org.springframework.ai.bedrock.converse.api.URLValidator; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; @@ -84,6 +81,7 @@ import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.DefaultUsage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -682,11 +680,17 @@ private Flux internalStream(Prompt prompt, ChatResponse perviousCh .system(converseRequest.system()) .additionalModelRequestFields(converseRequest.additionalModelRequestFields()) .toolConfig(converseRequest.toolConfig()) + .requestMetadata(converseRequest.requestMetadata()) .build(); - Flux response = converseStream(converseStreamRequest); + Usage accumulatedUsage = null; + if (perviousChatResponse != null && perviousChatResponse.getMetadata() != null) { + accumulatedUsage = perviousChatResponse.getMetadata().getUsage(); + } - Flux chatResponses = ConverseApiUtils.toChatResponse(response, perviousChatResponse); + Flux chatResponses = new ConverseChatResponseStream(this.bedrockRuntimeAsyncClient, + converseStreamRequest, accumulatedUsage) + .stream(); Flux chatResponseFlux = chatResponses.switchMap(chatResponse -> { @@ -733,48 +737,6 @@ private Flux internalStream(Prompt prompt, ChatResponse perviousCh }); } - public static final EmitFailureHandler DEFAULT_EMIT_FAILURE_HANDLER = EmitFailureHandler - .busyLooping(Duration.ofSeconds(10)); - - /** - * Invoke the model and return the response stream. - * - * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html - * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html - * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream - * @param converseStreamRequest Model invocation request. - * @return The model invocation response stream. - */ - public Flux converseStream(ConverseStreamRequest converseStreamRequest) { - Assert.notNull(converseStreamRequest, "'converseStreamRequest' must not be null"); - - Sinks.Many eventSink = Sinks.many().multicast().onBackpressureBuffer(); - - ConverseStreamResponseHandler.Visitor visitor = ConverseStreamResponseHandler.Visitor.builder() - .onDefault(output -> { - logger.debug("Received converse stream output:{}", output); - eventSink.emitNext(output, DEFAULT_EMIT_FAILURE_HANDLER); - }) - .build(); - - ConverseStreamResponseHandler responseHandler = ConverseStreamResponseHandler.builder() - .onEventStream(stream -> stream.subscribe(e -> e.accept(visitor))) - .onComplete(() -> { - eventSink.emitComplete(DEFAULT_EMIT_FAILURE_HANDLER); - logger.info("Completed streaming response."); - }) - .onError(error -> { - logger.error("Error handling Bedrock converse stream response", error); - eventSink.emitError(error, DEFAULT_EMIT_FAILURE_HANDLER); - }) - .build(); - - this.bedrockRuntimeAsyncClient.converseStream(converseStreamRequest, responseHandler); - - return eventSink.asFlux(); - - } - /** * Use the provided convention for reporting observation data * @param observationConvention The provided convention diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java index 0fb0ee8e113..328d7a48a04 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java @@ -18,44 +18,16 @@ import java.math.BigDecimal; import java.math.BigInteger; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import software.amazon.awssdk.core.SdkField; import software.amazon.awssdk.core.document.Document; -import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDelta; -import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent; -import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStart; -import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStartEvent; -import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStopEvent; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetrics; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput.EventType; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler.Visitor; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamTrace; -import software.amazon.awssdk.services.bedrockruntime.model.MessageStartEvent; -import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent; -import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage; -import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlockStart; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; -import org.springframework.ai.chat.metadata.ChatResponseMetadata; -import org.springframework.ai.chat.metadata.DefaultUsage; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.model.ModelOptions; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; /** * Amazon Bedrock Converse API utils. @@ -67,249 +39,9 @@ */ public final class ConverseApiUtils { - public static final ChatResponse EMPTY_CHAT_RESPONSE = ChatResponse.builder() - .generations(List.of()) - .metadata("empty", true) - .build(); - private ConverseApiUtils() { - - } - - public static boolean isToolUseStart(ConverseStreamOutput event) { - if (event == null || event.sdkEventType() == null || event.sdkEventType() != EventType.CONTENT_BLOCK_START) { - return false; - } - - return ContentBlockStart.Type.TOOL_USE == ((ContentBlockStartEvent) event).start().type(); - } - - public static boolean isToolUseFinish(ConverseStreamOutput event) { - if (event == null || event.sdkEventType() == null || event.sdkEventType() != EventType.METADATA) { - return false; - } - return true; - } - - public static Flux toChatResponse(Flux responses, - ChatResponse perviousChatResponse) { - - AtomicBoolean isInsideTool = new AtomicBoolean(false); - - return responses.map(event -> { - if (ConverseApiUtils.isToolUseStart(event)) { - isInsideTool.set(true); - } - return event; - }).windowUntil(event -> { // Group all chunks belonging to the same function call. - if (isInsideTool.get() && ConverseApiUtils.isToolUseFinish(event)) { - isInsideTool.set(false); - return true; - } - return !isInsideTool.get(); - }).concatMapIterable(window -> { // Merging the window chunks into a single chunk. - Mono monoChunk = window.reduce(new ToolUseAggregationEvent(), - ConverseApiUtils::mergeToolUseEvents); - return List.of(monoChunk); - }).flatMap(mono -> mono).scanWith(() -> new Aggregation(), (lastAggregation, nextEvent) -> { - - // System.out.println(nextEvent); - if (nextEvent instanceof ToolUseAggregationEvent toolUseAggregationEvent) { - - if (CollectionUtils.isEmpty(toolUseAggregationEvent.toolUseEntries())) { - return new Aggregation(); - } - - List toolCalls = new ArrayList<>(); - - Integer promptTokens = 0; - Integer generationTokens = 0; - Integer totalTokens = 0; - - for (ToolUseAggregationEvent.ToolUseEntry toolUseEntry : toolUseAggregationEvent.toolUseEntries()) { - var functionCallId = toolUseEntry.id(); - var functionName = toolUseEntry.name(); - var functionArguments = toolUseEntry.input(); - toolCalls.add( - new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments)); - - if (toolUseEntry.usage() != null) { - promptTokens += toolUseEntry.usage().getPromptTokens(); - generationTokens += toolUseEntry.usage().getCompletionTokens(); - totalTokens += toolUseEntry.usage().getTotalTokens(); - } - } - - AssistantMessage assistantMessage = AssistantMessage.builder() - .content("") - .properties(Map.of()) - .toolCalls(toolCalls) - .build(); - Generation toolCallGeneration = new Generation(assistantMessage, - ChatGenerationMetadata.builder().finishReason("tool_use").build()); - - var chatResponseMetaData = ChatResponseMetadata.builder() - .usage(new DefaultUsage(promptTokens, generationTokens, totalTokens)) - .build(); - - return new Aggregation( - MetadataAggregation.builder().copy(lastAggregation.metadataAggregation()).build(), - new ChatResponse(List.of(toolCallGeneration), chatResponseMetaData)); - - } - else if (nextEvent instanceof MessageStartEvent messageStartEvent) { - var newMeta = MetadataAggregation.builder() - .copy(lastAggregation.metadataAggregation()) - .withRole(messageStartEvent.role().toString()) - .build(); - return new Aggregation(newMeta, ConverseApiUtils.EMPTY_CHAT_RESPONSE); - } - else if (nextEvent instanceof MessageStopEvent messageStopEvent) { - var newMeta = MetadataAggregation.builder() - .copy(lastAggregation.metadataAggregation()) - .withStopReason(messageStopEvent.stopReasonAsString()) - .withAdditionalModelResponseFields(messageStopEvent.additionalModelResponseFields()) - .build(); - return new Aggregation(newMeta, ConverseApiUtils.EMPTY_CHAT_RESPONSE); - } - else if (nextEvent instanceof ContentBlockStartEvent contentBlockStartEvent) { - // TODO ToolUse support - return new Aggregation(); - } - else if (nextEvent instanceof ContentBlockDeltaEvent contentBlockDeltaEvent) { - if (contentBlockDeltaEvent.delta().type().equals(ContentBlockDelta.Type.TEXT)) { - - var generation = new Generation( - AssistantMessage.builder() - .content(contentBlockDeltaEvent.delta().text()) - .properties(Map.of()) - .build(), - ChatGenerationMetadata.builder() - .finishReason(lastAggregation.metadataAggregation().stopReason()) - .build()); - - return new Aggregation( - MetadataAggregation.builder().copy(lastAggregation.metadataAggregation()).build(), - new ChatResponse(List.of(generation))); - } - else if (contentBlockDeltaEvent.delta().type().equals(ContentBlockDelta.Type.TOOL_USE)) { - // TODO ToolUse support - } - return new Aggregation(); - } - else if (nextEvent instanceof ContentBlockStopEvent contentBlockStopEvent) { - // TODO ToolUse support - return new Aggregation(); - } - else if (nextEvent instanceof ConverseStreamMetadataEvent metadataEvent) { - - var newMeta = MetadataAggregation.builder() - .copy(lastAggregation.metadataAggregation()) - .withTokenUsage(metadataEvent.usage()) - .withMetrics(metadataEvent.metrics()) - .withTrace(metadataEvent.trace()) - .build(); - - // TODO - Document modelResponseFields = lastAggregation.metadataAggregation().additionalModelResponseFields(); - ConverseStreamMetrics metrics = metadataEvent.metrics(); - - DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens(), - metadataEvent.usage().outputTokens(), metadataEvent.usage().totalTokens()); - - var chatResponseMetaData = ChatResponseMetadata.builder().usage(usage).build(); - - return new Aggregation(newMeta, new ChatResponse(List.of(), chatResponseMetaData)); - } - else { - return new Aggregation(); - } - }) - // .skip(1) - .filter(aggregation -> aggregation.chatResponse() != ConverseApiUtils.EMPTY_CHAT_RESPONSE) - .map(aggregation -> { - - var chatResponse = aggregation.chatResponse(); - - // Merge the previous chat response metadata with the current one. - if (perviousChatResponse != null && perviousChatResponse.getMetadata() != null - && perviousChatResponse.getMetadata().getUsage() != null) { - - var metadataBuilder = ChatResponseMetadata.builder(); - - Integer promptTokens = perviousChatResponse.getMetadata().getUsage().getPromptTokens(); - Integer generationTokens = perviousChatResponse.getMetadata().getUsage().getCompletionTokens(); - int totalTokens = perviousChatResponse.getMetadata().getUsage().getTotalTokens(); - - if (chatResponse.getMetadata() != null) { - metadataBuilder.id(chatResponse.getMetadata().getId()); - metadataBuilder.model(chatResponse.getMetadata().getModel()); - metadataBuilder.rateLimit(chatResponse.getMetadata().getRateLimit()); - metadataBuilder.promptMetadata(chatResponse.getMetadata().getPromptMetadata()); - - if (chatResponse.getMetadata().getUsage() != null) { - promptTokens = promptTokens + chatResponse.getMetadata().getUsage().getPromptTokens(); - generationTokens = generationTokens - + chatResponse.getMetadata().getUsage().getCompletionTokens(); - totalTokens = totalTokens + chatResponse.getMetadata().getUsage().getTotalTokens(); - } - } - - metadataBuilder.usage(new DefaultUsage(promptTokens, generationTokens, totalTokens)); - - return new ChatResponse(chatResponse.getResults(), metadataBuilder.build()); - } - - return aggregation.chatResponse(); - }); - } - - public static ConverseStreamOutput mergeToolUseEvents(ConverseStreamOutput previousEvent, - ConverseStreamOutput event) { - - ToolUseAggregationEvent toolUseEventAggregator = (ToolUseAggregationEvent) previousEvent; - - if (event.sdkEventType() == EventType.CONTENT_BLOCK_START) { - - ContentBlockStartEvent contentBlockStart = (ContentBlockStartEvent) event; - - if (ContentBlockStart.Type.TOOL_USE.equals(contentBlockStart.start().type())) { - ToolUseBlockStart cbToolUse = contentBlockStart.start().toolUse(); - - return toolUseEventAggregator.withIndex(contentBlockStart.contentBlockIndex()) - .withId(cbToolUse.toolUseId()) - .withName(cbToolUse.name()) - .appendPartialJson(""); // CB START always has empty JSON. - } - } - else if (event.sdkEventType() == EventType.CONTENT_BLOCK_DELTA) { - ContentBlockDeltaEvent contentBlockDelta = (ContentBlockDeltaEvent) event; - if (ContentBlockDelta.Type.TOOL_USE == contentBlockDelta.delta().type()) { - return toolUseEventAggregator.appendPartialJson(contentBlockDelta.delta().toolUse().input()); - } - } - else if (event.sdkEventType() == EventType.CONTENT_BLOCK_STOP) { - return toolUseEventAggregator; - } - else if (event.sdkEventType() == EventType.MESSAGE_STOP) { - return toolUseEventAggregator; - } - else if (event.sdkEventType() == EventType.METADATA) { - ConverseStreamMetadataEvent metadataEvent = (ConverseStreamMetadataEvent) event; - DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens(), - metadataEvent.usage().outputTokens(), metadataEvent.usage().totalTokens()); - toolUseEventAggregator.withUsage(usage); - - if (!toolUseEventAggregator.isEmpty()) { - toolUseEventAggregator.squashIntoContentBlock(); - return toolUseEventAggregator; - } - } - - return event; } - @SuppressWarnings("unchecked") public static Document getChatOptionsAdditionalModelRequestFields(ChatOptions defaultOptions, ModelOptions promptOptions) { if (defaultOptions == null && promptOptions == null) { @@ -391,7 +123,6 @@ else if (value instanceof Map mapValue) { } } - @SuppressWarnings("unchecked") public static Map getRequestMetadata(Map metadata) { if (metadata.isEmpty()) { @@ -419,165 +150,4 @@ private static Document convertMapToDocument(Map value) { return Document.fromMap(attr); } - public record Aggregation(MetadataAggregation metadataAggregation, ChatResponse chatResponse) { - public Aggregation() { - this(MetadataAggregation.builder().build(), EMPTY_CHAT_RESPONSE); - } - } - - /** - * Special event used to aggregate multiple tool use events into a single event with - * list of aggregated ContentBlockToolUse. - */ - public static class ToolUseAggregationEvent implements ConverseStreamOutput { - - private Integer index; - - private String id; - - private String name; - - private String partialJson = ""; - - private List toolUseEntries = new ArrayList<>(); - - private DefaultUsage usage; - - public List toolUseEntries() { - return this.toolUseEntries; - } - - public boolean isEmpty() { - return (this.index == null || this.id == null || this.name == null || this.partialJson == null); - } - - ToolUseAggregationEvent withIndex(Integer index) { - this.index = index; - return this; - } - - ToolUseAggregationEvent withId(String id) { - this.id = id; - return this; - } - - ToolUseAggregationEvent withName(String name) { - this.name = name; - return this; - } - - ToolUseAggregationEvent withUsage(DefaultUsage usage) { - this.usage = usage; - return this; - } - - ToolUseAggregationEvent appendPartialJson(String partialJson) { - this.partialJson = this.partialJson + partialJson; - return this; - } - - void squashIntoContentBlock() { - // Workaround to handle streaming tool calling with no input arguments. - String json = StringUtils.hasText(this.partialJson) ? this.partialJson : "{}"; - this.toolUseEntries.add(new ToolUseEntry(this.index, this.id, this.name, json, this.usage)); - this.index = null; - this.id = null; - this.name = null; - this.partialJson = ""; - this.usage = null; - } - - @Override - public String toString() { - return "EventToolUseBuilder [index=" + this.index + ", id=" + this.id + ", name=" + this.name - + ", partialJson=" + this.partialJson + ", toolUseMap=" + "]"; - } - - @Override - public List> sdkFields() { - return List.of(); - } - - @Override - public void accept(Visitor visitor) { - throw new UnsupportedOperationException(); - } - - public record ToolUseEntry(Integer index, String id, String name, String input, DefaultUsage usage) { - } - - } - - public record MetadataAggregation(String role, String stopReason, Document additionalModelResponseFields, - TokenUsage tokenUsage, ConverseStreamMetrics metrics, ConverseStreamTrace trace) { - - public static Builder builder() { - return new Builder(); - } - - public static final class Builder { - - private String role; - - private String stopReason; - - private Document additionalModelResponseFields; - - private TokenUsage tokenUsage; - - private ConverseStreamMetrics metrics; - - private ConverseStreamTrace trace; - - private Builder() { - } - - public Builder copy(MetadataAggregation metadataAggregation) { - this.role = metadataAggregation.role; - this.stopReason = metadataAggregation.stopReason; - this.additionalModelResponseFields = metadataAggregation.additionalModelResponseFields; - this.tokenUsage = metadataAggregation.tokenUsage; - this.metrics = metadataAggregation.metrics; - this.trace = metadataAggregation.trace; - return this; - } - - public Builder withRole(String role) { - this.role = role; - return this; - } - - public Builder withStopReason(String stopReason) { - this.stopReason = stopReason; - return this; - } - - public Builder withAdditionalModelResponseFields(Document additionalModelResponseFields) { - this.additionalModelResponseFields = additionalModelResponseFields; - return this; - } - - public Builder withTokenUsage(TokenUsage tokenUsage) { - this.tokenUsage = tokenUsage; - return this; - } - - public Builder withMetrics(ConverseStreamMetrics metrics) { - this.metrics = metrics; - return this; - } - - public Builder withTrace(ConverseStreamTrace trace) { - this.trace = trace; - return this; - } - - public MetadataAggregation build() { - return new MetadataAggregation(this.role, this.stopReason, this.additionalModelResponseFields, - this.tokenUsage, this.metrics, this.trace); - } - - } - } - } diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseChatResponseStream.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseChatResponseStream.java new file mode 100644 index 00000000000..0c1a360ab62 --- /dev/null +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseChatResponseStream.java @@ -0,0 +1,230 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.bedrock.converse.api; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Sinks; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDelta; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStart; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStartEvent; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler; +import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent; +import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.util.Assert; + +/** + * Sends a {@link ConverseStreamRequest} to Bedrock and returns {@link ChatResponse} + * stream. + * + * @author Jared Rufer + * @since 1.1.0 + */ +public class ConverseChatResponseStream implements ConverseStreamResponseHandler.Visitor { + + private static final Logger logger = LoggerFactory.getLogger(ConverseChatResponseStream.class); + + public static final Sinks.EmitFailureHandler DEFAULT_EMIT_FAILURE_HANDLER = Sinks.EmitFailureHandler + .busyLooping(Duration.ofSeconds(10)); + + private final AtomicReference requestIdRef = new AtomicReference<>("Unknown"); + + private final AtomicReference tokenUsageRef = new AtomicReference<>(); + + private final AtomicInteger promptTokens = new AtomicInteger(); + + private final AtomicInteger generationTokens = new AtomicInteger(); + + private final AtomicInteger totalTokens = new AtomicInteger(); + + private final AtomicReference stopReason = new AtomicReference<>(); + + private final Map toolUseMap = new ConcurrentHashMap<>(); + + private final Sinks.Many eventSink = Sinks.many().multicast().onBackpressureBuffer(); + + private final BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient; + + private final ConverseStreamRequest converseStreamRequest; + + public ConverseChatResponseStream(BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, + ConverseStreamRequest converseStreamRequest, Usage accumulatedUsage) { + + Assert.notNull(bedrockRuntimeAsyncClient, "'bedrockRuntimeAsyncClient' must not be null"); + Assert.notNull(converseStreamRequest, "'converseStreamRequest' must not be null"); + + this.bedrockRuntimeAsyncClient = bedrockRuntimeAsyncClient; + this.converseStreamRequest = converseStreamRequest; + if (accumulatedUsage != null) { + this.totalTokens.set(accumulatedUsage.getTotalTokens()); + this.promptTokens.set(accumulatedUsage.getPromptTokens()); + this.generationTokens.set(accumulatedUsage.getCompletionTokens()); + if (accumulatedUsage.getNativeUsage() instanceof TokenUsage tokenUsage) { + this.mergeNativeTokenUsage(tokenUsage); + } + } + } + + @Override + public void visitContentBlockStart(ContentBlockStartEvent event) { + if (ContentBlockStart.Type.TOOL_USE.equals(event.start().type())) { + this.toolUseMap.put(event.contentBlockIndex(), + new StreamingToolCallBuilder().id(event.start().toolUse().toolUseId()) + .name(event.start().toolUse().name())); + } + } + + @Override + public void visitContentBlockDelta(ContentBlockDeltaEvent event) { + StreamingToolCallBuilder toolCallBuilder = this.toolUseMap.get(event.contentBlockIndex()); + + if (toolCallBuilder != null) { + toolCallBuilder.delta(event.delta().toolUse().input()); + } + else if (ContentBlockDelta.Type.TEXT.equals(event.delta().type())) { + this.emitChatResponse(new Generation(AssistantMessage.builder().content(event.delta().text()).build())); + } + } + + @Override + public void visitMessageStop(MessageStopEvent event) { + this.stopReason.set(event.stopReasonAsString()); + } + + @Override + public void visitMetadata(ConverseStreamMetadataEvent event) { + this.promptTokens.addAndGet(event.usage().inputTokens()); + this.generationTokens.addAndGet(event.usage().outputTokens()); + this.totalTokens.addAndGet(event.usage().totalTokens()); + this.mergeNativeTokenUsage(event.usage()); + + ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.builder() + .finishReason(this.stopReason.get()) + .build(); + + List toolCalls = this.toolUseMap.entrySet() + .stream() + .sorted(Map.Entry.comparingByKey()) + .map(Map.Entry::getValue) + .map(StreamingToolCallBuilder::build) + .toList(); + + if (!toolCalls.isEmpty()) { + this.emitChatResponse(new Generation(AssistantMessage.builder().content("").toolCalls(toolCalls).build(), + generationMetadata)); + } + else { + this.emitChatResponse(new Generation(AssistantMessage.builder().content("").build(), generationMetadata)); + } + } + + private void mergeNativeTokenUsage(TokenUsage tokenUsage) { + this.tokenUsageRef.accumulateAndGet(tokenUsage, (current, next) -> { + if (current == null) { + return next; + } + else { + return TokenUsage.builder() + .inputTokens(addTokens(current.inputTokens(), next.inputTokens())) + .outputTokens(addTokens(current.outputTokens(), next.outputTokens())) + .totalTokens(addTokens(current.totalTokens(), next.totalTokens())) + .cacheReadInputTokens(addTokens(current.cacheReadInputTokens(), next.cacheReadInputTokens())) + .cacheWriteInputTokens(addTokens(current.cacheWriteInputTokens(), next.cacheWriteInputTokens())) + .build(); + } + }); + } + + private static Integer addTokens(Integer current, Integer next) { + if (current == null) { + return next; + } + if (next == null) { + return current; + } + return current + next; + } + + private void emitChatResponse(Generation generation) { + var metadataBuilder = ChatResponseMetadata.builder(); + metadataBuilder.id(this.requestIdRef.get()); + metadataBuilder.usage(this.getCurrentUsage()); + + ChatResponse chatResponse = new ChatResponse(generation == null ? List.of() : List.of(generation), + metadataBuilder.build()); + + this.eventSink.emitNext(chatResponse, DEFAULT_EMIT_FAILURE_HANDLER); + } + + private Usage getCurrentUsage() { + return new DefaultUsage(this.promptTokens.get(), this.generationTokens.get(), this.totalTokens.get(), + this.tokenUsageRef.get()); + } + + /** + * Invoke the model and return the chat response stream. + * @see + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html + * @see + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + * @see + * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream + */ + public Flux stream() { + + ConverseStreamResponseHandler responseHandler = ConverseStreamResponseHandler.builder() + .subscriber(this) + .onResponse(converseStreamResponse -> this.requestIdRef + .set(converseStreamResponse.responseMetadata().requestId())) + .onComplete(() -> { + this.eventSink.emitComplete(DEFAULT_EMIT_FAILURE_HANDLER); + logger.info("Completed streaming response."); + }) + .onError(error -> { + logger.error("Error handling Bedrock converse stream response", error); + this.eventSink.emitError(error, DEFAULT_EMIT_FAILURE_HANDLER); + }) + .build(); + this.bedrockRuntimeAsyncClient.converseStream(this.converseStreamRequest, responseHandler); + + return this.eventSink.asFlux(); + } + +} diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/StreamingToolCallBuilder.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/StreamingToolCallBuilder.java new file mode 100644 index 00000000000..bd6231b5791 --- /dev/null +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/StreamingToolCallBuilder.java @@ -0,0 +1,54 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.bedrock.converse.api; + +import org.springframework.ai.chat.messages.AssistantMessage; + +/** + * @author Jared Rufer + * @since 1.1.0 + */ +public class StreamingToolCallBuilder { + + private final StringBuffer arguments = new StringBuffer(); + + private volatile String id; + + private volatile String name; + + public StreamingToolCallBuilder id(String id) { + this.id = id; + return this; + } + + public StreamingToolCallBuilder name(String name) { + this.name = name; + return this; + } + + public StreamingToolCallBuilder delta(String delta) { + this.arguments.append(delta); + return this; + } + + public AssistantMessage.ToolCall build() { + // Workaround to handle streaming tool calling with no input arguments. + String toolArgs = this.arguments.isEmpty() ? "{}" : this.arguments.toString(); + return new AssistantMessage.ToolCall(this.id, "function", this.name, toolArgs); + } + +} diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java index cfe22f2f01c..dcea19f7c56 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java @@ -324,11 +324,11 @@ void streamFunctionCallTest() { logger.info(metadata.getUsage().toString()); - assertThat(metadata.getUsage().getPromptTokens()).isGreaterThan(1500); - assertThat(metadata.getUsage().getPromptTokens()).isLessThan(3500); + assertThat(metadata.getUsage().getPromptTokens()).isGreaterThan(1000); + assertThat(metadata.getUsage().getPromptTokens()).isLessThan(1500); assertThat(metadata.getUsage().getCompletionTokens()).isGreaterThan(0); - assertThat(metadata.getUsage().getCompletionTokens()).isLessThan(1500); + assertThat(metadata.getUsage().getCompletionTokens()).isLessThan(600); assertThat(metadata.getUsage().getTotalTokens()) .isEqualTo(metadata.getUsage().getPromptTokens() + metadata.getUsage().getCompletionTokens()); @@ -363,7 +363,7 @@ void singularStreamFunctionCallTest() { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "anthropic.claude-3-5-sonnet-20240620-v1:0" }) + @ValueSource(strings = { "us.anthropic.claude-3-5-sonnet-20240620-v1:0" }) void multiModalityEmbeddedImage(String modelName) throws IOException { // @formatter:off @@ -380,7 +380,7 @@ void multiModalityEmbeddedImage(String modelName) throws IOException { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "anthropic.claude-3-5-sonnet-20240620-v1:0" }) + @ValueSource(strings = { "us.anthropic.claude-3-5-sonnet-20240620-v1:0" }) void multiModalityImageUrl2(String modelName) throws IOException { // TODO: add url method that wraps the checked exception. @@ -400,7 +400,7 @@ void multiModalityImageUrl2(String modelName) throws IOException { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "anthropic.claude-3-5-sonnet-20240620-v1:0" }) + @ValueSource(strings = { "us.anthropic.claude-3-5-sonnet-20240620-v1:0" }) void multiModalityImageUrl(String modelName) throws IOException { // TODO: add url method that wraps the checked exception. diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java index 05992349a01..4f2da467eb9 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java @@ -34,7 +34,7 @@ public BedrockProxyChatModel bedrockConverseChatModel() { // String modelId = "anthropic.claude-3-5-sonnet-20241022-v2:0"; // String modelId = "meta.llama3-8b-instruct-v1:0"; // String modelId = "ai21.jamba-1-5-large-v1:0"; - String modelId = "anthropic.claude-3-5-sonnet-20240620-v1:0"; + String modelId = "us.anthropic.claude-3-5-sonnet-20240620-v1:0"; return BedrockProxyChatModel.builder() .credentialsProvider(EnvironmentVariableCredentialsProvider.create()) diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java index 2b2361cba03..9ba03c121c4 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java @@ -76,14 +76,15 @@ class BedrockProxyChatModelIT { private static void validateChatResponseMetadata(ChatResponse response, String model) { // assertThat(response.getMetadata().getId()).isNotEmpty(); // assertThat(response.getMetadata().getModel()).containsIgnoringCase(model); + assertThat(response.getMetadata().getId()).isNotEqualTo("Unknown").isNotBlank(); assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive(); assertThat(response.getMetadata().getUsage().getCompletionTokens()).isPositive(); assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0", - "anthropic.claude-3-5-sonnet-20240620-v1:0" }) + @ValueSource(strings = { "us.anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0", + "us.anthropic.claude-3-5-sonnet-20240620-v1:0" }) void roleTest(String modelName) { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); @@ -314,7 +315,7 @@ void streamFunctionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = BedrockChatOptions.builder() - .model("anthropic.claude-3-5-sonnet-20240620-v1:0") + .model("us.anthropic.claude-3-5-sonnet-20240620-v1:0") .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description( "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") @@ -335,9 +336,38 @@ void streamFunctionCallTest() { assertThat(content).contains("30", "10", "15"); } + @ParameterizedTest(name = "{displayName} - {0} ") + @ValueSource(ints = { 50, 200 }) + void streamFunctionCallTestWithMaxTokens(int maxTokens) { + + UserMessage userMessage = new UserMessage( + // "What's the weather like in San Francisco? Return the result in + // Celsius."); + "What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = BedrockChatOptions.builder() + .maxTokens(maxTokens) + .model("us.anthropic.claude-3-5-sonnet-20240620-v1:0") + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description( + "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); + ChatResponse lastResponse = response.blockLast(); + String finishReason = lastResponse.getResult().getMetadata().getFinishReason(); + + logger.info("Finish reason: {}", finishReason); + assertThat(finishReason).isEqualTo("max_tokens"); + } + @Test void validateCallResponseMetadata() { - String model = "anthropic.claude-3-5-sonnet-20240620-v1:0"; + String model = "us.anthropic.claude-3-5-sonnet-20240620-v1:0"; // @formatter:off ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(BedrockChatOptions.builder().model(model).build()) @@ -352,7 +382,7 @@ void validateCallResponseMetadata() { @Test void validateStreamCallResponseMetadata() { - String model = "anthropic.claude-3-5-sonnet-20240620-v1:0"; + String model = "us.anthropic.claude-3-5-sonnet-20240620-v1:0"; // @formatter:off ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(BedrockChatOptions.builder().model(model).build()) diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java index 1824a1b84d2..ccc0d7257e1 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java @@ -67,7 +67,7 @@ void beforeEach() { @Test void observationForChatOperation() { var options = BedrockChatOptions.builder() - .model("anthropic.claude-3-5-sonnet-20240620-v1:0") + .model("us.anthropic.claude-3-5-sonnet-20240620-v1:0") .maxTokens(2048) .stopSequences(List.of("this-is-the-end")) .temperature(0.7) @@ -89,7 +89,7 @@ void observationForChatOperation() { @Test void observationForStreamingChatOperation() { var options = BedrockChatOptions.builder() - .model("anthropic.claude-3-5-sonnet-20240620-v1:0") + .model("us.anthropic.claude-3-5-sonnet-20240620-v1:0") .maxTokens(2048) .stopSequences(List.of("this-is-the-end")) .temperature(0.7) @@ -124,13 +124,13 @@ private void validate(ChatResponseMetadata responseMetadata, String finishReason .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() - .hasContextualNameEqualTo("chat " + "anthropic.claude-3-5-sonnet-20240620-v1:0") + .hasContextualNameEqualTo("chat " + "us.anthropic.claude-3-5-sonnet-20240620-v1:0") .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), AiOperationType.CHAT.value()) .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.BEDROCK_CONVERSE.value()) .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), - "anthropic.claude-3-5-sonnet-20240620-v1:0") + "us.anthropic.claude-3-5-sonnet-20240620-v1:0") // .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), // responseMetadata.getModel()) .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString()) @@ -166,7 +166,7 @@ public TestObservationRegistry observationRegistry() { @Bean public BedrockProxyChatModel bedrockConverseChatModel(ObservationRegistry observationRegistry) { - String modelId = "anthropic.claude-3-5-sonnet-20240620-v1:0"; + String modelId = "us.anthropic.claude-3-5-sonnet-20240620-v1:0"; return BedrockProxyChatModel.builder() .credentialsProvider(EnvironmentVariableCredentialsProvider.create()) diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java index decddfd7355..fe91af13ac3 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java @@ -191,7 +191,7 @@ void toolAnnotationWeatherForecast() { // https://github.com/spring-projects/spring-ai/issues/1878 @ParameterizedTest - @ValueSource(strings = { "amazon.nova-pro-v1:0", "us.anthropic.claude-3-7-sonnet-20250219-v1:0" }) + @ValueSource(strings = { "us.amazon.nova-pro-v1:0", "us.anthropic.claude-3-7-sonnet-20250219-v1:0" }) void toolAnnotationWeatherForecastStreaming(String modelName) { ChatClient chatClient = ChatClient.builder(this.chatModel).build(); @@ -262,7 +262,7 @@ public static class Config { @Bean public BedrockProxyChatModel bedrockConverseChatModel() { - String modelId = "amazon.nova-pro-v1:0"; + String modelId = "us.amazon.nova-pro-v1:0"; // String modelId = "us.anthropic.claude-3-7-sonnet-20250219-v1:0"; return BedrockProxyChatModel.builder() diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain3.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain3.java index 8d408cf0f4e..a11093d0e37 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain3.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain3.java @@ -40,7 +40,7 @@ public static void main(String[] args) { // String modelId = "anthropic.claude-3-5-sonnet-20240620-v1:0"; // String modelId = "ai21.jamba-1-5-large-v1:0"; - String modelId = "anthropic.claude-3-5-sonnet-20240620-v1:0"; + String modelId = "us.anthropic.claude-3-5-sonnet-20240620-v1:0"; // var prompt = new Prompt("Tell me a joke?", // ChatOptions.builder().model(modelId).build();