From 483055c2ef9f1920ce85a658a73784da6681ccaa Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 11 Feb 2025 16:46:20 +0100 Subject: [PATCH] refactor(bedrock): Migrate from function calling to tool calling - Replace function calling with tool calling in BedrockProxyChatModel - Deprecate function calling related code and APIs - Add new tool calling manager and options - Update builder pattern to remove "with" prefix from methods - Update tests and documentation for tool calling Part of the #2207 epic Signed-off-by: Christian Tzolov --- .../converse/BedrockProxyChatModel.java | 389 ++++++++++++++---- .../converse/api/ConverseApiUtils.java | 6 +- .../converse/BedrockConverseChatClientIT.java | 23 +- .../BedrockConverseTestConfiguration.java | 10 +- .../BedrockConverseUsageAggregationTests.java | 1 + .../converse/BedrockProxyChatModelIT.java | 28 +- .../client/BedrockNovaChatClientIT.java | 9 +- .../BedrockConverseChatModelMain2.java | 76 ---- .../ai/openai/chat/OpenAiChatModelIT.java | 26 +- .../ROOT/pages/api/chat/bedrock-converse.adoc | 40 +- ...ockConverseProxyChatAutoConfiguration.java | 30 +- .../BedrockConverseProxyChatProperties.java | 8 +- 12 files changed, 447 insertions(+), 199 deletions(-) delete mode 100644 models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain2.java diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 37afa84de43..fb2d4834d39 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,10 +24,8 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Base64; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; @@ -35,7 +33,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; import reactor.core.publisher.Sinks.EmitFailureHandler; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; @@ -60,7 +57,6 @@ import software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration; import software.amazon.awssdk.services.bedrockruntime.model.Message; import software.amazon.awssdk.services.bedrockruntime.model.S3Location; -import software.amazon.awssdk.services.bedrockruntime.model.StopReason; import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock; import software.amazon.awssdk.services.bedrockruntime.model.Tool; import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration; @@ -96,11 +92,15 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.Media; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.DefaultFunctionCallingOptions; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackResolver; import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.LegacyToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StreamUtils; @@ -138,36 +138,87 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build(); + private final BedrockRuntimeClient bedrockRuntimeClient; private final BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient; - private FunctionCallingOptions defaultOptions; + private ToolCallingChatOptions defaultOptions; /** * Observation registry used for instrumentation. */ private final ObservationRegistry observationRegistry; + private final ToolCallingManager toolCallingManager; + /** * Conventions to use for generating observations. */ private ChatModelObservationConvention observationConvention; + /** + * @deprecated Use + * {@link #BedrockProxyChatModel(BedrockRuntimeClient, BedrockRuntimeAsyncClient, ToolCallingChatOptions, ObservationRegistry, ToolCallingManager)} + * instead. + */ + @Deprecated public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient, BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, FunctionCallingOptions defaultOptions, FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks, ObservationRegistry observationRegistry) { - super(functionCallbackResolver, defaultOptions, toolFunctionCallbacks); + this(bedrockRuntimeClient, bedrockRuntimeAsyncClient, from(defaultOptions), observationRegistry, + LegacyToolCallingManager.builder() + .functionCallbackResolver(functionCallbackResolver) + .functionCallbacks(toolFunctionCallbacks) + .build()); + } + + public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient, + BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions, + ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager) { + + // We do not pass the 'defaultOptions' to the AbstractToolSupport, + // because it modifies them. We are using ToolCallingManager instead, + // so we just pass empty options here. + super(null, FunctionCallingOptions.builder().build(), List.of()); Assert.notNull(bedrockRuntimeClient, "bedrockRuntimeClient must not be null"); Assert.notNull(bedrockRuntimeAsyncClient, "bedrockRuntimeAsyncClient must not be null"); + Assert.notNull(toolCallingManager, "toolCallingManager must not be null"); this.bedrockRuntimeClient = bedrockRuntimeClient; this.bedrockRuntimeAsyncClient = bedrockRuntimeAsyncClient; this.defaultOptions = defaultOptions; this.observationRegistry = observationRegistry; + this.toolCallingManager = toolCallingManager; + } + + @Deprecated + private static ToolCallingChatOptions from(FunctionCallingOptions options) { + return ToolCallingChatOptions.builder() + .model(options.getModel()) + .maxTokens(options.getMaxTokens()) + .stopSequences(options.getStopSequences()) + .temperature(options.getTemperature()) + .topP(options.getTopP()) + .toolCallbacks(options.getFunctionCallbacks()) + .toolNames(options.getFunctions()) + .internalToolExecutionEnabled(options.getProxyToolCalls() != null ? !options.getProxyToolCalls() : false) + .toolContext(options.getToolContext()) + .build(); + } + + private static ToolCallingChatOptions from(ChatOptions options) { + return ToolCallingChatOptions.builder() + .model(options.getModel()) + .maxTokens(options.getMaxTokens()) + .stopSequences(options.getStopSequences()) + .temperature(options.getTemperature()) + .topP(options.getTopP()) + .build(); } /** @@ -180,7 +231,8 @@ public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient, */ @Override public ChatResponse call(Prompt prompt) { - return this.internalCall(prompt, null); + Prompt requestPrompt = buildRequestPrompt(prompt); + return this.internalCall(requestPrompt, null); } private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatResponse) { @@ -190,7 +242,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(AiProvider.BEDROCK_CONVERSE.value()) - .requestOptions(buildRequestOptions(converseRequest)) + .requestOptions(prompt.getOptions()) .build(); ChatResponse chatResponse = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION @@ -209,48 +261,100 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon return response; }); - if (!this.isProxyToolCalls(prompt, this.defaultOptions) && chatResponse != null - && this.isToolCall(chatResponse, Set.of(StopReason.TOOL_USE.toString()))) { - var toolCallConversation = this.handleToolCalls(prompt, chatResponse); - return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse); + if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && chatResponse != null + && chatResponse.hasToolCalls()) { + var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return ChatResponse.builder() + .from(chatResponse) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build(); + } + else { + // Send the tool execution result back to the model. + return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), + chatResponse); + } } - return chatResponse; } - private ChatOptions buildRequestOptions(ConverseRequest request) { - return ChatOptions.builder() - .model(request.modelId()) - .maxTokens(request.inferenceConfig().maxTokens()) - .stopSequences(request.inferenceConfig().stopSequences()) - .temperature(request.inferenceConfig().temperature() != null - ? request.inferenceConfig().temperature().doubleValue() : null) - .topP(request.inferenceConfig().topP() != null ? request.inferenceConfig().topP().doubleValue() : null) - .build(); - } + // private ToolCallingChatOptions buildRequestOptions(ConverseRequest request) { + + // ToolCallingChatOptions toolCallbackChatOptions = ToolCallingChatOptions.builder() + // .model(request.modelId()) + // .maxTokens(request.inferenceConfig().maxTokens()) + // .stopSequences(request.inferenceConfig().stopSequences()) + // .temperature(request.inferenceConfig().temperature() != null + // ? request.inferenceConfig().temperature().doubleValue() + // : null) + // .topP(request.inferenceConfig().topP() != null ? + // request.inferenceConfig().topP().doubleValue() : null) + // .build(); + + // return toolCallbackChatOptions; + // } @Override public ChatOptions getDefaultOptions() { return this.defaultOptions; } - public ConverseStreamRequest createStreamRequest(Prompt prompt) { + Prompt buildRequestPrompt(Prompt prompt) { + ToolCallingChatOptions runtimeOptions = null; + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { + runtimeOptions = toolCallingChatOptions.copy(); + } + else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { + runtimeOptions = from(functionCallingOptions); + } + else { + runtimeOptions = from(prompt.getOptions()); + } + } - ConverseRequest converseRequest = this.createRequest(prompt); + // Merge runtime options with the default options + ToolCallingChatOptions updatedRuntimeOptions = null; + if (runtimeOptions == null) { + updatedRuntimeOptions = this.defaultOptions.copy(); + } + else { + updatedRuntimeOptions = ToolCallingChatOptions.builder() + .model(runtimeOptions.getModel() != null ? runtimeOptions.getModel() : this.defaultOptions.getModel()) + .frequencyPenalty(runtimeOptions.getFrequencyPenalty() != null ? runtimeOptions.getFrequencyPenalty() + : this.defaultOptions.getFrequencyPenalty()) + .maxTokens(runtimeOptions.getMaxTokens() != null ? runtimeOptions.getMaxTokens() + : this.defaultOptions.getMaxTokens()) + .presencePenalty(runtimeOptions.getPresencePenalty() != null ? runtimeOptions.getPresencePenalty() + : this.defaultOptions.getPresencePenalty()) + .stopSequences(runtimeOptions.getStopSequences() != null ? runtimeOptions.getStopSequences() + : this.defaultOptions.getStopSequences()) + .temperature(runtimeOptions.getTemperature() != null ? runtimeOptions.getTemperature() + : this.defaultOptions.getTemperature()) + .topK(runtimeOptions.getTopK() != null ? runtimeOptions.getTopK() : this.defaultOptions.getTopK()) + .topP(runtimeOptions.getTopP() != null ? runtimeOptions.getTopP() : this.defaultOptions.getTopP()) + + .toolCallbacks(runtimeOptions.getToolCallbacks() != null ? runtimeOptions.getToolCallbacks() + : this.defaultOptions.getToolCallbacks()) + .toolNames(runtimeOptions.getToolNames() != null ? runtimeOptions.getToolNames() + : this.defaultOptions.getToolNames()) + .toolContext(runtimeOptions.getToolContext() != null ? runtimeOptions.getToolContext() + : this.defaultOptions.getToolContext()) + .internalToolExecutionEnabled(runtimeOptions.isInternalToolExecutionEnabled() != null + ? runtimeOptions.isInternalToolExecutionEnabled() + : this.defaultOptions.isInternalToolExecutionEnabled()) + .build(); + } - return ConverseStreamRequest.builder() - .modelId(converseRequest.modelId()) - .messages(converseRequest.messages()) - .system(converseRequest.system()) - .additionalModelRequestFields(converseRequest.additionalModelRequestFields()) - .toolConfig(converseRequest.toolConfig()) - .build(); + ToolCallingChatOptions.validateToolCallbacks(updatedRuntimeOptions.getToolCallbacks()); + + return new Prompt(prompt.getInstructions(), updatedRuntimeOptions); } ConverseRequest createRequest(Prompt prompt) { - Set functionsForThisRequest = new HashSet<>(); - List instructionMessages = prompt.getInstructions() .stream() .filter(message -> message.getMessageType() != MessageType.SYSTEM) @@ -318,26 +422,29 @@ else if (message.getMessageType() == MessageType.TOOL) { .map(sysMessage -> SystemContentBlock.builder().text(sysMessage.getText()).build()) .toList(); - FunctionCallingOptions updatedRuntimeOptions = (FunctionCallingOptions) this.defaultOptions.copy(); - - if (prompt.getOptions() != null) { - if (prompt.getOptions() instanceof FunctionCallingOptions) { - var functionCallingOptions = (FunctionCallingOptions) prompt.getOptions(); - updatedRuntimeOptions = ((DefaultFunctionCallingOptions) updatedRuntimeOptions) - .merge(functionCallingOptions); - } - else if (prompt.getOptions() instanceof ChatOptions) { - var chatOptions = (ChatOptions) prompt.getOptions(); - updatedRuntimeOptions = ((DefaultFunctionCallingOptions) updatedRuntimeOptions).merge(chatOptions); - } - } - - functionsForThisRequest.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions)); + ToolCallingChatOptions updatedRuntimeOptions = prompt.getOptions().copy(); ToolConfiguration toolConfiguration = null; - if (!CollectionUtils.isEmpty(functionsForThisRequest)) { - toolConfiguration = ToolConfiguration.builder().tools(getFunctionTools(functionsForThisRequest)).build(); + // Add the tool definitions to the request's tools parameter. + List toolDefinitions = this.toolCallingManager.resolveToolDefinitions(updatedRuntimeOptions); + + if (!CollectionUtils.isEmpty(toolDefinitions)) { + List bedrockTools = toolDefinitions.stream().map(toolDefinition -> { + var description = toolDefinition.description(); + var name = toolDefinition.name(); + String inputSchema = toolDefinition.inputSchema(); + return Tool.builder() + .toolSpec(ToolSpecification.builder() + .name(name) + .description(description) + .inputSchema(ToolInputSchema.fromJson( + ConverseApiUtils.convertObjectToDocument(ModelOptionsUtils.jsonToMap(inputSchema)))) + .build()) + .build(); + }).toList(); + + toolConfiguration = ToolConfiguration.builder().tools(bedrockTools).build(); } InferenceConfiguration inferenceConfiguration = InferenceConfiguration.builder() @@ -446,22 +553,6 @@ else if (BedrockMediaFormat.isSupportedDocumentFormat(mimeType)) { // Document throw new IllegalArgumentException("Unsupported media format: " + mimeType); } - private List getFunctionTools(Set functionNames) { - return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> { - var description = functionCallback.getDescription(); - var name = functionCallback.getName(); - String inputSchema = functionCallback.getInputTypeSchema(); - return Tool.builder() - .toolSpec(ToolSpecification.builder() - .name(name) - .description(description) - .inputSchema(ToolInputSchema - .fromJson(ConverseApiUtils.convertObjectToDocument(ModelOptionsUtils.jsonToMap(inputSchema)))) - .build()) - .build(); - }).toList(); - } - private static byte[] getContentMediaData(Object mediaData) { if (mediaData instanceof byte[] bytes) { return bytes; @@ -583,7 +674,8 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv */ @Override public Flux stream(Prompt prompt) { - return this.internalStream(prompt, null); + Prompt requestPrompt = buildRequestPrompt(prompt); + return this.internalStream(requestPrompt, null); } private Flux internalStream(Prompt prompt, ChatResponse perviousChatResponse) { @@ -596,7 +688,7 @@ private Flux internalStream(Prompt prompt, ChatResponse perviousCh ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(AiProvider.BEDROCK_CONVERSE.value()) - .requestOptions(buildRequestOptions(converseRequest)) + .requestOptions(prompt.getOptions()) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( @@ -620,12 +712,23 @@ private Flux internalStream(Prompt prompt, ChatResponse perviousCh Flux chatResponses = ConverseApiUtils.toChatResponse(response, perviousChatResponse); Flux chatResponseFlux = chatResponses.switchMap(chatResponse -> { - if (!this.isProxyToolCalls(prompt, this.defaultOptions) && chatResponse != null - && this.isToolCall(chatResponse, Set.of(StopReason.TOOL_USE.toString()))) { - var toolCallConversation = this.handleToolCalls(prompt, chatResponse); - return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse); + + if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && chatResponse.hasToolCalls()) { + var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return Flux.just(ChatResponse.builder().from(chatResponse) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build()); + } else { + // Send the tool execution result back to the model. + return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), + chatResponse); + } + } + else { + return Flux.just(chatResponse); } - return Mono.just(chatResponse); }) .doOnError(observation::error) .doFinally(s -> observation.stop()) @@ -699,7 +802,9 @@ public static final class Builder { private Duration timeout = Duration.ofMinutes(10); - private FunctionCallingOptions defaultOptions = new DefaultFunctionCallingOptions(); + private ToolCallingManager toolCallingManager; + + private ToolCallingChatOptions defaultOptions = ToolCallingChatOptions.builder().build(); private FunctionCallbackResolver functionCallbackResolver; @@ -716,33 +821,87 @@ public static final class Builder { private Builder() { } + public Builder toolCallingManager(ToolCallingManager toolCallingManager) { + this.toolCallingManager = toolCallingManager; + return this; + } + + /** + * @deprecated Use {@link #credentialsProvider(AwsCredentialsProvider)} instead. + */ + @Deprecated public Builder withCredentialsProvider(AwsCredentialsProvider credentialsProvider) { Assert.notNull(credentialsProvider, "'credentialsProvider' must not be null."); this.credentialsProvider = credentialsProvider; return this; } + public Builder credentialsProvider(AwsCredentialsProvider credentialsProvider) { + Assert.notNull(credentialsProvider, "'credentialsProvider' must not be null."); + this.credentialsProvider = credentialsProvider; + return this; + } + + /** + * @deprecated Use {@link #region(Region)} instead. + */ + @Deprecated public Builder withRegion(Region region) { Assert.notNull(region, "'region' must not be null."); this.region = region; return this; } + public Builder region(Region region) { + Assert.notNull(region, "'region' must not be null."); + this.region = region; + return this; + } + + /** + * @deprecated Use {@link #timeout(Duration)} instead. + */ + @Deprecated public Builder withTimeout(Duration timeout) { Assert.notNull(timeout, "'timeout' must not be null."); this.timeout = timeout; return this; } + public Builder timeout(Duration timeout) { + Assert.notNull(timeout, "'timeout' must not be null."); + this.timeout = timeout; + return this; + } + + /** + * @deprecated Use {@link #defaultOptions(ToolCallingChatOptions)} instead. + */ + @Deprecated public Builder withDefaultOptions(FunctionCallingOptions defaultOptions) { + Assert.notNull(defaultOptions, "'defaultOptions' must not be null."); + return this.defaultOptions(ToolCallingChatOptions.builder() + .model(defaultOptions.getModel()) + .maxTokens(defaultOptions.getMaxTokens()) + .stopSequences(defaultOptions.getStopSequences()) + .temperature(defaultOptions.getTemperature()) + .topP(defaultOptions.getTopP()) + .toolCallbacks(defaultOptions.getFunctionCallbacks()) + .toolNames(defaultOptions.getFunctions()) + .internalToolExecutionEnabled( + defaultOptions.getProxyToolCalls() != null ? !defaultOptions.getProxyToolCalls() : false) + .toolContext(defaultOptions.getToolContext()) + .build()); + } + + public Builder defaultOptions(ToolCallingChatOptions defaultOptions) { Assert.notNull(defaultOptions, "'defaultOptions' must not be null."); this.defaultOptions = defaultOptions; return this; } /** - * @deprecated Use {@link #functionCallbackResolver(FunctionCallbackResolver)} - * instead. + * @deprecated To be removed after M6 */ @Deprecated public Builder withFunctionCallbackContext(FunctionCallbackResolver functionCallbackResolver) { @@ -755,33 +914,77 @@ public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbac return this; } + /** + * @deprecated To be removed after M6 + */ + @Deprecated public Builder withToolFunctionCallbacks(List toolFunctionCallbacks) { this.toolFunctionCallbacks = toolFunctionCallbacks; return this; } + /** + * @deprecated Use {@link #observationRegistry(ObservationRegistry)} instead. + */ + @Deprecated public Builder withObservationRegistry(ObservationRegistry observationRegistry) { Assert.notNull(observationRegistry, "'observationRegistry' must not be null."); this.observationRegistry = observationRegistry; return this; } + public Builder observationRegistry(ObservationRegistry observationRegistry) { + Assert.notNull(observationRegistry, "'observationRegistry' must not be null."); + this.observationRegistry = observationRegistry; + return this; + } + + /** + * @deprecated Use + * {@link #customObservationConvention(ChatModelObservationConvention)} instead. + */ + @Deprecated public Builder withCustomObservationConvention(ChatModelObservationConvention observationConvention) { Assert.notNull(observationConvention, "'observationConvention' must not be null."); this.customObservationConvention = observationConvention; return this; } + public Builder customObservationConvention(ChatModelObservationConvention observationConvention) { + Assert.notNull(observationConvention, "'observationConvention' must not be null."); + this.customObservationConvention = observationConvention; + return this; + } + + /** + * @deprecated Use {@link #bedrockRuntimeClient(BedrockRuntimeClient)} instead. + */ + @Deprecated public Builder withBedrockRuntimeClient(BedrockRuntimeClient bedrockRuntimeClient) { this.bedrockRuntimeClient = bedrockRuntimeClient; return this; } + public Builder bedrockRuntimeClient(BedrockRuntimeClient bedrockRuntimeClient) { + this.bedrockRuntimeClient = bedrockRuntimeClient; + return this; + } + + /** + * @deprecated Use {@link #bedrockRuntimeAsyncClient(BedrockRuntimeAsyncClient)} + * instead. + */ + @Deprecated public Builder withBedrockRuntimeAsyncClient(BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient) { this.bedrockRuntimeAsyncClient = bedrockRuntimeAsyncClient; return this; } + public Builder bedrockRuntimeAsyncClient(BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient) { + this.bedrockRuntimeAsyncClient = bedrockRuntimeAsyncClient; + return this; + } + public BedrockProxyChatModel build() { if (this.bedrockRuntimeClient == null) { @@ -809,9 +1012,31 @@ public BedrockProxyChatModel build() { this.bedrockRuntimeAsyncClient = builder.build(); } - var bedrockProxyChatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient, - this.bedrockRuntimeAsyncClient, this.defaultOptions, this.functionCallbackResolver, - this.toolFunctionCallbacks, this.observationRegistry); + BedrockProxyChatModel bedrockProxyChatModel = null; + + if (this.toolCallingManager != null) { + Assert.isNull(functionCallbackResolver, + "functionCallbackResolver cannot be set when toolCallingManager is set"); + Assert.isNull(toolFunctionCallbacks, + "toolFunctionCallbacks cannot be set when toolCallingManager is set"); + + bedrockProxyChatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient, + this.bedrockRuntimeAsyncClient, this.defaultOptions, this.observationRegistry, + this.toolCallingManager); + + } + else if (this.functionCallbackResolver != null) { + Assert.isNull(toolCallingManager, + "toolCallingManager cannot be set when functionCallbackResolver is set"); + bedrockProxyChatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient, + this.bedrockRuntimeAsyncClient, this.defaultOptions, this.functionCallbackResolver, + this.toolFunctionCallbacks, this.observationRegistry); + } + else { + bedrockProxyChatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient, + this.bedrockRuntimeAsyncClient, this.defaultOptions, this.observationRegistry, + DEFAULT_TOOL_CALLING_MANAGER); + } if (this.customObservationConvention != null) { bedrockProxyChatModel.setObservationConvention(this.customObservationConvention); diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java index 1d170567bf4..fc233beb5ea 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -331,6 +331,10 @@ public static Document getChatOptionsAdditionalModelRequestFields(ChatOptions de attributes.remove("toolContext"); attributes.remove("functionCallbacks"); + attributes.remove("toolCallbacks"); + attributes.remove("toolNames"); + attributes.remove("internalToolExecutionEnabled"); + attributes.remove("temperature"); attributes.remove("topK"); attributes.remove("stopSequences"); diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java index 3f6638b9f12..cfbe4fd1b3d 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java @@ -380,7 +380,8 @@ void multiModalityEmbeddedImage(String modelName) throws IOException { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "anthropic.claude-3-5-sonnet-20240620-v1:0" }) - void multiModalityImageUrl(String modelName) throws IOException { + @Deprecated + void multiModalityImageUrl2(String modelName) throws IOException { // TODO: add url method that wrapps the checked exception. URL url = new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); @@ -398,6 +399,26 @@ void multiModalityImageUrl(String modelName) throws IOException { assertThat(response).containsAnyOf("bananas", "apple", "bowl", "basket", "fruit stand"); } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "anthropic.claude-3-5-sonnet-20240620-v1:0" }) + void multiModalityImageUrl(String modelName) throws IOException { + + // TODO: add url method that wrapps the checked exception. + URL url = new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); + + // @formatter:off + String response = ChatClient.create(this.chatModel).prompt() + // TODO consider adding model(...) method to ChatClient as a shortcut to + .options(ToolCallingChatOptions.builder().model(modelName).build()) + .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url)) + .call() + .content(); + // @formatter:on + + logger.info(response); + assertThat(response).containsAnyOf("bananas", "apple", "bowl", "basket", "fruit stand"); + } + @Test void streamingMultiModalityImageUrl() throws IOException { diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java index f2537e95b0b..7f2fd2c979f 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,10 +38,10 @@ public BedrockProxyChatModel bedrockConverseChatModel() { String modelId = "anthropic.claude-3-5-sonnet-20240620-v1:0"; return BedrockProxyChatModel.builder() - .withCredentialsProvider(EnvironmentVariableCredentialsProvider.create()) - .withRegion(Region.US_EAST_1) - .withTimeout(Duration.ofSeconds(120)) - // .withRegion(Region.US_EAST_1) + .credentialsProvider(EnvironmentVariableCredentialsProvider.create()) + .region(Region.US_EAST_1) + // .region(Region.US_EAST_1) + .timeout(Duration.ofSeconds(120)) .withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build()) .build(); } diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java index bdbe95bda48..25a44b63afe 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java @@ -20,6 +20,7 @@ import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java index 7eda0b87761..3652d322bd1 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java @@ -48,6 +48,8 @@ import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; @@ -244,8 +246,9 @@ void multiModalityTest() throws IOException { "fruit stand"); } + @Deprecated @Test - void functionCallTest() { + void functionCallTestDeprecated() { UserMessage userMessage = new UserMessage( "What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius."); @@ -269,6 +272,29 @@ void functionCallTest() { assertThat(generation.getOutput().getText()).contains("30", "10", "15"); } + @Test + void functionCallTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = ToolCallingChatOptions.builder() + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location. Return in 36°C format") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + Generation generation = response.getResult(); + assertThat(generation.getOutput().getText()).contains("30", "10", "15"); + } + @Test void streamFunctionCallTest() { diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java index eeadcaf9d5a..cd90c436850 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java @@ -20,6 +20,7 @@ import java.time.Duration; import java.util.Set; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,7 +33,6 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; @@ -47,6 +47,7 @@ /** * @author Christian Tzolov */ +@Disabled @SpringBootTest(classes = BedrockNovaChatClientIT.Config.class) @RequiresAwsCredentials public class BedrockNovaChatClientIT { @@ -181,9 +182,9 @@ public BedrockProxyChatModel bedrockConverseChatModel() { String modelId = "amazon.nova-pro-v1:0"; return BedrockProxyChatModel.builder() - .withCredentialsProvider(EnvironmentVariableCredentialsProvider.create()) - .withRegion(Region.US_EAST_1) - .withTimeout(Duration.ofSeconds(120)) + .credentialsProvider(EnvironmentVariableCredentialsProvider.create()) + .region(Region.US_EAST_1) + .timeout(Duration.ofSeconds(120)) .withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build()) .build(); } diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain2.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain2.java deleted file mode 100644 index b2d95eab70f..00000000000 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain2.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.bedrock.converse.experiments; - -import java.util.List; - -import reactor.core.publisher.Flux; -import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; - -import org.springframework.ai.bedrock.converse.BedrockProxyChatModel; -import org.springframework.ai.bedrock.converse.MockWeatherService; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; - -/** - * Used for reverse engineering the protocol - */ -public final class BedrockConverseChatModelMain2 { - - private BedrockConverseChatModelMain2() { - - } - - public static void main(String[] args) { - - // String modelId = "anthropic.claude-3-5-sonnet-20240620-v1:0"; - // String modelId = "ai21.jamba-1-5-large-v1:0"; - String modelId = "anthropic.claude-3-5-sonnet-20240620-v1:0"; - - // var prompt = new Prompt("Tell me a joke?", - // ChatOptions.builder().model(modelId).build(); - var prompt = new Prompt( - // "What's the weather like in San Francisco, Tokyo, and Paris? Return the - // temperature in Celsius.", - "What's the weather like in Paris? Return the temperature in Celsius.", - FunctionCallingOptions.builder() - .model(modelId) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) - .description("Get the weather in location") - .inputType(MockWeatherService.Request.class) - .build())) - .build()); - - BedrockProxyChatModel chatModel = BedrockProxyChatModel.builder() - .withCredentialsProvider(EnvironmentVariableCredentialsProvider.create()) - .withRegion(Region.US_EAST_1) - .build(); - - var streamRequest = chatModel.createStreamRequest(prompt); - - Flux responses = chatModel.converseStream(streamRequest); - List responseList = responses.collectList().block(); - System.out.println(responseList); - System.out.println("Response count: " + responseList.size()); - responseList.forEach(System.out::println); - } - -} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java index 20b23e5b5b1..105ecdad392 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java @@ -62,6 +62,7 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters.Voice; import org.springframework.ai.openai.api.tool.MockWeatherService; import org.springframework.ai.openai.testutils.AbstractIT; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.convert.support.DefaultConversionService; @@ -332,7 +333,8 @@ void beanStreamOutputConverterRecords() { } @Test - void functionCallTest() { + @Deprecated + void functionCallTestDeprecated() { UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); @@ -356,6 +358,28 @@ void functionCallTest() { assertThat(response.getResult().getOutput().getText()).containsAnyOf("15.0", "15"); } + @Test + void functionCallTest() { + + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = OpenAiChatOptions.builder() + .model(OpenAiApi.ChatModel.GPT_4_O.getValue()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); + } + @Test void streamFunctionCallTest() { diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock-converse.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock-converse.adoc index 0863efb4bd2..cd3c7f76be9 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock-converse.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock-converse.adoc @@ -90,7 +90,7 @@ The prefix `spring.ai.bedrock.converse.chat` is the property prefix that configu == Runtime Options [[chat-options]] -Use the portable `ChatOptions` or `FunctionCallingOptions` portable builders to create model configurations, such as temperature, maxToken, topP, etc. +Use the portable `ChatOptions` or `ToolCallingChatOptions` portable builders to create model configurations, such as temperature, maxToken, topP, etc. On start-up, the default options can be configured with the `BedrockConverseProxyChatModel(api, options)` constructor or the `spring.ai.bedrock.converse.chat.options.*` properties. @@ -98,12 +98,11 @@ At run-time you can override the default options by adding new, request specific [source,java] ---- -var options = FunctionCallingOptions.builder() - .withModel("anthropic.claude-3-5-sonnet-20240620-v1:0") - .withTemperature(0.6) - .withMaxTokens(300) - .withFunctionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new WeatherService()) +var options = ToolCallingChatOptions.builder() + .model("anthropic.claude-3-5-sonnet-20240620-v1:0") + .temperature(0.6) + .maxTokens(300) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new WeatherService()) .description("Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") .inputType(WeatherService.Request.class) .build())) @@ -118,7 +117,28 @@ String response = ChatClient.create(this.chatModel) == Tool/Function Calling -The Bedrock Converse API supports function calling capabilities, allowing models to use tools during conversations. Here's an example of how to define and use functions: +The Bedrock Converse API supports tool calling capabilities, allowing models to use tools during conversations. +Here's an example of how to define and use @Tool based tools: + +[source,java] +---- + +public class WeatherService { + + @Tool(description = "Get the weather in location") + public String weatherByLocation(@ToolParam(description= "City or state name") String location) { + ... + } +} + +String response = ChatClient.create(this.chatModel) + .prompt("What's the weather like in Boston?") + .tools(new WeatherService()) + .call() + .content(); +---- + +You can use the java.util.function beans as tools as well: [source,java] ---- @@ -130,12 +150,14 @@ public Function weatherFunction() { String response = ChatClient.create(this.chatModel) .prompt("What's the weather like in Boston?") - .function("weatherFunction") + .tools("weatherFunction") .inputType(Request.class) .call() .content(); ---- +Find more in xref:api/tools.adoc[Tools] documentation. + == Multimodal Multimodality refers to a model's ability to simultaneously understand and process information from various sources, including text, images, video, pdf, doc, html, md and more data formats. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatAutoConfiguration.java index 2f6b21dbc72..e732d106286 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatAutoConfiguration.java @@ -16,8 +16,6 @@ package org.springframework.ai.autoconfigure.bedrock.converse; -import java.util.List; - import io.micrometer.observation.ObservationRegistry; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.providers.AwsRegionProvider; @@ -26,13 +24,15 @@ import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.chat.model.ToolCallingAutoConfiguration; import org.springframework.ai.bedrock.converse.BedrockProxyChatModel; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.ImportAutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; @@ -50,12 +50,13 @@ * @author Christian Tzolov * @author Wei Jiang */ -@AutoConfiguration +@AutoConfiguration(after = { ToolCallingAutoConfiguration.class }) @EnableConfigurationProperties({ BedrockConverseProxyChatProperties.class, BedrockAwsConnectionConfiguration.class }) @ConditionalOnClass({ BedrockProxyChatModel.class, BedrockRuntimeClient.class, BedrockRuntimeAsyncClient.class }) @ConditionalOnProperty(prefix = BedrockConverseProxyChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) @Import(BedrockAwsConnectionConfiguration.class) +@ImportAutoConfiguration({ ToolCallingAutoConfiguration.class }) public class BedrockConverseProxyChatAutoConfiguration { @Bean @@ -63,22 +64,21 @@ public class BedrockConverseProxyChatAutoConfiguration { @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) public BedrockProxyChatModel bedrockProxyChatModel(AwsCredentialsProvider credentialsProvider, AwsRegionProvider regionProvider, BedrockAwsConnectionProperties connectionProperties, - BedrockConverseProxyChatProperties chatProperties, FunctionCallbackResolver functionCallbackResolver, - List toolFunctionCallbacks, ObjectProvider observationRegistry, + BedrockConverseProxyChatProperties chatProperties, ToolCallingManager toolCallingManager, + ObjectProvider observationRegistry, ObjectProvider observationConvention, ObjectProvider bedrockRuntimeClient, ObjectProvider bedrockRuntimeAsyncClient) { var chatModel = BedrockProxyChatModel.builder() - .withCredentialsProvider(credentialsProvider) - .withRegion(regionProvider.getRegion()) - .withTimeout(connectionProperties.getTimeout()) - .withDefaultOptions(chatProperties.getOptions()) - .withObservationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) - .functionCallbackResolver(functionCallbackResolver) - .withToolFunctionCallbacks(toolFunctionCallbacks) - .withBedrockRuntimeClient(bedrockRuntimeClient.getIfAvailable()) - .withBedrockRuntimeAsyncClient(bedrockRuntimeAsyncClient.getIfAvailable()) + .credentialsProvider(credentialsProvider) + .region(regionProvider.getRegion()) + .timeout(connectionProperties.getTimeout()) + .defaultOptions(chatProperties.getOptions()) + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .toolCallingManager(toolCallingManager) + .bedrockRuntimeClient(bedrockRuntimeClient.getIfAvailable()) + .bedrockRuntimeAsyncClient(bedrockRuntimeAsyncClient.getIfAvailable()) .build(); observationConvention.ifAvailable(chatModel::setObservationConvention); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatProperties.java index bca0e2fba0d..128eaf71b02 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatProperties.java @@ -16,7 +16,7 @@ package org.springframework.ai.autoconfigure.bedrock.converse; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.util.Assert; @@ -38,7 +38,7 @@ public class BedrockConverseProxyChatProperties { private boolean enabled = true; @NestedConfigurationProperty - private FunctionCallingOptions options = FunctionCallingOptions.builder() + private ToolCallingChatOptions options = ToolCallingChatOptions.builder() .temperature(0.7) .maxTokens(300) .topK(10) @@ -52,11 +52,11 @@ public void setEnabled(boolean enabled) { this.enabled = enabled; } - public FunctionCallingOptions getOptions() { + public ToolCallingChatOptions getOptions() { return this.options; } - public void setOptions(FunctionCallingOptions options) { + public void setOptions(ToolCallingChatOptions options) { Assert.notNull(options, "FunctionCallingOptions must not be null"); this.options = options; }