diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 94fe45595a9..600361cd6f4 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,17 +18,24 @@ import java.util.ArrayList; import java.util.Base64; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import com.fasterxml.jackson.core.type.TypeReference; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.model.tool.LegacyToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.util.json.JsonParser; +import org.springframework.lang.Nullable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -96,6 +103,8 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build(); + /** * The retry template used to retry the OpenAI API calls. */ @@ -116,6 +125,8 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM */ private final ObservationRegistry observationRegistry; + private final ToolCallingManager toolCallingManager; + /** * Conventions to use for generating observations. */ @@ -124,7 +135,9 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM /** * Construct a new {@link AnthropicChatModel} instance. * @param anthropicApi the lower-level API for the Anthropic service. + * @deprecated Use {@link AnthropicChatModel.Builder}. */ + @Deprecated public AnthropicChatModel(AnthropicApi anthropicApi) { this(anthropicApi, AnthropicChatOptions.builder() @@ -138,7 +151,9 @@ public AnthropicChatModel(AnthropicApi anthropicApi) { * Construct a new {@link AnthropicChatModel} instance. * @param anthropicApi the lower-level API for the Anthropic service. * @param defaultOptions the default options used for the chat completion requests. + * @deprecated Use {@link AnthropicChatModel.Builder}. */ + @Deprecated public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions) { this(anthropicApi, defaultOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE); } @@ -148,7 +163,9 @@ public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaul * @param anthropicApi the lower-level API for the Anthropic service. * @param defaultOptions the default options used for the chat completion requests. * @param retryTemplate the retry template used to retry the Anthropic API calls. + * @deprecated Use {@link AnthropicChatModel.Builder}. */ + @Deprecated public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions, RetryTemplate retryTemplate) { this(anthropicApi, defaultOptions, retryTemplate, null); @@ -161,10 +178,11 @@ public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaul * @param retryTemplate the retry template used to retry the Anthropic API calls. * @param functionCallbackResolver the function callback resolver used to resolve the * function by its name. + * @deprecated Use {@link AnthropicChatModel.Builder}. */ + @Deprecated public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions, RetryTemplate retryTemplate, FunctionCallbackResolver functionCallbackResolver) { - this(anthropicApi, defaultOptions, retryTemplate, functionCallbackResolver, List.of()); } @@ -177,7 +195,9 @@ public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaul * function by its name. * @param toolFunctionCallbacks the tool function callbacks used to handle the tool * calls. + * @deprecated Use {@link AnthropicChatModel.Builder}. */ + @Deprecated public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions, RetryTemplate retryTemplate, FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks) { @@ -194,27 +214,49 @@ public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaul * function by its name. * @param toolFunctionCallbacks the tool function callbacks used to handle the tool * calls. + * @deprecated Use {@link AnthropicChatModel.Builder}. */ + @Deprecated public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions, - RetryTemplate retryTemplate, FunctionCallbackResolver functionCallbackResolver, - List toolFunctionCallbacks, ObservationRegistry observationRegistry) { - - super(functionCallbackResolver, defaultOptions, toolFunctionCallbacks); + RetryTemplate retryTemplate, @Nullable FunctionCallbackResolver functionCallbackResolver, + @Nullable List toolFunctionCallbacks, ObservationRegistry observationRegistry) { + this(anthropicApi, defaultOptions, + LegacyToolCallingManager.builder() + .functionCallbackResolver(functionCallbackResolver) + .functionCallbacks(toolFunctionCallbacks) + .build(), + retryTemplate, observationRegistry); + logger.warn("This constructor is deprecated and will be removed in the next milestone. " + + "Please use the MistralAiChatModel.Builder or the new constructor accepting ToolCallingManager instead."); + } - Assert.notNull(anthropicApi, "AnthropicApi must not be null"); - Assert.notNull(defaultOptions, "DefaultOptions must not be null"); - Assert.notNull(retryTemplate, "RetryTemplate must not be null"); - Assert.notNull(observationRegistry, "ObservationRegistry must not be null"); + public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions, + ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, + ObservationRegistry observationRegistry) { + // We do not pass the 'defaultOptions' to the AbstractToolSupport, + // because it modifies them. We are using ToolCallingManager instead, + // so we just pass empty options here. + super(null, AnthropicChatOptions.builder().build(), List.of()); + + Assert.notNull(anthropicApi, "anthropicApi cannot be null"); + Assert.notNull(defaultOptions, "defaultOptions cannot be null"); + Assert.notNull(toolCallingManager, "toolCallingManager cannot be null"); + Assert.notNull(retryTemplate, "retryTemplate cannot be null"); + Assert.notNull(observationRegistry, "observationRegistry cannot be null"); this.anthropicApi = anthropicApi; this.defaultOptions = defaultOptions; + this.toolCallingManager = toolCallingManager; this.retryTemplate = retryTemplate; this.observationRegistry = observationRegistry; } @Override public ChatResponse call(Prompt prompt) { - return this.internalCall(prompt, null); + // Before moving any further, build the final request Prompt, + // merging runtime and default options. + Prompt requestPrompt = buildRequestPrompt(prompt); + return this.internalCall(requestPrompt, null); } public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { @@ -223,7 +265,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(AnthropicApi.PROVIDER_NAME) - .requestOptions(buildRequestOptions(request)) + .requestOptions(prompt.getOptions()) .build(); ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION @@ -247,10 +289,21 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons return chatResponse; }); - if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null - && this.isToolCall(response, Set.of("tool_use"))) { - var toolCallConversation = handleToolCalls(prompt, response); - return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response); + if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null + && response.hasToolCalls()) { + var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return ChatResponse.builder() + .from(response) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build(); + } + else { + // Send the tool execution result back to the model. + return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), + response); + } } return response; @@ -263,7 +316,10 @@ private DefaultUsage getDefaultUsage(AnthropicApi.Usage usage) { @Override public Flux stream(Prompt prompt) { - return this.internalStream(prompt, null); + // Before moving any further, build the final request Prompt, + // merging runtime and default options. + Prompt requestPrompt = buildRequestPrompt(prompt); + return this.internalStream(requestPrompt, null); } public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { @@ -273,7 +329,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(AnthropicApi.PROVIDER_NAME) - .requestOptions(buildRequestOptions(request)) + .requestOptions(prompt.getOptions()) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( @@ -291,9 +347,18 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage); - if (!isProxyToolCalls(prompt, this.defaultOptions) && this.isToolCall(chatResponse, Set.of("tool_use"))) { - var toolCallConversation = handleToolCalls(prompt, chatResponse); - return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse); + if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && chatResponse.hasToolCalls() && chatResponse.hasFinishReasons(Set.of("tool_use"))) { + var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return Flux.just(ChatResponse.builder().from(chatResponse) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build()); + } else { + // Send the tool execution result back to the model. + return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), + chatResponse); + } } return Mono.just(chatResponse); @@ -341,7 +406,7 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage var functionCallId = toolToUse.id(); var functionName = toolToUse.name(); - var functionArguments = ModelOptionsUtils.toJsonString(toolToUse.input()); + var functionArguments = JsonParser.toJson(toolToUse.input()); toolCalls .add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments)); @@ -397,9 +462,55 @@ else if (mimeType.contains("pdf")) { + ". Supported types are: images (image/*) and PDF documents (application/pdf)"); } - ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + Prompt buildRequestPrompt(Prompt prompt) { + // Process runtime options + AnthropicChatOptions runtimeOptions = null; + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { + runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, + AnthropicChatOptions.class); + } + else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { + runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class, + AnthropicChatOptions.class); + } + else { + runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, + AnthropicChatOptions.class); + } + } - Set functionsForThisRequest = new HashSet<>(); + // Define request options by merging runtime options and default options + AnthropicChatOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, + AnthropicChatOptions.class); + + // Merge @JsonIgnore-annotated options explicitly since they are ignored by + // Jackson, used by ModelOptionsUtils. + if (runtimeOptions != null) { + requestOptions.setInternalToolExecutionEnabled( + ModelOptionsUtils.mergeOption(runtimeOptions.isInternalToolExecutionEnabled(), + this.defaultOptions.isInternalToolExecutionEnabled())); + requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), + this.defaultOptions.getToolNames())); + requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), + this.defaultOptions.getToolCallbacks())); + requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(), + this.defaultOptions.getToolContext())); + } + else { + requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.isInternalToolExecutionEnabled()); + + requestOptions.setToolNames(this.defaultOptions.getToolNames()); + requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); + requestOptions.setToolContext(this.defaultOptions.getToolContext()); + } + + ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); + + return new Prompt(prompt.getInstructions(), requestOptions); + } + + ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { List userMessages = prompt.getInstructions() .stream() @@ -457,58 +568,29 @@ else if (message.getMessageType() == MessageType.TOOL) { ChatCompletionRequest request = new ChatCompletionRequest(this.defaultOptions.getModel(), userMessages, systemPrompt, this.defaultOptions.getMaxTokens(), this.defaultOptions.getTemperature(), stream); - if (prompt.getOptions() != null) { - AnthropicChatOptions updatedRuntimeOptions; - if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { - updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, - FunctionCallingOptions.class, AnthropicChatOptions.class); - } - else { - updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, - AnthropicChatOptions.class); - } - - functionsForThisRequest.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions)); - - request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class); - } - - if (!CollectionUtils.isEmpty(this.defaultOptions.getFunctions())) { - functionsForThisRequest.addAll(this.defaultOptions.getFunctions()); - } - - request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class); + AnthropicChatOptions requestOptions = (AnthropicChatOptions) prompt.getOptions(); + request = ModelOptionsUtils.merge(requestOptions, request, ChatCompletionRequest.class); - if (!CollectionUtils.isEmpty(functionsForThisRequest)) { - - List tools = getFunctionTools(functionsForThisRequest); - - request = ChatCompletionRequest.from(request).withTools(tools).build(); + // Add the tool definitions to the request's tools parameter. + List toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); + if (!CollectionUtils.isEmpty(toolDefinitions)) { + request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class); + request = ChatCompletionRequest.from(request).withTools(getFunctionTools(toolDefinitions)).build(); } return request; } - private List getFunctionTools(Set functionNames) { - return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> { - var description = functionCallback.getDescription(); - var name = functionCallback.getName(); - String inputSchema = functionCallback.getInputTypeSchema(); - return new AnthropicApi.Tool(name, description, ModelOptionsUtils.jsonToMap(inputSchema)); + private List getFunctionTools(List toolDefinitions) { + return toolDefinitions.stream().map(toolDefinition -> { + var name = toolDefinition.name(); + var description = toolDefinition.description(); + String inputSchema = toolDefinition.inputSchema(); + return new AnthropicApi.Tool(name, description, JsonParser.fromJson(inputSchema, new TypeReference<>() { + })); }).toList(); } - private ChatOptions buildRequestOptions(AnthropicApi.ChatCompletionRequest request) { - return ChatOptions.builder() - .model(request.model()) - .maxTokens(request.maxTokens()) - .stopSequences(request.stopSequences()) - .temperature(request.temperature()) - .topK(request.topK()) - .topP(request.topP()) - .build(); - } - @Override public ChatOptions getDefaultOptions() { return AnthropicChatOptions.fromOptions(this.defaultOptions); @@ -523,4 +605,92 @@ public void setObservationConvention(ChatModelObservationConvention observationC this.observationConvention = observationConvention; } + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + + private AnthropicApi anthropicApi; + + private AnthropicChatOptions defaultOptions = AnthropicChatOptions.builder() + .model(DEFAULT_MODEL_NAME) + .maxTokens(DEFAULT_MAX_TOKENS) + .temperature(DEFAULT_TEMPERATURE) + .build(); + + private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + + private FunctionCallbackResolver functionCallbackResolver; + + private List toolCallbacks; + + private ToolCallingManager toolCallingManager; + + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + + private Builder() { + } + + public Builder anthropicApi(AnthropicApi anthropicApi) { + this.anthropicApi = anthropicApi; + return this; + } + + public Builder defaultOptions(AnthropicChatOptions defaultOptions) { + this.defaultOptions = defaultOptions; + return this; + } + + public Builder retryTemplate(RetryTemplate retryTemplate) { + this.retryTemplate = retryTemplate; + return this; + } + + public Builder toolCallingManager(ToolCallingManager toolCallingManager) { + this.toolCallingManager = toolCallingManager; + return this; + } + + @Deprecated + public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) { + this.functionCallbackResolver = functionCallbackResolver; + return this; + } + + @Deprecated + public Builder toolCallbacks(List toolCallbacks) { + this.toolCallbacks = toolCallbacks; + return this; + } + + public Builder observationRegistry(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + return this; + } + + public AnthropicChatModel build() { + if (toolCallingManager != null) { + Assert.isNull(functionCallbackResolver, + "functionCallbackResolver cannot be set when toolCallingManager is set"); + Assert.isNull(toolCallbacks, "toolCallbacks cannot be set when toolCallingManager is set"); + + return new AnthropicChatModel(anthropicApi, defaultOptions, toolCallingManager, retryTemplate, + observationRegistry); + } + if (functionCallbackResolver != null) { + Assert.isNull(toolCallingManager, + "toolCallingManager cannot be set when functionCallbackResolver is set"); + List toolCallbacks = this.toolCallbacks != null ? this.toolCallbacks : List.of(); + + return new AnthropicChatModel(anthropicApi, defaultOptions, retryTemplate, functionCallbackResolver, + toolCallbacks, observationRegistry); + } + + return new AnthropicChatModel(anthropicApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate, + observationRegistry); + } + + } + } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index 229b5113126..c1b319a27ff 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ package org.springframework.ai.anthropic; import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -30,7 +32,9 @@ import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest; import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -42,7 +46,7 @@ * @since 1.0.0 */ @JsonInclude(Include.NON_NULL) -public class AnthropicChatOptions implements FunctionCallingOptions { +public class AnthropicChatOptions implements ToolCallingChatOptions { // @formatter:off private @JsonProperty("model") String model; @@ -54,34 +58,27 @@ public class AnthropicChatOptions implements FunctionCallingOptions { private @JsonProperty("top_k") Integer topK; /** - * Tool Function Callbacks to register with the ChatModel. For Prompt - * Options the functionCallbacks are automatically enabled for the duration of the - * prompt execution. For Default Options the functionCallbacks are registered but - * disabled by default. Use the enableFunctions to set the functions from the registry - * to be used by the ChatModel chat completion requests. + * Collection of {@link ToolCallback}s to be used for tool calling in the chat + * completion requests. */ @JsonIgnore - private List functionCallbacks = new ArrayList<>(); + private List toolCallbacks = new ArrayList<>(); /** - * List of functions, identified by their names, to configure for function calling in - * the chat completion requests. Functions with those names must exist in the - * functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions - * are automatically enabled for the duration of the prompt execution. - * - * Note that function enabled with the default options are enabled for all chat - * completion requests. This could impact the token count and the billing. If the - * functions is set in a prompt options, then the enabled functions are only active - * for the duration of this prompt execution. + * Collection of tool names to be resolved at runtime and used for tool calling in the + * chat completion requests. */ @JsonIgnore - private Set functions = new HashSet<>(); + private Set toolNames = new HashSet<>(); + /** + * Whether to enable the tool execution lifecycle internally in ChatModel. + */ @JsonIgnore - private Boolean proxyToolCalls; + private Boolean internalToolExecutionEnabled; @JsonIgnore - private Map toolContext; + private Map toolContext = new HashMap<>(); // @formatter:on @@ -97,9 +94,9 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) .temperature(fromOptions.getTemperature()) .topP(fromOptions.getTopP()) .topK(fromOptions.getTopK()) - .functionCallbacks(fromOptions.getFunctionCallbacks()) - .functions(fromOptions.getFunctions()) - .proxyToolCalls(fromOptions.getProxyToolCalls()) + .toolCallbacks(fromOptions.getToolCallbacks()) + .toolNames(fromOptions.getToolNames()) + .internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled()) .toolContext(fromOptions.getToolContext()) .build(); } @@ -167,25 +164,73 @@ public void setTopK(Integer topK) { } @Override + @JsonIgnore + public List getToolCallbacks() { + return this.toolCallbacks; + } + + @Override + @JsonIgnore + public void setToolCallbacks(List toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); + this.toolCallbacks = toolCallbacks; + } + + @Override + @JsonIgnore + public Set getToolNames() { + return this.toolNames; + } + + @Override + @JsonIgnore + public void setToolNames(Set toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + Assert.noNullElements(toolNames, "toolNames cannot contain null elements"); + toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements")); + this.toolNames = toolNames; + } + + @Override + @Nullable + @JsonIgnore + public Boolean isInternalToolExecutionEnabled() { + return internalToolExecutionEnabled; + } + + @Override + @JsonIgnore + public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.internalToolExecutionEnabled = internalToolExecutionEnabled; + } + + @Override + @Deprecated + @JsonIgnore public List getFunctionCallbacks() { - return this.functionCallbacks; + return this.getToolCallbacks(); } @Override + @Deprecated + @JsonIgnore public void setFunctionCallbacks(List functionCallbacks) { - Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null"); - this.functionCallbacks = functionCallbacks; + this.setToolCallbacks(functionCallbacks); } @Override + @Deprecated + @JsonIgnore public Set getFunctions() { - return this.functions; + return this.getToolNames(); } @Override - public void setFunctions(Set functions) { - Assert.notNull(functions, "Function must not be null"); - this.functions = functions; + @Deprecated + @JsonIgnore + public void setFunctions(Set functionNames) { + this.setToolNames(functionNames); } @Override @@ -201,20 +246,26 @@ public Double getPresencePenalty() { } @Override + @Deprecated + @JsonIgnore public Boolean getProxyToolCalls() { - return this.proxyToolCalls; + return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null; } + @Deprecated + @JsonIgnore public void setProxyToolCalls(Boolean proxyToolCalls) { - this.proxyToolCalls = proxyToolCalls; + this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null; } @Override + @JsonIgnore public Map getToolContext() { return this.toolContext; } @Override + @JsonIgnore public void setToolContext(Map toolContext) { this.toolContext = toolContext; } @@ -268,25 +319,54 @@ public Builder topK(Integer topK) { return this; } - public Builder functionCallbacks(List functionCallbacks) { - this.options.functionCallbacks = functionCallbacks; + public Builder toolCallbacks(List toolCallbacks) { + this.options.setToolCallbacks(toolCallbacks); return this; } - public Builder functions(Set functionNames) { - Assert.notNull(functionNames, "Function names must not be null"); - this.options.functions = functionNames; + public Builder toolCallbacks(FunctionCallback... toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks)); return this; } - public Builder function(String functionName) { - Assert.hasText(functionName, "Function name must not be empty"); - this.options.functions.add(functionName); + public Builder toolNames(Set toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + this.options.setToolNames(toolNames); + return this; + } + + public Builder toolNames(String... toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + this.options.toolNames.addAll(Set.of(toolNames)); return this; } + public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled); + return this; + } + + @Deprecated + public Builder functionCallbacks(List functionCallbacks) { + return toolCallbacks(functionCallbacks); + } + + @Deprecated + public Builder functions(Set functionNames) { + return toolNames(functionNames); + } + + @Deprecated + public Builder function(String functionName) { + return toolNames(functionName); + } + + @Deprecated public Builder proxyToolCalls(Boolean proxyToolCalls) { - this.options.proxyToolCalls = proxyToolCalls; + if (proxyToolCalls != null) { + this.options.setInternalToolExecutionEnabled(!proxyToolCalls); + } return this; } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/ChatCompletionRequestTests.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/ChatCompletionRequestTests.java index 5f84372ea08..5f1aedd51b6 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/ChatCompletionRequestTests.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/ChatCompletionRequestTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ /** * @author Christian Tzolov * @author Alexandros Pappas + * @author Thomas Vitale */ public class ChatCompletionRequestTests { @@ -35,7 +36,9 @@ public void createRequestWithChatOptions() { var client = new AnthropicChatModel(new AnthropicApi("TEST"), AnthropicChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build()); - var request = client.createRequest(new Prompt("Test message content"), false); + var prompt = client.buildRequestPrompt(new Prompt("Test message content")); + + var request = client.createRequest(prompt, false); assertThat(request.messages()).hasSize(1); assertThat(request.stream()).isFalse(); @@ -43,8 +46,10 @@ public void createRequestWithChatOptions() { assertThat(request.model()).isEqualTo("DEFAULT_MODEL"); assertThat(request.temperature()).isEqualTo(66.6); - request = client.createRequest(new Prompt("Test message content", - AnthropicChatOptions.builder().model("PROMPT_MODEL").temperature(99.9).build()), true); + prompt = client.buildRequestPrompt(new Prompt("Test message content", + AnthropicChatOptions.builder().model("PROMPT_MODEL").temperature(99.9).build())); + + request = client.createRequest(prompt, true); assertThat(request.messages()).hasSize(1); assertThat(request.stream()).isTrue(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java index adf3075706c..24eac16a53b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java @@ -23,6 +23,7 @@ import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.model.ModelResponse; +import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; /** @@ -111,6 +112,21 @@ public boolean hasToolCalls() { return generations.stream().anyMatch(generation -> generation.getOutput().hasToolCalls()); } + /** + * Whether the model has finished with any of the given finish reasons. + */ + public boolean hasFinishReasons(Set finishReasons) { + Assert.notNull(finishReasons, "finishReasons cannot be null"); + if (CollectionUtils.isEmpty(generations)) { + return false; + } + return generations.stream().anyMatch(generation -> { + var finishReason = (generation.getMetadata().getFinishReason() != null) + ? generation.getMetadata().getFinishReason() : ""; + return finishReasons.stream().map(String::toLowerCase).toList().contains(finishReason.toLowerCase()); + }); + } + @Override public String toString() { return "ChatResponse [metadata=" + this.chatResponseMetadata + ", generations=" + this.generations + "]"; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java index 2cdd976a815..d73054986d1 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java @@ -18,11 +18,14 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import java.util.List; import java.util.Map; +import java.util.Set; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Unit tests for {@link ChatResponse}. @@ -48,4 +51,32 @@ void whenNoToolCallsArePresentThenReturnFalse() { assertThat(chatResponse.hasToolCalls()).isFalse(); } + @Test + void whenFinishReasonIsNullThenThrow() { + var chatResponse = ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("Result"), + ChatGenerationMetadata.builder().finishReason("completed").build()))) + .build(); + assertThatThrownBy(() -> chatResponse.hasFinishReasons(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("finishReasons cannot be null"); + } + + @Test + void whenFinishReasonIsPresent() { + ChatResponse chatResponse = ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("Result"), + ChatGenerationMetadata.builder().finishReason("completed").build()))) + .build(); + assertThat(chatResponse.hasFinishReasons(Set.of("completed"))).isTrue(); + } + + @Test + void whenFinishReasonIsNotPresent() { + ChatResponse chatResponse = ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("Result"), + ChatGenerationMetadata.builder().finishReason("failed").build()))) + .build(); + assertThat(chatResponse.hasFinishReasons(Set.of("completed"))).isFalse(); + } + } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index f94c7ea0843..cebe0432942 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -8,7 +8,7 @@ *** xref:api/chat/comparison.adoc[Chat Models Comparison] *** xref:api/chat/bedrock-converse.adoc[Amazon Bedrock Converse] *** xref:api/chat/anthropic-chat.adoc[Anthropic 3] -**** xref:api/chat/functions/anthropic-chat-functions.adoc[Anthropic Function Calling] +**** xref:api/chat/functions/anthropic-chat-functions.adoc[Anthropic Function Calling (Deprecated)] *** xref:api/chat/azure-openai-chat.adoc[Azure OpenAI] **** xref:api/chat/functions/azure-open-ai-chat-functions.adoc[Azure OpenAI Function Calling] *** xref:api/chat/deepseek-chat.adoc[DeepSeek AI] @@ -18,19 +18,19 @@ *** xref:api/chat/groq-chat.adoc[Groq] *** xref:api/chat/huggingface.adoc[Hugging Face] *** xref:api/chat/mistralai-chat.adoc[Mistral AI] -**** xref:api/chat/functions/mistralai-chat-functions.adoc[Mistral Function Calling] +**** xref:api/chat/functions/mistralai-chat-functions.adoc[Mistral Function Calling (Deprecated)] *** xref:api/chat/minimax-chat.adoc[MiniMax] **** xref:api/chat/functions/minimax-chat-functions.adoc[MinmaxFunction Calling] *** xref:api/chat/moonshot-chat.adoc[Moonshot AI] //// **** xref:api/chat/functions/moonshot-chat-functions.adoc[Moonshot Function Calling] *** xref:api/chat/nvidia-chat.adoc[NVIDIA] *** xref:api/chat/ollama-chat.adoc[Ollama] -**** xref:api/chat/functions/ollama-chat-functions.adoc[Ollama Function Calling] +**** xref:api/chat/functions/ollama-chat-functions.adoc[Ollama Function Calling (Deprecated)] *** xref:api/chat/perplexity-chat.adoc[Perplexity AI] *** OCI Generative AI **** xref:api/chat/oci-genai/cohere-chat.adoc[Cohere] *** xref:api/chat/openai-chat.adoc[OpenAI] -**** xref:api/chat/functions/openai-chat-functions.adoc[OpenAI Function Calling] +**** xref:api/chat/functions/openai-chat-functions.adoc[OpenAI Function Calling (Deprecated)] *** xref:api/chat/qianfan-chat.adoc[QianFan] *** xref:api/chat/zhipuai-chat.adoc[ZhiPu AI] // **** xref:api/chat/functions/zhipuai-chat-functions.adoc[Function Calling] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc index f40f0f83855..52c7ce8daca 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc @@ -1,4 +1,6 @@ -= Anthropic Function Calling += Anthropic Function Calling (Deprecated) + +WARNING: This page describes the previous version of the Function Calling API, which has been deprecated and marked for remove in the next release. The current version is available at xref:api/tools.adoc[Tool Calling]. See the xref:api/tools-migration.adoc[Migration Guide] for more information. TIP: Starting of Jul 1st, 2024, streaming function calling and Tool use is supported. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfiguration.java index 3f653672ff4..40addc8d69e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/anthropic/AnthropicAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,17 +16,16 @@ package org.springframework.ai.autoconfigure.anthropic; -import java.util.List; - import io.micrometer.observation.ObservationRegistry; import org.springframework.ai.anthropic.AnthropicChatModel; import org.springframework.ai.anthropic.api.AnthropicApi; +import org.springframework.ai.autoconfigure.chat.model.ToolCallingAutoConfiguration; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.ImportAutoConfiguration; @@ -50,13 +49,14 @@ * @author Thomas Vitale * @since 1.0.0 */ -@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class }) +@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class, + ToolCallingAutoConfiguration.class }) @EnableConfigurationProperties({ AnthropicChatProperties.class, AnthropicConnectionProperties.class }) @ConditionalOnClass(AnthropicApi.class) @ConditionalOnProperty(prefix = AnthropicChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) @ImportAutoConfiguration(classes = { SpringAiRetryAutoConfiguration.class, RestClientAutoConfiguration.class, - WebClientAutoConfiguration.class }) + ToolCallingAutoConfiguration.class, WebClientAutoConfiguration.class }) public class AnthropicAutoConfiguration { @Bean @@ -74,13 +74,17 @@ public AnthropicApi anthropicApi(AnthropicConnectionProperties connectionPropert @Bean @ConditionalOnMissingBean public AnthropicChatModel anthropicChatModel(AnthropicApi anthropicApi, AnthropicChatProperties chatProperties, - RetryTemplate retryTemplate, FunctionCallbackResolver functionCallbackResolver, - List toolFunctionCallbacks, ObjectProvider observationRegistry, + RetryTemplate retryTemplate, ToolCallingManager toolCallingManager, + ObjectProvider observationRegistry, ObjectProvider observationConvention) { - var chatModel = new AnthropicChatModel(anthropicApi, chatProperties.getOptions(), retryTemplate, - functionCallbackResolver, toolFunctionCallbacks, - observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)); + var chatModel = AnthropicChatModel.builder() + .anthropicApi(anthropicApi) + .defaultOptions(chatProperties.getOptions()) + .toolCallingManager(toolCallingManager) + .retryTemplate(retryTemplate) + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .build(); observationConvention.ifAvailable(chatModel::setObservationConvention);