diff --git a/models/spring-ai-bedrock-converse/pom.xml b/models/spring-ai-bedrock-converse/pom.xml
new file mode 100644
index 00000000000..e684d6bf133
--- /dev/null
+++ b/models/spring-ai-bedrock-converse/pom.xml
@@ -0,0 +1,84 @@
+
+
+ 4.0.0
+
+ org.springframework.ai
+ spring-ai
+ 1.0.0-SNAPSHOT
+ ../../pom.xml
+
+ spring-ai-bedrock-converse
+ jar
+ Spring AI Model - Amazon Bedrock Converse API
+ Amazon Bedrock models support using the Converse API
+ https://github.com/spring-projects/spring-ai
+
+
+ https://github.com/spring-projects/spring-ai
+ git://github.com/spring-projects/spring-ai.git
+ git@github.com:spring-projects/spring-ai.git
+
+
+
+ 2.29.3
+
+
+
+
+
+
+ org.springframework.ai
+ spring-ai-core
+ ${project.parent.version}
+
+
+
+ org.springframework.ai
+ spring-ai-retry
+ ${project.parent.version}
+
+
+
+ software.amazon.awssdk
+ bedrockruntime
+ ${aws.sdk.version}
+
+
+ commons-logging
+ commons-logging
+
+
+
+
+
+ software.amazon.awssdk
+ sts
+ ${aws.sdk.version}
+
+
+
+ software.amazon.awssdk
+ netty-nio-client
+ ${aws.sdk.version}
+
+
+
+
+ org.springframework.ai
+ spring-ai-test
+ ${project.version}
+ test
+
+
+
+ io.micrometer
+ micrometer-observation-test
+ test
+
+
+
+
+
+
\ No newline at end of file
diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java
new file mode 100644
index 00000000000..f661d59c01c
--- /dev/null
+++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java
@@ -0,0 +1,710 @@
+/*
+ * Copyright 2024 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.bedrock.converse;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URL;
+import java.net.URLConnection;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.ai.bedrock.converse.api.ConverseApiUtils;
+import org.springframework.ai.bedrock.converse.api.URLValidator;
+import org.springframework.ai.chat.messages.AssistantMessage;
+import org.springframework.ai.chat.messages.MessageType;
+import org.springframework.ai.chat.messages.ToolResponseMessage;
+import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
+import org.springframework.ai.chat.metadata.ChatResponseMetadata;
+import org.springframework.ai.chat.metadata.DefaultUsage;
+import org.springframework.ai.chat.model.AbstractToolCallSupport;
+import org.springframework.ai.chat.model.ChatModel;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.model.Generation;
+import org.springframework.ai.chat.model.MessageAggregator;
+import org.springframework.ai.chat.observation.ChatModelObservationContext;
+import org.springframework.ai.chat.observation.ChatModelObservationConvention;
+import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
+import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
+import org.springframework.ai.chat.prompt.ChatOptions;
+import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.model.ModelOptionsUtils;
+import org.springframework.ai.model.function.FunctionCallback;
+import org.springframework.ai.model.function.FunctionCallbackContext;
+import org.springframework.ai.model.function.FunctionCallingOptions;
+import org.springframework.ai.model.function.FunctionCallingOptionsBuilder;
+import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
+import org.springframework.ai.observation.conventions.AiProvider;
+import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
+import org.springframework.util.StreamUtils;
+import org.springframework.util.StringUtils;
+
+import io.micrometer.observation.Observation;
+import io.micrometer.observation.ObservationRegistry;
+import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import reactor.core.publisher.Sinks;
+import reactor.core.publisher.Sinks.EmitFailureHandler;
+import reactor.core.publisher.Sinks.EmitResult;
+import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.core.SdkBytes;
+import software.amazon.awssdk.core.document.Document;
+import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
+import software.amazon.awssdk.regions.Region;
+import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
+import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
+import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
+import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
+import software.amazon.awssdk.services.bedrockruntime.model.ConverseMetrics;
+import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
+import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
+import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
+import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
+import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
+import software.amazon.awssdk.services.bedrockruntime.model.ImageBlock;
+import software.amazon.awssdk.services.bedrockruntime.model.ImageSource;
+import software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration;
+import software.amazon.awssdk.services.bedrockruntime.model.Message;
+import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock;
+import software.amazon.awssdk.services.bedrockruntime.model.Tool;
+import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration;
+import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema;
+import software.amazon.awssdk.services.bedrockruntime.model.ToolResultBlock;
+import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock;
+import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification;
+import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock;
+
+/**
+ * A {@link ChatModel} implementation that uses the Amazon Bedrock Converse API to
+ * interact with the Supported
+ * models.
+ *
+ * The Converse API doesn't support any embedding models (such as Titan Embeddings G1 -
+ * Text) or image generation models (such as Stability AI).
+ *
+ *
+ * https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
+ *
+ * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
+ *
+ * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html
+ *
+ * https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
+ *
+ * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
+ *
+ * @author Christian Tzolov
+ * @author Wei Jiang
+ * @since 1.0.0
+ */
+public class BedrockProxyChatModel extends AbstractToolCallSupport implements ChatModel {
+
+ private static final Logger logger = LoggerFactory.getLogger(BedrockProxyChatModel.class);
+
+ private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
+
+ private final BedrockRuntimeClient bedrockRuntimeClient;
+
+ private final BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient;
+
+ private FunctionCallingOptions defaultOptions;
+
+ /**
+ * Observation registry used for instrumentation.
+ */
+ private final ObservationRegistry observationRegistry;
+
+ /**
+ * Conventions to use for generating observations.
+ */
+ private ChatModelObservationConvention observationConvention;
+
+ public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient,
+ BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, FunctionCallingOptions defaultOptions,
+ FunctionCallbackContext functionCallbackContext, List toolFunctionCallbacks,
+ ObservationRegistry observationRegistry) {
+
+ super(functionCallbackContext, defaultOptions, toolFunctionCallbacks);
+
+ Assert.notNull(bedrockRuntimeClient, "bedrockRuntimeClient must not be null");
+ Assert.notNull(bedrockRuntimeAsyncClient, "bedrockRuntimeAsyncClient must not be null");
+
+ this.bedrockRuntimeClient = bedrockRuntimeClient;
+ this.bedrockRuntimeAsyncClient = bedrockRuntimeAsyncClient;
+ this.defaultOptions = defaultOptions;
+ this.observationRegistry = observationRegistry;
+ }
+
+ /**
+ * Invoke the model and return the response.
+ *
+ * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
+ * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
+ * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient.html#converse
+ * @param bedrockConverseRequest Model invocation request.
+ * @return The model invocation response.
+ */
+ @Override
+ public ChatResponse call(Prompt prompt) {
+
+ ConverseRequest converseRequest = this.createRequest(prompt);
+
+ ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
+ .prompt(prompt)
+ .provider(AiProvider.BEDROCK_CONVERSE.value())
+ .requestOptions(buildRequestOptions(converseRequest))
+ .build();
+
+ ChatResponse chatResponse = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
+ .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
+ this.observationRegistry)
+ .observe(() -> {
+
+ ConverseResponse converseResponse = this.bedrockRuntimeClient.converse(converseRequest);
+
+ var response = this.toChatResponse(converseResponse);
+
+ observationContext.setResponse(response);
+
+ return response;
+ });
+
+ if (!this.isProxyToolCalls(prompt, this.defaultOptions) && chatResponse != null
+ && this.isToolCall(chatResponse, Set.of("tool_use"))) {
+ var toolCallConversation = this.handleToolCalls(prompt, chatResponse);
+ return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
+ }
+
+ return chatResponse;
+ }
+
+ private ChatOptions buildRequestOptions(ConverseRequest request) {
+ return ChatOptionsBuilder.builder()
+ .withModel(request.modelId())
+ .withMaxTokens(request.inferenceConfig().maxTokens())
+ .withStopSequences(request.inferenceConfig().stopSequences())
+ .withTemperature(request.inferenceConfig().temperature() != null
+ ? request.inferenceConfig().temperature().doubleValue() : null)
+ .withTopP(request.inferenceConfig().topP() != null ? request.inferenceConfig().topP().doubleValue() : null)
+ .build();
+ }
+
+ @Override
+ public ChatOptions getDefaultOptions() {
+ return this.defaultOptions;
+ }
+
+ public ConverseStreamRequest createStreamRequest(Prompt prompt) {
+
+ ConverseRequest converseRequest = this.createRequest(prompt);
+
+ return ConverseStreamRequest.builder()
+ .modelId(converseRequest.modelId())
+ .messages(converseRequest.messages())
+ .system(converseRequest.system())
+ .additionalModelRequestFields(converseRequest.additionalModelRequestFields())
+ .toolConfig(converseRequest.toolConfig())
+ .build();
+ }
+
+ ConverseRequest createRequest(Prompt prompt) {
+
+ Set functionsForThisRequest = new HashSet<>();
+
+ List instructionMessages = prompt.getInstructions()
+ .stream()
+ .filter(message -> message.getMessageType() != MessageType.SYSTEM)
+ .map(message -> {
+ if (message.getMessageType() == MessageType.USER) {
+ List contents = new ArrayList<>();
+ if (message instanceof UserMessage) {
+ var userMessage = (UserMessage) message;
+ contents.add(ContentBlock.fromText(userMessage.getContent()));
+
+ if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
+ List mediaContent = userMessage.getMedia().stream().map(media -> {
+ ContentBlock cb = ContentBlock.fromImage(ImageBlock.builder()
+ .format(media.getMimeType().getSubtype())
+ .source(ImageSource
+ .fromBytes(SdkBytes.fromByteArray(getContentMediaData(media.getData()))))
+ .build());
+ return cb;
+ }).toList();
+ contents.addAll(mediaContent);
+ }
+ }
+ return Message.builder().content(contents).role(ConversationRole.USER).build();
+ }
+ else if (message.getMessageType() == MessageType.ASSISTANT) {
+ AssistantMessage assistantMessage = (AssistantMessage) message;
+ List contentBlocks = new ArrayList<>();
+ if (StringUtils.hasText(message.getContent())) {
+ contentBlocks.add(ContentBlock.fromText(message.getContent()));
+ }
+ if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
+ for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
+
+ var argumentsDocument = ConverseApiUtils
+ .convertObjectToDocument(ModelOptionsUtils.jsonToMap(toolCall.arguments()));
+
+ contentBlocks.add(ContentBlock.fromToolUse(ToolUseBlock.builder()
+ .toolUseId(toolCall.id())
+ .name(toolCall.name())
+ .input(argumentsDocument)
+ .build()));
+
+ }
+ }
+ return Message.builder().content(contentBlocks).role(ConversationRole.ASSISTANT).build();
+ }
+ else if (message.getMessageType() == MessageType.TOOL) {
+ List contentBlocks = ((ToolResponseMessage) message).getResponses()
+ .stream()
+ .map(toolResponse -> {
+ ToolResultBlock toolResultBlock = ToolResultBlock.builder()
+ .toolUseId(toolResponse.id())
+ .content(ToolResultContentBlock.builder().text(toolResponse.responseData()).build())
+ .build();
+ return ContentBlock.fromToolResult(toolResultBlock);
+ })
+ .toList();
+ return Message.builder().content(contentBlocks).role(ConversationRole.USER).build();
+ }
+ else {
+ throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
+ }
+ })
+ .toList();
+
+ List systemMessages = prompt.getInstructions()
+ .stream()
+ .filter(m -> m.getMessageType() == MessageType.SYSTEM)
+ .map(sysMessage -> SystemContentBlock.builder().text(sysMessage.getContent()).build())
+ .toList();
+
+ FunctionCallingOptions updatedRuntimeOptions = (FunctionCallingOptions) this.defaultOptions.copy();
+
+ if (prompt.getOptions() != null) {
+ if (prompt.getOptions() instanceof FunctionCallingOptions) {
+ var functionCallingOptions = (FunctionCallingOptions) prompt.getOptions();
+ updatedRuntimeOptions = ((PortableFunctionCallingOptions) updatedRuntimeOptions)
+ .merge(functionCallingOptions);
+ }
+ else if (prompt.getOptions() instanceof ChatOptions) {
+ var chatOptions = (ChatOptions) prompt.getOptions();
+ updatedRuntimeOptions = ((PortableFunctionCallingOptions) updatedRuntimeOptions).merge(chatOptions);
+ }
+ }
+
+ functionsForThisRequest.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions));
+
+ ToolConfiguration toolConfiguration = null;
+
+ if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
+ toolConfiguration = ToolConfiguration.builder().tools(getFunctionTools(functionsForThisRequest)).build();
+ }
+
+ InferenceConfiguration inferenceConfiguration = InferenceConfiguration.builder()
+ .maxTokens(updatedRuntimeOptions.getMaxTokens())
+ .stopSequences(updatedRuntimeOptions.getStopSequences())
+ .temperature(updatedRuntimeOptions.getTemperature() != null
+ ? updatedRuntimeOptions.getTemperature().floatValue() : null)
+ .topP(updatedRuntimeOptions.getTopP() != null ? updatedRuntimeOptions.getTopP().floatValue() : null)
+ .build();
+ Document additionalModelRequestFields = ConverseApiUtils
+ .getChatOptionsAdditionalModelRequestFields(defaultOptions, prompt.getOptions());
+
+ return ConverseRequest.builder()
+ .modelId(updatedRuntimeOptions.getModel())
+ .inferenceConfig(inferenceConfiguration)
+ .messages(instructionMessages)
+ .system(systemMessages)
+ .additionalModelRequestFields(additionalModelRequestFields)
+ .toolConfig(toolConfiguration)
+ .build();
+ }
+
+ private List getFunctionTools(Set functionNames) {
+ return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
+ var description = functionCallback.getDescription();
+ var name = functionCallback.getName();
+ String inputSchema = functionCallback.getInputTypeSchema();
+ return Tool.builder()
+ .toolSpec(ToolSpecification.builder()
+ .name(name)
+ .description(description)
+ .inputSchema(ToolInputSchema
+ .fromJson(ConverseApiUtils.convertObjectToDocument(ModelOptionsUtils.jsonToMap(inputSchema))))
+ .build())
+ .build();
+ }).toList();
+ }
+
+ private static byte[] getContentMediaData(Object mediaData) {
+ if (mediaData instanceof byte[] bytes) {
+ return bytes;
+ }
+ else if (mediaData instanceof String text) {
+ if (URLValidator.isValidURLBasic(text)) {
+ try {
+ URL url = new URL(text);
+ URLConnection connection = url.openConnection();
+ try (InputStream is = connection.getInputStream()) {
+ return StreamUtils.copyToByteArray(is);
+ }
+ }
+ catch (IOException e) {
+ throw new RuntimeException("Failed to read media data from URL: " + text, e);
+ }
+ }
+ return text.getBytes();
+ }
+ else if (mediaData instanceof URL url) {
+ try (InputStream is = url.openConnection().getInputStream()) {
+ return StreamUtils.copyToByteArray(is);
+ }
+ catch (IOException e) {
+ throw new RuntimeException("Failed to read media data from URL: " + url, e);
+ }
+ }
+ else {
+ throw new IllegalArgumentException("Unsupported media data type: " + mediaData.getClass().getSimpleName());
+ }
+ }
+
+ /**
+ * Convert {@link ConverseResponse} to {@link ChatResponse} includes model output,
+ * stopReason, usage, metrics etc.
+ * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_ResponseSyntax
+ * @param response The Bedrock Converse response.
+ * @return The ChatResponse entity.
+ */
+ private ChatResponse toChatResponse(ConverseResponse response) {
+
+ Assert.notNull(response, "'response' must not be null.");
+
+ Message message = response.output().message();
+
+ List generations = message.content()
+ .stream()
+ .filter(content -> content.type() != ContentBlock.Type.TOOL_USE)
+ .map(content -> {
+ return new Generation(new AssistantMessage(content.text(), Map.of()),
+ ChatGenerationMetadata.from(response.stopReasonAsString(), null));
+ })
+ .toList();
+
+ List allGenerations = new ArrayList<>(generations);
+
+ if (response.stopReasonAsString() != null && generations.isEmpty()) {
+ Generation generation = new Generation(new AssistantMessage(null, Map.of()),
+ ChatGenerationMetadata.from(response.stopReasonAsString(), null));
+ allGenerations.add(generation);
+ }
+
+ List toolUseContentBlocks = message.content()
+ .stream()
+ .filter(c -> c.type() == ContentBlock.Type.TOOL_USE)
+ .toList();
+
+ if (!CollectionUtils.isEmpty(toolUseContentBlocks)) {
+
+ List toolCalls = new ArrayList<>();
+
+ for (ContentBlock toolUseContentBlock : toolUseContentBlocks) {
+
+ var functionCallId = toolUseContentBlock.toolUse().toolUseId();
+ var functionName = toolUseContentBlock.toolUse().name();
+ var functionArguments = toolUseContentBlock.toolUse().input().toString();
+
+ toolCalls
+ .add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
+ }
+
+ AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
+ Generation toolCallGeneration = new Generation(assistantMessage,
+ ChatGenerationMetadata.from(response.stopReasonAsString(), null));
+ allGenerations.add(toolCallGeneration);
+ }
+
+ DefaultUsage usage = new DefaultUsage(response.usage().inputTokens().longValue(),
+ response.usage().outputTokens().longValue(), response.usage().totalTokens().longValue());
+
+ Document modelResponseFields = response.additionalModelResponseFields();
+
+ ConverseMetrics metrics = response.metrics();
+
+ var chatResponseMetaData = ChatResponseMetadata.builder()
+ .withId(response.responseMetadata().requestId())
+ .withUsage(usage)
+ .build();
+
+ return new ChatResponse(allGenerations, chatResponseMetaData);
+ }
+
+ /**
+ * Invoke the model and return the response stream.
+ *
+ * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
+ * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
+ * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream
+ * @param bedrockConverseRequest Model invocation request.
+ * @return The model invocation response stream.
+ */
+ @Override
+ public Flux stream(Prompt prompt) {
+ Assert.notNull(prompt, "'prompt' must not be null");
+
+ return Flux.deferContextual(contextView -> {
+
+ ConverseRequest converseRequest = this.createRequest(prompt);
+
+ // System.out.println(">>>>> CONVERSE REQUEST: " + converseRequest);
+
+ ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
+ .prompt(prompt)
+ .provider(AiProvider.BEDROCK_CONVERSE.value())
+ .requestOptions(buildRequestOptions(converseRequest))
+ .build();
+
+ Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
+ this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
+ this.observationRegistry);
+
+ observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
+
+ ConverseStreamRequest converseStreamRequest = ConverseStreamRequest.builder()
+ .modelId(converseRequest.modelId())
+ .messages(converseRequest.messages())
+ .system(converseRequest.system())
+ .additionalModelRequestFields(converseRequest.additionalModelRequestFields())
+ .toolConfig(converseRequest.toolConfig())
+ .build();
+
+ Flux response = converseStream(converseStreamRequest);
+
+ // @formatter:off
+ Flux chatResponses = ConverseApiUtils.toChatResponse(response);
+
+ Flux chatResponseFlux = chatResponses.switchMap(chatResponse -> {
+ if (!this.isProxyToolCalls(prompt, this.defaultOptions) && chatResponse != null
+ && this.isToolCall(chatResponse, Set.of("tool_use"))) {
+ var toolCallConversation = this.handleToolCalls(prompt, chatResponse);
+ return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
+ }
+ return Mono.just(chatResponse);
+ })
+ .doOnError(observation::error)
+ .doFinally(s -> observation.stop())
+ .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
+ // @formatter:on
+
+ return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse);
+ });
+ }
+
+ /**
+ * Invoke the model and return the response stream.
+ *
+ * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
+ * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
+ * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream
+ * @param converseStreamRequest Model invocation request.
+ * @return The model invocation response stream.
+ */
+ public Flux converseStream(ConverseStreamRequest converseStreamRequest) {
+ Assert.notNull(converseStreamRequest, "'converseStreamRequest' must not be null");
+
+ Sinks.Many eventSink = Sinks.many().multicast().onBackpressureBuffer();
+
+ ConverseStreamResponseHandler.Visitor visitor = ConverseStreamResponseHandler.Visitor.builder()
+ .onDefault((output) -> {
+ logger.debug("Received converse stream output:{}", output);
+ eventSink.tryEmitNext(output);
+ })
+ .build();
+
+ ConverseStreamResponseHandler responseHandler = ConverseStreamResponseHandler.builder()
+ .onEventStream(stream -> stream.subscribe((e) -> e.accept(visitor)))
+ .onComplete(() -> {
+ EmitResult emitResult = eventSink.tryEmitComplete();
+
+ while (!emitResult.isSuccess()) {
+ logger.info("Emitting complete:{}", emitResult);
+ emitResult = eventSink.tryEmitComplete();
+ }
+
+ eventSink.emitComplete(EmitFailureHandler.busyLooping(Duration.ofSeconds(3)));
+ logger.info("Completed streaming response.");
+ })
+ .onError((error) -> {
+ logger.error("Error handling Bedrock converse stream response", error);
+ eventSink.tryEmitError(error);
+ })
+ .build();
+
+ this.bedrockRuntimeAsyncClient.converseStream(converseStreamRequest, responseHandler);
+
+ return eventSink.asFlux();
+
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static class Builder {
+
+ private AwsCredentialsProvider credentialsProvider;
+
+ private Region region = Region.US_EAST_1;
+
+ private Duration timeout = Duration.ofMinutes(10);
+
+ private FunctionCallingOptions defaultOptions = new FunctionCallingOptionsBuilder().build();
+
+ private FunctionCallbackContext functionCallbackContext;
+
+ private List toolFunctionCallbacks;
+
+ private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
+
+ private ChatModelObservationConvention customObservationConvention;
+
+ private BedrockRuntimeClient bedrockRuntimeClient;
+
+ private BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient;
+
+ private Builder() {
+ }
+
+ public Builder withCredentialsProvider(AwsCredentialsProvider credentialsProvider) {
+ Assert.notNull(credentialsProvider, "'credentialsProvider' must not be null.");
+ this.credentialsProvider = credentialsProvider;
+ return this;
+ }
+
+ public Builder withRegion(Region region) {
+ Assert.notNull(region, "'region' must not be null.");
+ this.region = region;
+ return this;
+ }
+
+ public Builder withTimeout(Duration timeout) {
+ Assert.notNull(timeout, "'timeout' must not be null.");
+ this.timeout = timeout;
+ return this;
+ }
+
+ public Builder withDefaultOptions(FunctionCallingOptions defaultOptions) {
+ Assert.notNull(defaultOptions, "'defaultOptions' must not be null.");
+ this.defaultOptions = defaultOptions;
+ return this;
+ }
+
+ public Builder withFunctionCallbackContext(FunctionCallbackContext functionCallbackContext) {
+ this.functionCallbackContext = functionCallbackContext;
+ return this;
+ }
+
+ public Builder withToolFunctionCallbacks(List toolFunctionCallbacks) {
+ this.toolFunctionCallbacks = toolFunctionCallbacks;
+ return this;
+ }
+
+ public Builder withObservationRegistry(ObservationRegistry observationRegistry) {
+ Assert.notNull(observationRegistry, "'observationRegistry' must not be null.");
+ this.observationRegistry = observationRegistry;
+ return this;
+ }
+
+ public Builder withCustomObservationConvention(ChatModelObservationConvention observationConvention) {
+ Assert.notNull(observationConvention, "'observationConvention' must not be null.");
+ this.customObservationConvention = observationConvention;
+ return this;
+ }
+
+ public Builder withBedrockRuntimeClient(BedrockRuntimeClient bedrockRuntimeClient) {
+ this.bedrockRuntimeClient = bedrockRuntimeClient;
+ return this;
+ }
+
+ public Builder withBedrockRuntimeAsyncClient(BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient) {
+ this.bedrockRuntimeAsyncClient = bedrockRuntimeAsyncClient;
+ return this;
+ }
+
+ public BedrockProxyChatModel build() {
+
+ if (this.bedrockRuntimeClient == null) {
+ this.bedrockRuntimeClient = BedrockRuntimeClient.builder()
+ .region(this.region)
+ .httpClientBuilder(null)
+ .credentialsProvider(this.credentialsProvider)
+ .overrideConfiguration(c -> c.apiCallTimeout(this.timeout))
+ .build();
+ }
+
+ if (this.bedrockRuntimeAsyncClient == null) {
+
+ // TODO: Is it ok to configure the NettyNioAsyncHttpClient explicitly???
+ var httpClientBuilder = NettyNioAsyncHttpClient.builder()
+ .tcpKeepAlive(true)
+ .connectionAcquisitionTimeout(Duration.ofSeconds(30))
+ .maxConcurrency(200);
+
+ var builder = BedrockRuntimeAsyncClient.builder()
+ .region(this.region)
+ .httpClientBuilder(httpClientBuilder)
+ .credentialsProvider(this.credentialsProvider)
+ .overrideConfiguration(c -> c.apiCallTimeout(this.timeout));
+ this.bedrockRuntimeAsyncClient = builder.build();
+ }
+
+ var bedrockProxyChatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient,
+ this.bedrockRuntimeAsyncClient, this.defaultOptions, this.functionCallbackContext,
+ this.toolFunctionCallbacks, this.observationRegistry);
+
+ if (this.customObservationConvention != null) {
+ bedrockProxyChatModel.setObservationConvention(this.customObservationConvention);
+ }
+
+ return bedrockProxyChatModel;
+ }
+
+ }
+
+ /**
+ * Use the provided convention for reporting observation data
+ * @param observationConvention The provided convention
+ */
+ public void setObservationConvention(ChatModelObservationConvention observationConvention) {
+ Assert.notNull(observationConvention, "observationConvention cannot be null");
+ this.observationConvention = observationConvention;
+ }
+
+}
diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/BedrockUsage.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/BedrockUsage.java
new file mode 100644
index 00000000000..96186b9b782
--- /dev/null
+++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/BedrockUsage.java
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2024 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.bedrock.converse.api;
+
+import org.springframework.ai.chat.metadata.Usage;
+import org.springframework.util.Assert;
+
+import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage;
+
+/**
+ * {@link Usage} implementation for Bedrock Converse API.
+ *
+ * @author Christian Tzolov
+ * @author Wei Jiang
+ * @since 1.0.0
+ */
+public class BedrockUsage implements Usage {
+
+ public static BedrockUsage from(TokenUsage usage) {
+ Assert.notNull(usage, "'TokenUsage' must not be null.");
+
+ return new BedrockUsage(usage.inputTokens().longValue(), usage.outputTokens().longValue());
+ }
+
+ private final Long inputTokens;
+
+ private final Long outputTokens;
+
+ protected BedrockUsage(Long inputTokens, Long outputTokens) {
+ this.inputTokens = inputTokens;
+ this.outputTokens = outputTokens;
+ }
+
+ @Override
+ public Long getPromptTokens() {
+ return inputTokens;
+ }
+
+ @Override
+ public Long getGenerationTokens() {
+ return outputTokens;
+ }
+
+ @Override
+ public String toString() {
+ return "BedrockUsage [inputTokens=" + inputTokens + ", outputTokens=" + outputTokens + "]";
+ }
+
+}
\ No newline at end of file
diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java
new file mode 100644
index 00000000000..ce5730e59c2
--- /dev/null
+++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java
@@ -0,0 +1,503 @@
+/*
+ * Copyright 2024 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.bedrock.converse.api;
+
+import java.math.BigDecimal;
+import java.math.BigInteger;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.stream.Collectors;
+
+import org.springframework.ai.bedrock.converse.api.ConverseApiUtils.Aggregation;
+import org.springframework.ai.bedrock.converse.api.ConverseApiUtils.MetadataAggregation;
+import org.springframework.ai.bedrock.converse.api.ConverseApiUtils.ToolUseAggregationEvent;
+import org.springframework.ai.bedrock.converse.api.ConverseApiUtils.ToolUseAggregationEvent.ToolUseEntry;
+import org.springframework.ai.chat.messages.AssistantMessage;
+import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
+import org.springframework.ai.chat.metadata.ChatResponseMetadata;
+import org.springframework.ai.chat.metadata.DefaultUsage;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.model.Generation;
+import org.springframework.ai.chat.prompt.ChatOptions;
+import org.springframework.ai.model.ModelOptions;
+import org.springframework.ai.model.ModelOptionsUtils;
+import org.springframework.util.CollectionUtils;
+import org.springframework.util.StringUtils;
+
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import software.amazon.awssdk.core.SdkField;
+import software.amazon.awssdk.core.document.Document;
+import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDelta;
+import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent;
+import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStart;
+import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStartEvent;
+import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStopEvent;
+import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent;
+import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetrics;
+import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
+import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput.EventType;
+import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler.Visitor;
+import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamTrace;
+import software.amazon.awssdk.services.bedrockruntime.model.MessageStartEvent;
+import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent;
+import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage;
+import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlockStart;
+
+/**
+ * Amazon Bedrock Converse API utils.
+ *
+ * @author Wei Jiang
+ * @author Christian Tzolov
+ * @since 1.0.0
+ */
+public class ConverseApiUtils {
+
+ public static boolean isToolUseStart(ConverseStreamOutput event) {
+ if (event == null || event.sdkEventType() == null || event.sdkEventType() != EventType.CONTENT_BLOCK_START) {
+ return false;
+ }
+
+ return ContentBlockStart.Type.TOOL_USE == ((ContentBlockStartEvent) event).start().type();
+ }
+
+ public static boolean isToolUseFinish(ConverseStreamOutput event) {
+ if (event == null || event.sdkEventType() == null || event.sdkEventType() != EventType.METADATA) {
+ return false;
+ }
+ return true;
+ }
+
+ public record Aggregation(MetadataAggregation metadataAggregation, ChatResponse chatResponse) {
+ public Aggregation() {
+ this(MetadataAggregation.builder().build(), EMPTY_CHAT_RESPONSE);
+ }
+ }
+
+ /**
+ * Special event used to aggregate multiple tool use events into a single event with
+ * list of aggregated ContentBlockToolUse.
+ */
+ public static class ToolUseAggregationEvent implements ConverseStreamOutput {
+
+ public record ToolUseEntry(Integer index, String id, String name, String input) {
+ }
+
+ private Integer index;
+
+ private String id;
+
+ private String name;
+
+ private String partialJson = "";
+
+ private List toolUseEntries = new ArrayList<>();
+
+ private DefaultUsage usage;
+
+ public List toolUseEntries() {
+ return this.toolUseEntries;
+ }
+
+ public boolean isEmpty() {
+ return (this.index == null || this.id == null || this.name == null
+ || !StringUtils.hasText(this.partialJson));
+ }
+
+ ToolUseAggregationEvent withIndex(Integer index) {
+ this.index = index;
+ return this;
+ }
+
+ ToolUseAggregationEvent withId(String id) {
+ this.id = id;
+ return this;
+ }
+
+ ToolUseAggregationEvent withName(String name) {
+ this.name = name;
+ return this;
+ }
+
+ ToolUseAggregationEvent withUsage(DefaultUsage usage) {
+ this.usage = usage;
+ return this;
+ }
+
+ ToolUseAggregationEvent appendPartialJson(String partialJson) {
+ this.partialJson = this.partialJson + partialJson;
+ return this;
+ }
+
+ void squashIntoContentBlock() {
+ this.toolUseEntries.add(new ToolUseEntry(this.index, this.id, this.name, this.partialJson));
+ this.index = null;
+ this.id = null;
+ this.name = null;
+ this.partialJson = "";
+ this.usage = null;
+ }
+
+ @Override
+ public String toString() {
+ return "EventToolUseBuilder [index=" + this.index + ", id=" + this.id + ", name=" + this.name
+ + ", partialJson=" + this.partialJson + ", toolUseMap=" + "]";
+ }
+
+ @Override
+ public List> sdkFields() {
+ return List.of();
+ }
+
+ @Override
+ public void accept(Visitor visitor) {
+ throw new UnsupportedOperationException();
+ }
+
+ }
+
+ public static ConverseStreamOutput mergeToolUseEvents(ConverseStreamOutput previousEvent,
+ ConverseStreamOutput event) {
+
+ ToolUseAggregationEvent toolUseEventAggregator = (ToolUseAggregationEvent) previousEvent;
+
+ if (event.sdkEventType() == EventType.CONTENT_BLOCK_START) {
+
+ ContentBlockStartEvent contentBlockStart = (ContentBlockStartEvent) event;
+
+ if (ContentBlockStart.Type.TOOL_USE.equals(contentBlockStart.start().type())) {
+ ToolUseBlockStart cbToolUse = contentBlockStart.start().toolUse();
+
+ return toolUseEventAggregator.withIndex(contentBlockStart.contentBlockIndex())
+ .withId(cbToolUse.toolUseId())
+ .withName(cbToolUse.name())
+ .appendPartialJson(""); // CB START always has empty JSON.
+ }
+ }
+ else if (event.sdkEventType() == EventType.CONTENT_BLOCK_DELTA) {
+ ContentBlockDeltaEvent contentBlockDelta = (ContentBlockDeltaEvent) event;
+ if (ContentBlockDelta.Type.TOOL_USE == contentBlockDelta.delta().type()) {
+ return toolUseEventAggregator.appendPartialJson(contentBlockDelta.delta().toolUse().input());
+ }
+ }
+ else if (event.sdkEventType() == EventType.CONTENT_BLOCK_STOP) {
+ return toolUseEventAggregator;
+ }
+ else if (event.sdkEventType() == EventType.MESSAGE_STOP) {
+ return toolUseEventAggregator;
+ }
+ else if (event.sdkEventType() == EventType.METADATA) {
+ ConverseStreamMetadataEvent metadataEvent = (ConverseStreamMetadataEvent) event;
+ DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens().longValue(),
+ metadataEvent.usage().outputTokens().longValue(), metadataEvent.usage().totalTokens().longValue());
+ toolUseEventAggregator.withUsage(usage);
+ // TODO
+ if (!toolUseEventAggregator.isEmpty()) {
+ toolUseEventAggregator.squashIntoContentBlock();
+ return toolUseEventAggregator;
+ }
+ }
+
+ return event;
+ }
+
+ public static Flux toChatResponse(Flux responses) {
+
+ AtomicBoolean isInsideTool = new AtomicBoolean(false);
+
+ return responses.map(event -> {
+ if (ConverseApiUtils.isToolUseStart(event)) {
+ isInsideTool.set(true);
+ }
+ return event;
+ }).windowUntil(event -> { // Group all chunks belonging to the same function call.
+ if (isInsideTool.get() && ConverseApiUtils.isToolUseFinish(event)) {
+ isInsideTool.set(false);
+ return true;
+ }
+ return !isInsideTool.get();
+ }).concatMapIterable(window -> {// Merging the window chunks into a single chunk.
+ Mono monoChunk = window.reduce(new ToolUseAggregationEvent(),
+ ConverseApiUtils::mergeToolUseEvents);
+ return List.of(monoChunk);
+ }).flatMap(mono -> mono).scanWith(() -> new Aggregation(), (lastAggregation, nextEvent) -> {
+
+ // System.out.println(nextEvent);
+ if (nextEvent instanceof ToolUseAggregationEvent toolUseAggregationEvent) {
+
+ if (CollectionUtils.isEmpty(toolUseAggregationEvent.toolUseEntries())) {
+ return new Aggregation();
+ }
+
+ List toolCalls = new ArrayList<>();
+
+ for (ToolUseAggregationEvent.ToolUseEntry toolUseEntry : toolUseAggregationEvent.toolUseEntries()) {
+ var functionCallId = toolUseEntry.id();
+ var functionName = toolUseEntry.name();
+ var functionArguments = toolUseEntry.input();
+ toolCalls.add(
+ new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
+ }
+
+ AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
+ Generation toolCallGeneration = new Generation(assistantMessage,
+ ChatGenerationMetadata.from("tool_use", null));
+
+ var chatResponseMetaData = ChatResponseMetadata.builder()
+ .withUsage(toolUseAggregationEvent.usage)
+ .build();
+
+ return new Aggregation(
+ MetadataAggregation.builder().copy(lastAggregation.metadataAggregation()).build(),
+ new ChatResponse(List.of(toolCallGeneration), chatResponseMetaData));
+
+ }
+ else if (nextEvent instanceof MessageStartEvent messageStartEvent) {
+ var newMeta = MetadataAggregation.builder()
+ .copy(lastAggregation.metadataAggregation())
+ .withRole(messageStartEvent.role().toString())
+ .build();
+ return new Aggregation(newMeta, ConverseApiUtils.EMPTY_CHAT_RESPONSE);
+ }
+ else if (nextEvent instanceof MessageStopEvent messageStopEvent) {
+ var newMeta = MetadataAggregation.builder()
+ .copy(lastAggregation.metadataAggregation())
+ .withStopReason(messageStopEvent.stopReasonAsString())
+ .withAdditionalModelResponseFields(messageStopEvent.additionalModelResponseFields())
+ .build();
+ return new Aggregation(newMeta, ConverseApiUtils.EMPTY_CHAT_RESPONSE);
+ }
+ else if (nextEvent instanceof ContentBlockStartEvent contentBlockStartEvent) {
+ // TODO ToolUse support
+ return new Aggregation();
+ }
+ else if (nextEvent instanceof ContentBlockDeltaEvent contentBlockDeltaEvent) {
+ if (contentBlockDeltaEvent.delta().type().equals(ContentBlockDelta.Type.TEXT)) {
+
+ var generation = new Generation(
+ new AssistantMessage(contentBlockDeltaEvent.delta().text(), Map.of()),
+ ChatGenerationMetadata.from(lastAggregation.metadataAggregation().stopReason(), null));
+
+ return new Aggregation(
+ MetadataAggregation.builder().copy(lastAggregation.metadataAggregation()).build(),
+ new ChatResponse(List.of(generation)));
+ }
+ else if (contentBlockDeltaEvent.delta().type().equals(ContentBlockDelta.Type.TOOL_USE)) {
+ // TODO ToolUse support
+ }
+ return new Aggregation();
+ }
+ else if (nextEvent instanceof ContentBlockStopEvent contentBlockStopEvent) {
+ // TODO ToolUse support
+ return new Aggregation();
+ }
+ else if (nextEvent instanceof ConverseStreamMetadataEvent metadataEvent) {
+ // return new Aggregation();
+ var newMeta = MetadataAggregation.builder()
+ .copy(lastAggregation.metadataAggregation())
+ .withTokenUsage(metadataEvent.usage())
+ .withMetrics(metadataEvent.metrics())
+ .withTrace(metadataEvent.trace())
+ .build();
+
+ DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens().longValue(),
+ metadataEvent.usage().outputTokens().longValue(),
+ metadataEvent.usage().totalTokens().longValue());
+
+ // TODO
+ Document modelResponseFields = lastAggregation.metadataAggregation().additionalModelResponseFields();
+ ConverseStreamMetrics metrics = metadataEvent.metrics();
+
+ var chatResponseMetaData = ChatResponseMetadata.builder().withUsage(usage).build();
+
+ return new Aggregation(newMeta, new ChatResponse(List.of(), chatResponseMetaData));
+ }
+ else {
+ return new Aggregation();
+ }
+ })
+ // .skip(1)
+ .map(aggregation -> aggregation.chatResponse())
+ .filter(chatResponse -> chatResponse != ConverseApiUtils.EMPTY_CHAT_RESPONSE);
+ }
+
+ public static final ChatResponse EMPTY_CHAT_RESPONSE = ChatResponse.builder()
+ .withGenerations(List.of())
+ .withMetadata("empty", true)
+ .build();
+
+ public record MetadataAggregation(String role, String stopReason, Document additionalModelResponseFields,
+ TokenUsage tokenUsage, ConverseStreamMetrics metrics, ConverseStreamTrace trace) {
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public final static class Builder {
+
+ private String role;
+
+ private String stopReason;
+
+ private Document additionalModelResponseFields;
+
+ private TokenUsage tokenUsage;
+
+ private ConverseStreamMetrics metrics;
+
+ private ConverseStreamTrace trace;
+
+ private Builder() {
+ }
+
+ public Builder copy(MetadataAggregation metadataAggregation) {
+ this.role = metadataAggregation.role;
+ this.stopReason = metadataAggregation.stopReason;
+ this.additionalModelResponseFields = metadataAggregation.additionalModelResponseFields;
+ this.tokenUsage = metadataAggregation.tokenUsage;
+ this.metrics = metadataAggregation.metrics;
+ this.trace = metadataAggregation.trace;
+ return this;
+ }
+
+ public Builder withRole(String role) {
+ this.role = role;
+ return this;
+ }
+
+ public Builder withStopReason(String stopReason) {
+ this.stopReason = stopReason;
+ return this;
+ }
+
+ public Builder withAdditionalModelResponseFields(Document additionalModelResponseFields) {
+ this.additionalModelResponseFields = additionalModelResponseFields;
+ return this;
+ }
+
+ public Builder withTokenUsage(TokenUsage tokenUsage) {
+ this.tokenUsage = tokenUsage;
+ return this;
+ }
+
+ public Builder withMetrics(ConverseStreamMetrics metrics) {
+ this.metrics = metrics;
+ return this;
+ }
+
+ public Builder withTrace(ConverseStreamTrace trace) {
+ this.trace = trace;
+ return this;
+ }
+
+ public MetadataAggregation build() {
+ return new MetadataAggregation(role, stopReason, additionalModelResponseFields, tokenUsage, metrics,
+ trace);
+ }
+
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ public static Document getChatOptionsAdditionalModelRequestFields(ChatOptions defaultOptions,
+ ModelOptions promptOptions) {
+ if (defaultOptions == null && promptOptions == null) {
+ return null;
+ }
+
+ Map attributes = new HashMap<>();
+
+ if (defaultOptions != null) {
+ attributes.putAll(ModelOptionsUtils.objectToMap(defaultOptions));
+ }
+
+ if (promptOptions != null) {
+ if (promptOptions instanceof ChatOptions runtimeOptions) {
+ attributes.putAll(ModelOptionsUtils.objectToMap(runtimeOptions));
+ }
+ else {
+ throw new IllegalArgumentException(
+ "Prompt options are not of type ChatOptions:" + promptOptions.getClass().getSimpleName());
+ }
+ }
+
+ attributes.remove("model");
+ attributes.remove("proxyToolCalls");
+ attributes.remove("functions");
+ attributes.remove("toolContext");
+ attributes.remove("functionCallbacks");
+
+ attributes.remove("temperature");
+ attributes.remove("topK");
+ attributes.remove("stopSequences");
+ attributes.remove("maxTokens");
+ attributes.remove("topP");
+
+ return convertObjectToDocument(attributes);
+ }
+
+ @SuppressWarnings("unchecked")
+ public static Document convertObjectToDocument(Object value) {
+ if (value == null) {
+ return Document.fromNull();
+ }
+ else if (value instanceof String stringValue) {
+ return Document.fromString(stringValue);
+ }
+ else if (value instanceof Boolean booleanValue) {
+ return Document.fromBoolean(booleanValue);
+ }
+ else if (value instanceof Integer integerValue) {
+ return Document.fromNumber(integerValue);
+ }
+ else if (value instanceof Long longValue) {
+ return Document.fromNumber(longValue);
+ }
+ else if (value instanceof Float floatValue) {
+ return Document.fromNumber(floatValue);
+ }
+ else if (value instanceof Double doubleValue) {
+ return Document.fromNumber(doubleValue);
+ }
+ else if (value instanceof BigDecimal bigDecimalValue) {
+ return Document.fromNumber(bigDecimalValue);
+ }
+ else if (value instanceof BigInteger bigIntegerValue) {
+ return Document.fromNumber(bigIntegerValue);
+ }
+ else if (value instanceof List listValue) {
+ return Document.fromList(listValue.stream().map(v -> convertObjectToDocument(v)).toList());
+ }
+ else if (value instanceof Map mapValue) {
+ return convertMapToDocument(mapValue);
+ }
+ else {
+ throw new IllegalArgumentException("Unsupported value type:" + value.getClass().getSimpleName());
+ }
+ }
+
+ private static Document convertMapToDocument(Map value) {
+ Map attr = value.entrySet()
+ .stream()
+ .collect(Collectors.toMap(e -> e.getKey(), e -> convertObjectToDocument(e.getValue())));
+
+ return Document.fromMap(attr);
+ }
+
+}
\ No newline at end of file
diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/URLValidator.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/URLValidator.java
new file mode 100644
index 00000000000..342ce5ba545
--- /dev/null
+++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/URLValidator.java
@@ -0,0 +1,124 @@
+/*
+* Copyright 2024 - 2024 the original author or authors.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* https://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+package org.springframework.ai.bedrock.converse.api;
+
+import java.net.MalformedURLException;
+import java.net.URISyntaxException;
+import java.net.URL;
+import java.util.regex.Pattern;
+
+/**
+ * Utility class for detecting and normalizing URLs. Intended for use with multimodal user
+ * inputs.
+ *
+ * @author Christian Tzolov
+ * @since 1.0.0
+ */
+public class URLValidator {
+
+ // Basic URL regex pattern
+ // Protocol (http:// or https://)
+ private static final Pattern URL_PATTERN = Pattern.compile("^(https?://)" +
+
+ "((([a-zA-Z0-9-]+\\.)+[a-zA-Z]{2,6})|" + // Domain name
+ "(localhost))" + // OR localhost
+ "(:[0-9]{1,5})?" + // Optional port
+ "(/[\\w\\-./]*)*" + // Optional path
+ "(\\?[\\w=&\\-.]*)?" + // Optional query parameters
+ "(#[\\w-]*)?" + // Optional fragment
+ "$");
+
+ /**
+ * Quick validation using regex pattern Good for basic checks but may not catch all
+ * edge cases
+ */
+ public static boolean isValidURLBasic(String urlString) {
+ if (urlString == null || urlString.trim().isEmpty()) {
+ return false;
+ }
+ return URL_PATTERN.matcher(urlString).matches();
+ }
+
+ /**
+ * Thorough validation using URL class More comprehensive but might be slower
+ * Validates protocol, host, port, and basic structure
+ */
+ public static boolean isValidURLStrict(String urlString) {
+ if (urlString == null || urlString.trim().isEmpty()) {
+ return false;
+ }
+
+ try {
+ URL url = new URL(urlString);
+ // Additional validation by attempting to convert to URI
+ url.toURI();
+
+ // Ensure protocol is http or https
+ String protocol = url.getProtocol().toLowerCase();
+ if (!protocol.equals("http") && !protocol.equals("https")) {
+ return false;
+ }
+
+ // Validate host (not empty and contains at least one dot, unless it's
+ // localhost)
+ String host = url.getHost();
+ if (host == null || host.isEmpty()) {
+ return false;
+ }
+ if (!host.equals("localhost") && !host.contains(".")) {
+ return false;
+ }
+
+ // Validate port (if specified)
+ int port = url.getPort();
+ if (port != -1 && (port < 1 || port > 65535)) {
+ return false;
+ }
+
+ return true;
+ }
+ catch (MalformedURLException | URISyntaxException e) {
+ return false;
+ }
+ }
+
+ /**
+ * Attempts to fix common URL issues Adds protocol if missing, removes extra spaces
+ */
+ public static String normalizeURL(String urlString) {
+ if (urlString == null || urlString.trim().isEmpty()) {
+ return null;
+ }
+
+ String normalized = urlString.trim();
+
+ // Add protocol if missing
+ if (!normalized.toLowerCase().startsWith("http://") && !normalized.toLowerCase().startsWith("https://")) {
+ normalized = "https://" + normalized;
+ }
+
+ // Remove multiple forward slashes in path (except after protocol)
+ normalized = normalized.replaceAll("(? s.text(this.systemTextResource)
+ .param("name", "Bob")
+ .param("voice", "pirate"))
+ .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
+ .call()
+ .chatResponse();
+ // @formatter:on
+
+ logger.info("" + response);
+ assertThat(response.getResults()).hasSize(1);
+ assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard");
+ }
+
+ @Test
+ void listOutputConverterString() {
+ // @formatter:off
+ List collection = ChatClient.create(this.chatModel).prompt()
+ .user(u -> u.text("List five {subject}")
+ .param("subject", "ice cream flavors"))
+ .call()
+ .entity(new ParameterizedTypeReference>() {});
+ // @formatter:on
+
+ logger.info(collection.toString());
+ assertThat(collection).hasSize(5);
+ }
+
+ @Test
+ void listOutputConverterBean() {
+
+ // @formatter:off
+ List actorsFilms = ChatClient.create(this.chatModel).prompt()
+ .user("Generate the filmography of 5 movies for Tom Hanks and Bill Murray.")
+ .call()
+ .entity(new ParameterizedTypeReference>() {
+ });
+ // @formatter:on
+
+ logger.info("" + actorsFilms);
+ assertThat(actorsFilms).hasSize(2);
+ }
+
+ @Test
+ void customOutputConverter() {
+
+ var toStringListConverter = new ListOutputConverter(new DefaultConversionService());
+
+ // @formatter:off
+ List flavors = ChatClient.create(this.chatModel).prompt()
+ .user(u -> u.text("List five {subject}")
+ .param("subject", "ice cream flavors"))
+ .call()
+ .entity(toStringListConverter);
+ // @formatter:on
+
+ logger.info("ice cream flavors" + flavors);
+ assertThat(flavors).hasSize(5);
+ assertThat(flavors).containsAnyOf("Vanilla", "vanilla");
+ }
+
+ @Test
+ void mapOutputConverter() {
+ // @formatter:off
+ Map result = ChatClient.create(this.chatModel).prompt()
+ .user(u -> u.text("Provide me a List of {subject}")
+ .param("subject", "an array of numbers from 1 to 9 under they key name 'numbers'"))
+ .call()
+ .entity(new ParameterizedTypeReference