diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 6dcb48d28f2..ab8bfe27317 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -129,28 +129,18 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, @Nullable FunctionCallbackResolver functionCallbackResolver, @Nullable List toolFunctionCallbacks, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) { - super(functionCallbackResolver, defaultOptions, toolFunctionCallbacks); - Assert.notNull(ollamaApi, "ollamaApi must not be null"); - Assert.notNull(defaultOptions, "defaultOptions must not be null"); - Assert.notNull(observationRegistry, "observationRegistry must not be null"); - Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null"); - this.chatApi = ollamaApi; - this.defaultOptions = defaultOptions; - this.toolCallingManager = new LegacyToolCallingManager(functionCallbackResolver, toolFunctionCallbacks); - this.observationRegistry = observationRegistry; - this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions); - initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy()); + this(ollamaApi, defaultOptions, new LegacyToolCallingManager(functionCallbackResolver, toolFunctionCallbacks), + observationRegistry, modelManagementOptions); logger.warn("This constructor is deprecated and will be removed in the next milestone. " - + "Please use the new constructor accepting ToolCallingManager instead."); + + "Please use the OllamaChatModel.Builder or the new constructor accepting ToolCallingManager instead."); } public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) { - // We do not pass the 'defaultOptions' to the AbstractToolSupport, because it - // modifies them. - // We are not using the AbstractToolSupport class in this path, so we just pass - // empty options. + // We do not pass the 'defaultOptions' to the AbstractToolSupport, + // because it modifies them. We are using ToolCallingManager instead, + // so we just pass empty options here. super(null, OllamaOptions.builder().build(), List.of()); Assert.notNull(ollamaApi, "ollamaApi must not be null"); Assert.notNull(defaultOptions, "defaultOptions must not be null"); @@ -424,6 +414,8 @@ else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOp throw new IllegalArgumentException("model cannot be null or empty"); } + ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); + return new Prompt(prompt.getInstructions(), requestOptions); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java index 740c299ab8a..e3221f02640 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java @@ -20,6 +20,7 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; @@ -48,7 +49,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { .internalToolExecutionEnabled(true) .toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2")) .toolNames("tool1", "tool2") - .toolContext(Map.of("key1", "value1")) + .toolContext(Map.of("key1", "value1", "key2", "valueA")) .build(); OllamaChatModel chatModel = OllamaChatModel.builder() .ollamaApi(new OllamaApi()) @@ -59,17 +60,19 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { .internalToolExecutionEnabled(false) .toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4")) .toolNames("tool3") - .toolContext(Map.of("key2", "value2")) + .toolContext(Map.of("key2", "valueB")) .build(); Prompt prompt = chatModel.buildRequestPrompt(new Prompt("Test message content", runtimeOptions)); assertThat(((ToolCallingChatOptions) prompt.getOptions())).isNotNull(); assertThat(((ToolCallingChatOptions) prompt.getOptions()).isInternalToolExecutionEnabled()).isFalse(); - assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(4); - assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolNames()).containsExactlyInAnyOrder("tool1", - "tool2", "tool3"); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(2); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks() + .stream() + .map(FunctionCallback::getName)).containsExactlyInAnyOrder("tool3", "tool4"); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolNames()).containsExactlyInAnyOrder("tool3"); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolContext()).containsEntry("key1", "value1") - .containsEntry("key2", "value2"); + .containsEntry("key2", "valueB"); } @Test diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index fdb7f2020c4..8065f0b8b8d 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,10 +19,8 @@ import java.util.ArrayList; import java.util.Base64; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; @@ -31,6 +29,12 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.model.tool.LegacyToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.lang.Nullable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -112,6 +116,8 @@ public class OpenAiChatModel extends AbstractToolCallSupport implements ChatMode private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build(); + /** * The default options used for the chat completion requests. */ @@ -132,6 +138,8 @@ public class OpenAiChatModel extends AbstractToolCallSupport implements ChatMode */ private final ObservationRegistry observationRegistry; + private final ToolCallingManager toolCallingManager; + /** * Conventions to use for generating observations. */ @@ -142,7 +150,9 @@ public class OpenAiChatModel extends AbstractToolCallSupport implements ChatMode * @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI * Chat API. * @throws IllegalArgumentException if openAiApi is null + * @deprecated Use OpenAiChatModel.Builder. */ + @Deprecated public OpenAiChatModel(OpenAiApi openAiApi) { this(openAiApi, OpenAiChatOptions.builder().model(OpenAiApi.DEFAULT_CHAT_MODEL).temperature(0.7).build()); } @@ -152,7 +162,9 @@ public OpenAiChatModel(OpenAiApi openAiApi) { * @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI * Chat API. * @param options The OpenAiChatOptions to configure the chat model. + * @deprecated Use OpenAiChatModel.Builder. */ + @Deprecated public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options) { this(openAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE); } @@ -164,9 +176,11 @@ public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options) { * @param options The OpenAiChatOptions to configure the chat model. * @param functionCallbackResolver The function callback resolver. * @param retryTemplate The retry template. + * @deprecated Use OpenAiChatModel.Builder. */ + @Deprecated public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options, - FunctionCallbackResolver functionCallbackResolver, RetryTemplate retryTemplate) { + @Nullable FunctionCallbackResolver functionCallbackResolver, RetryTemplate retryTemplate) { this(openAiApi, options, functionCallbackResolver, List.of(), retryTemplate); } @@ -178,10 +192,12 @@ public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options, * @param functionCallbackResolver The function callback resolver. * @param toolFunctionCallbacks The tool function callbacks. * @param retryTemplate The retry template. + * @deprecated Use OpenAiChatModel.Builder. */ + @Deprecated public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options, - FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks, - RetryTemplate retryTemplate) { + @Nullable FunctionCallbackResolver functionCallbackResolver, + @Nullable List toolFunctionCallbacks, RetryTemplate retryTemplate) { this(openAiApi, options, functionCallbackResolver, toolFunctionCallbacks, retryTemplate, ObservationRegistry.NOOP); } @@ -195,29 +211,48 @@ public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options, * @param toolFunctionCallbacks The tool function callbacks. * @param retryTemplate The retry template. * @param observationRegistry The ObservationRegistry used for instrumentation. + * @deprecated Use OpenAiChatModel.Builder or OpenAiChatModel(OpenAiApi, + * OpenAiChatOptions, ToolCallingManager, RetryTemplate, ObservationRegistry). */ + @Deprecated public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options, - FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks, - RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { - - super(functionCallbackResolver, options, toolFunctionCallbacks); - - Assert.notNull(openAiApi, "OpenAiApi must not be null"); - Assert.notNull(options, "Options must not be null"); - Assert.notNull(retryTemplate, "RetryTemplate must not be null"); - Assert.isTrue(CollectionUtils.isEmpty(options.getFunctionCallbacks()), - "The default function callbacks must be set via the toolFunctionCallbacks constructor parameter"); - Assert.notNull(observationRegistry, "ObservationRegistry must not be null"); + @Nullable FunctionCallbackResolver functionCallbackResolver, + @Nullable List toolFunctionCallbacks, RetryTemplate retryTemplate, + ObservationRegistry observationRegistry) { + this(openAiApi, options, + LegacyToolCallingManager.builder() + .functionCallbackResolver(functionCallbackResolver) + .functionCallbacks(toolFunctionCallbacks) + .build(), + retryTemplate, observationRegistry); + logger.warn("This constructor is deprecated and will be removed in the next milestone. " + + "Please use the OpenAiChatModel.Builder or the new constructor accepting ToolCallingManager instead."); + } + public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager, + RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { + // We do not pass the 'defaultOptions' to the AbstractToolSupport, + // because it modifies them. We are using ToolCallingManager instead, + // so we just pass empty options here. + super(null, OpenAiChatOptions.builder().build(), List.of()); + Assert.notNull(openAiApi, "openAiApi cannot be null"); + Assert.notNull(defaultOptions, "defaultOptions cannot be null"); + Assert.notNull(toolCallingManager, "toolCallingManager cannot be null"); + Assert.notNull(retryTemplate, "retryTemplate cannot be null"); + Assert.notNull(observationRegistry, "observationRegistry cannot be null"); this.openAiApi = openAiApi; - this.defaultOptions = options; + this.defaultOptions = defaultOptions; + this.toolCallingManager = toolCallingManager; this.retryTemplate = retryTemplate; this.observationRegistry = observationRegistry; } @Override public ChatResponse call(Prompt prompt) { - return this.internalCall(prompt, null); + // Before moving any further, build the final request Prompt, + // merging runtime and default options. + Prompt requestPrompt = buildRequestPrompt(prompt); + return this.internalCall(requestPrompt, null); } public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { @@ -227,7 +262,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(OpenAiApiConstants.PROVIDER_NAME) - .requestOptions(buildRequestOptions(request)) + .requestOptions(prompt.getOptions()) .build(); ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION @@ -263,8 +298,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons return buildGeneration(choice, metadata, request); }).toList(); - // Non function calling. RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity); + // Current usage OpenAiApi.Usage usage = completionEntity.getBody().usage(); Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage(); @@ -278,13 +313,21 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons }); - if (!isProxyToolCalls(prompt, this.defaultOptions) - && isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), - OpenAiApi.ChatCompletionFinishReason.STOP.name()))) { - var toolCallConversation = handleToolCalls(prompt, response); - // Recursively call the call method with the tool call message - // conversation that contains the call responses. - return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response); + if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null + && response.hasToolCalls()) { + var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return ChatResponse.builder() + .from(response) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build(); + } + else { + // Send the tool execution result back to the model. + return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), + response); + } } return response; @@ -292,7 +335,10 @@ && isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.n @Override public Flux stream(Prompt prompt) { - return internalStream(prompt, null); + // Before moving any further, build the final request Prompt, + // merging runtime and default options. + Prompt requestPrompt = buildRequestPrompt(prompt); + return internalStream(requestPrompt, null); } public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { @@ -320,7 +366,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(OpenAiApiConstants.PROVIDER_NAME) - .requestOptions(buildRequestOptions(request)) + .requestOptions(prompt.getOptions()) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( @@ -392,12 +438,18 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // @formatter:off Flux flux = chatResponse.flatMap(response -> { - if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), - OpenAiApi.ChatCompletionFinishReason.STOP.name()))) { - var toolCallConversation = handleToolCalls(prompt, response); - // Recursively call the stream method with the tool call message - // conversation that contains the call responses. - return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response); + if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response.hasToolCalls()) { + var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return Flux.just(ChatResponse.builder().from(response) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build()); + } else { + // Send the tool execution result back to the model. + return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), + response); + } } else { return Flux.just(response); @@ -505,6 +557,63 @@ private DefaultUsage getDefaultUsage(OpenAiApi.Usage usage) { return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); } + Prompt buildRequestPrompt(Prompt prompt) { + // Process runtime options + OpenAiChatOptions runtimeOptions = null; + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { + runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, + OpenAiChatOptions.class); + } + else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { + runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class, + OpenAiChatOptions.class); + } + else { + runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, + OpenAiChatOptions.class); + } + } + + // Define request options by merging runtime options and default options + OpenAiChatOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, + OpenAiChatOptions.class); + + // Merge @JsonIgnore-annotated options explicitly since they are ignored by + // Jackson, used by ModelOptionsUtils. + if (runtimeOptions != null) { + requestOptions.setHttpHeaders( + mergeHttpHeaders(runtimeOptions.getHttpHeaders(), this.defaultOptions.getHttpHeaders())); + requestOptions.setInternalToolExecutionEnabled( + ModelOptionsUtils.mergeOption(runtimeOptions.isInternalToolExecutionEnabled(), + this.defaultOptions.isInternalToolExecutionEnabled())); + requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), + this.defaultOptions.getToolNames())); + requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), + this.defaultOptions.getToolCallbacks())); + requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(), + this.defaultOptions.getToolContext())); + } + else { + requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders()); + requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.isInternalToolExecutionEnabled()); + requestOptions.setToolNames(this.defaultOptions.getToolNames()); + requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); + requestOptions.setToolContext(this.defaultOptions.getToolContext()); + } + + ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); + + return new Prompt(prompt.getInstructions(), requestOptions); + } + + private Map mergeHttpHeaders(Map runtimeHttpHeaders, + Map defaultHttpHeaders) { + var mergedHttpHeaders = new HashMap<>(defaultHttpHeaders); + mergedHttpHeaders.putAll(runtimeHttpHeaders); + return mergedHttpHeaders; + } + /** * Accessible for testing. */ @@ -563,38 +672,17 @@ else if (message.getMessageType() == MessageType.TOOL) { ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream); - Set enabledToolsToUse = new HashSet<>(); - - if (prompt.getOptions() != null) { - OpenAiChatOptions updatedRuntimeOptions = null; - - if (prompt.getOptions() instanceof FunctionCallingOptions) { - updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(((FunctionCallingOptions) prompt.getOptions()), - FunctionCallingOptions.class, OpenAiChatOptions.class); - } - else { - updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, - OpenAiChatOptions.class); - } - - enabledToolsToUse.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions)); - - request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class); - } - - if (!CollectionUtils.isEmpty(this.defaultOptions.getFunctions())) { - enabledToolsToUse.addAll(this.defaultOptions.getFunctions()); - } - - request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class); - - // Add the enabled functions definitions to the request's tools parameter. - if (!CollectionUtils.isEmpty(enabledToolsToUse)) { + OpenAiChatOptions requestOptions = (OpenAiChatOptions) prompt.getOptions(); + request = ModelOptionsUtils.merge(requestOptions, request, ChatCompletionRequest.class); + // Add the tool definitions to the request's tools parameter. + List toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); + if (!CollectionUtils.isEmpty(toolDefinitions)) { request = ModelOptionsUtils.merge( - OpenAiChatOptions.builder().tools(this.getFunctionTools(enabledToolsToUse)).build(), request, + OpenAiChatOptions.builder().tools(this.getFunctionTools(toolDefinitions)).build(), request, ChatCompletionRequest.class); } + // Remove `streamOptions` from the request if it is not a streaming request if (request.streamOptions() != null && !stream) { logger.warn("Removing streamOptions from the request as it is not a streaming request!"); @@ -643,26 +731,14 @@ else if (mediaContentData instanceof String text) { } } - private List getFunctionTools(Set functionNames) { - return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> { - var function = new OpenAiApi.FunctionTool.Function(functionCallback.getDescription(), - functionCallback.getName(), functionCallback.getInputTypeSchema()); + private List getFunctionTools(List toolDefinitions) { + return toolDefinitions.stream().map(toolDefinition -> { + var function = new OpenAiApi.FunctionTool.Function(toolDefinition.description(), toolDefinition.name(), + toolDefinition.inputSchema()); return new OpenAiApi.FunctionTool(function); }).toList(); } - private ChatOptions buildRequestOptions(OpenAiApi.ChatCompletionRequest request) { - return ChatOptions.builder() - .model(request.model()) - .frequencyPenalty(request.frequencyPenalty()) - .maxTokens(request.maxTokens()) - .presencePenalty(request.presencePenalty()) - .stopSequences(request.stop()) - .temperature(request.temperature()) - .topP(request.topP()) - .build(); - } - @Override public ChatOptions getDefaultOptions() { return OpenAiChatOptions.fromOptions(this.defaultOptions); @@ -682,4 +758,94 @@ public void setObservationConvention(ChatModelObservationConvention observationC this.observationConvention = observationConvention; } + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + + private OpenAiApi openAiApi; + + private OpenAiChatOptions defaultOptions = OpenAiChatOptions.builder() + .model(OpenAiApi.DEFAULT_CHAT_MODEL) + .temperature(0.7) + .build(); + + private ToolCallingManager toolCallingManager; + + private FunctionCallbackResolver functionCallbackResolver; + + private List toolFunctionCallbacks; + + private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + + private Builder() { + } + + public Builder openAiApi(OpenAiApi openAiApi) { + this.openAiApi = openAiApi; + return this; + } + + public Builder defaultOptions(OpenAiChatOptions defaultOptions) { + this.defaultOptions = defaultOptions; + return this; + } + + public Builder toolCallingManager(ToolCallingManager toolCallingManager) { + this.toolCallingManager = toolCallingManager; + return this; + } + + @Deprecated + public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) { + this.functionCallbackResolver = functionCallbackResolver; + return this; + } + + @Deprecated + public Builder toolFunctionCallbacks(List toolFunctionCallbacks) { + this.toolFunctionCallbacks = toolFunctionCallbacks; + return this; + } + + public Builder retryTemplate(RetryTemplate retryTemplate) { + this.retryTemplate = retryTemplate; + return this; + } + + public Builder observationRegistry(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + return this; + } + + public OpenAiChatModel build() { + if (toolCallingManager != null) { + Assert.isNull(functionCallbackResolver, + "functionCallbackResolver cannot be set when toolCallingManager is set"); + Assert.isNull(toolFunctionCallbacks, + "toolFunctionCallbacks cannot be set when toolCallingManager is set"); + + return new OpenAiChatModel(openAiApi, defaultOptions, toolCallingManager, retryTemplate, + observationRegistry); + } + + if (functionCallbackResolver != null) { + Assert.isNull(toolCallingManager, + "toolCallingManager cannot be set when functionCallbackResolver is set"); + List toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks + : List.of(); + + return new OpenAiChatModel(openAiApi, defaultOptions, functionCallbackResolver, toolCallbacks, + retryTemplate, observationRegistry); + } + + return new OpenAiChatModel(openAiApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate, + observationRegistry); + } + + } + } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index 6c3e1246ca2..c1193d4b261 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -17,6 +17,7 @@ package org.springframework.ai.openai; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -31,12 +32,14 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.StreamOptions; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoiceBuilder; import org.springframework.ai.openai.api.ResponseFormat; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -49,7 +52,7 @@ * @since 0.8.0 */ @JsonInclude(Include.NON_NULL) -public class OpenAiChatOptions implements FunctionCallingOptions { +public class OpenAiChatOptions implements ToolCallingChatOptions { // @formatter:off /** @@ -192,33 +195,22 @@ public class OpenAiChatOptions implements FunctionCallingOptions { private @JsonProperty("reasoning_effort") String reasoningEffort; /** - * OpenAI Tool Function Callbacks to register with the ChatModel. - * For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution. - * For Default Options the functionCallbacks are registered but disabled by default. Use the enableFunctions to set the functions - * from the registry to be used by the ChatModel chat completion requests. + * Collection of {@link ToolCallback}s to be used for tool calling in the chat completion requests. */ @JsonIgnore - private List functionCallbacks = new ArrayList<>(); + private List toolCallbacks = new ArrayList<>(); /** - * List of functions, identified by their names, to configure for function calling in - * the chat completion requests. - * Functions with those names must exist in the functionCallbacks registry. - * The {@link #functionCallbacks} from the PromptOptions are automatically enabled for the duration of the prompt execution. - * - * Note that function enabled with the default options are enabled for all chat completion requests. This could impact the token count and the billing. - * If the functions is set in a prompt options, then the enabled functions are only active for the duration of this prompt execution. + * Collection of tool names to be resolved at runtime and used for tool calling in the chat completion requests. */ @JsonIgnore - private Set functions = new HashSet<>(); + private Set toolNames = new HashSet<>(); /** - * If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. - * It is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. - * If false, the Spring AI will handle the function calls internally. + * Whether to enable the tool execution lifecycle internally in ChatModel. */ @JsonIgnore - private Boolean proxyToolCalls; + private Boolean internalToolExecutionEnabled; /** * Optional HTTP headers to be added to the chat completion request. @@ -227,7 +219,7 @@ public class OpenAiChatOptions implements FunctionCallingOptions { private Map httpHeaders = new HashMap<>(); @JsonIgnore - private Map toolContext; + private Map toolContext = new HashMap<>(); // @formatter:on @@ -258,10 +250,10 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) { .toolChoice(fromOptions.getToolChoice()) .user(fromOptions.getUser()) .parallelToolCalls(fromOptions.getParallelToolCalls()) - .functionCallbacks(fromOptions.getFunctionCallbacks()) - .functions(fromOptions.getFunctions()) + .toolCallbacks(fromOptions.getToolCallbacks()) + .toolNames(fromOptions.getToolNames()) .httpHeaders(fromOptions.getHttpHeaders()) - .proxyToolCalls(fromOptions.getProxyToolCalls()) + .internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled()) .toolContext(fromOptions.getToolContext()) .store(fromOptions.getStore()) .metadata(fromOptions.getMetadata()) @@ -447,12 +439,16 @@ public void setToolChoice(Object toolChoice) { } @Override + @Deprecated + @JsonIgnore public Boolean getProxyToolCalls() { - return this.proxyToolCalls; + return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null; } + @Deprecated + @JsonIgnore public void setProxyToolCalls(Boolean proxyToolCalls) { - this.proxyToolCalls = proxyToolCalls; + this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null; } public String getUser() { @@ -472,22 +468,73 @@ public void setParallelToolCalls(Boolean parallelToolCalls) { } @Override + @JsonIgnore + public List getToolCallbacks() { + return this.toolCallbacks; + } + + @Override + @JsonIgnore + public void setToolCallbacks(List toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); + this.toolCallbacks = toolCallbacks; + } + + @Override + @JsonIgnore + public Set getToolNames() { + return this.toolNames; + } + + @Override + @JsonIgnore + public void setToolNames(Set toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + Assert.noNullElements(toolNames, "toolNames cannot contain null elements"); + toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements")); + this.toolNames = toolNames; + } + + @Override + @Nullable + @JsonIgnore + public Boolean isInternalToolExecutionEnabled() { + return internalToolExecutionEnabled; + } + + @Override + @JsonIgnore + public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.internalToolExecutionEnabled = internalToolExecutionEnabled; + } + + @Override + @Deprecated + @JsonIgnore public List getFunctionCallbacks() { - return this.functionCallbacks; + return this.getToolCallbacks(); } @Override + @Deprecated + @JsonIgnore public void setFunctionCallbacks(List functionCallbacks) { - this.functionCallbacks = functionCallbacks; + this.setToolCallbacks(functionCallbacks); } @Override + @Deprecated + @JsonIgnore public Set getFunctions() { - return this.functions; + return this.getToolNames(); } + @Override + @Deprecated + @JsonIgnore public void setFunctions(Set functionNames) { - this.functions = functionNames; + this.setToolNames(functionNames); } public Map getHttpHeaders() { @@ -505,11 +552,13 @@ public Integer getTopK() { } @Override + @JsonIgnore public Map getToolContext() { return this.toolContext; } @Override + @JsonIgnore public void setToolContext(Map toolContext) { this.toolContext = toolContext; } @@ -548,9 +597,9 @@ public int hashCode() { return Objects.hash(this.model, this.frequencyPenalty, this.logitBias, this.logprobs, this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.presencePenalty, this.responseFormat, this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice, - this.user, this.parallelToolCalls, this.functionCallbacks, this.functions, this.httpHeaders, - this.proxyToolCalls, this.toolContext, this.outputModalities, this.outputAudio, this.store, - this.metadata, this.reasoningEffort); + this.user, this.parallelToolCalls, this.toolCallbacks, this.toolNames, this.httpHeaders, + this.internalToolExecutionEnabled, this.toolContext, this.outputModalities, this.outputAudio, + this.store, this.metadata, this.reasoningEffort); } @Override @@ -574,11 +623,11 @@ public boolean equals(Object o) { && Objects.equals(this.topP, other.topP) && Objects.equals(this.tools, other.tools) && Objects.equals(this.toolChoice, other.toolChoice) && Objects.equals(this.user, other.user) && Objects.equals(this.parallelToolCalls, other.parallelToolCalls) - && Objects.equals(this.functionCallbacks, other.functionCallbacks) - && Objects.equals(this.functions, other.functions) + && Objects.equals(this.toolCallbacks, other.toolCallbacks) + && Objects.equals(this.toolNames, other.toolNames) && Objects.equals(this.httpHeaders, other.httpHeaders) && Objects.equals(this.toolContext, other.toolContext) - && Objects.equals(this.proxyToolCalls, other.proxyToolCalls) + && Objects.equals(this.internalToolExecutionEnabled, other.internalToolExecutionEnabled) && Objects.equals(this.outputModalities, other.outputModalities) && Objects.equals(this.outputAudio, other.outputAudio) && Objects.equals(this.store, other.store) && Objects.equals(this.metadata, other.metadata) @@ -712,25 +761,54 @@ public Builder parallelToolCalls(Boolean parallelToolCalls) { return this; } - public Builder functionCallbacks(List functionCallbacks) { - this.options.functionCallbacks = functionCallbacks; + public Builder toolCallbacks(List toolCallbacks) { + this.options.setToolCallbacks(toolCallbacks); return this; } - public Builder functions(Set functionNames) { - Assert.notNull(functionNames, "Function names must not be null"); - this.options.functions = functionNames; + public Builder toolCallbacks(FunctionCallback... toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks)); return this; } - public Builder function(String functionName) { - Assert.hasText(functionName, "Function name must not be empty"); - this.options.functions.add(functionName); + public Builder toolNames(Set toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + this.options.setToolNames(toolNames); + return this; + } + + public Builder toolNames(String... toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + this.options.toolNames.addAll(Set.of(toolNames)); return this; } + public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled); + return this; + } + + @Deprecated + public Builder functionCallbacks(List functionCallbacks) { + return toolCallbacks(functionCallbacks); + } + + @Deprecated + public Builder functions(Set functionNames) { + return toolNames(functionNames); + } + + @Deprecated + public Builder function(String functionName) { + return toolNames(functionName); + } + + @Deprecated public Builder proxyToolCalls(Boolean proxyToolCalls) { - this.options.proxyToolCalls = proxyToolCalls; + if (proxyToolCalls != null) { + this.options.setInternalToolExecutionEnabled(!proxyToolCalls); + } return this; } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java index 6725b18d7ba..561b0fe95ec 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,28 +17,69 @@ package org.springframework.ai.openai; import java.util.List; +import java.util.Map; import org.junit.jupiter.api.Test; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.SimpleApiKey; import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.tool.MockWeatherService; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov + * @author Thomas Vitale */ -public class ChatCompletionRequestTests { +class ChatCompletionRequestTests { @Test - public void createRequestWithChatOptions() { + void whenToolRuntimeOptionsThenMergeWithDefaults() { + OpenAiChatOptions defaultOptions = OpenAiChatOptions.builder() + .model("DEFAULT_MODEL") + .internalToolExecutionEnabled(true) + .toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2")) + .toolNames("tool1", "tool2") + .toolContext(Map.of("key1", "value1", "key2", "valueA")) + .build(); + + OpenAiChatModel chatModel = OpenAiChatModel.builder() + .openAiApi(OpenAiApi.builder().apiKey(new SimpleApiKey("TEST")).build()) + .defaultOptions(defaultOptions) + .build(); + + OpenAiChatOptions runtimeOptions = OpenAiChatOptions.builder() + .internalToolExecutionEnabled(false) + .toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4")) + .toolNames("tool3") + .toolContext(Map.of("key2", "valueB")) + .build(); + Prompt prompt = chatModel.buildRequestPrompt(new Prompt("Test message content", runtimeOptions)); + + assertThat(((ToolCallingChatOptions) prompt.getOptions())).isNotNull(); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).isInternalToolExecutionEnabled()).isFalse(); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(2); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks() + .stream() + .map(FunctionCallback::getName)).containsExactlyInAnyOrder("tool3", "tool4"); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolNames()).containsExactlyInAnyOrder("tool3"); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolContext()).containsEntry("key1", "value1") + .containsEntry("key2", "valueB"); + } + @Test + void createRequestWithChatOptions() { var client = new OpenAiChatModel(new OpenAiApi("TEST"), OpenAiChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build()); - var request = client.createRequest(new Prompt("Test message content"), false); + var prompt = client.buildRequestPrompt(new Prompt("Test message content")); + + var request = client.createRequest(prompt, false); assertThat(request.messages()).hasSize(1); assertThat(request.stream()).isFalse(); @@ -57,14 +98,13 @@ public void createRequestWithChatOptions() { } @Test - public void promptOptionsTools() { - + void promptOptionsTools() { final String TOOL_FUNCTION_NAME = "CurrentWeather"; var client = new OpenAiChatModel(new OpenAiApi("TEST"), OpenAiChatOptions.builder().model("DEFAULT_MODEL").build()); - var request = client.createRequest(new Prompt("Test message content", + var prompt = client.buildRequestPrompt(new Prompt("Test message content", OpenAiChatOptions.builder() .model("PROMPT_MODEL") .functionCallbacks(List.of(FunctionCallback.builder() @@ -72,11 +112,9 @@ public void promptOptionsTools() { .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) - .build()), - false); + .build())); - assertThat(client.getFunctionCallbackRegister()).hasSize(1); - assertThat(client.getFunctionCallbackRegister()).containsKeys(TOOL_FUNCTION_NAME); + var request = client.createRequest(prompt, false); assertThat(request.messages()).hasSize(1); assertThat(request.stream()).isFalse(); @@ -87,8 +125,7 @@ public void promptOptionsTools() { } @Test - public void defaultOptionsTools() { - + void defaultOptionsTools() { final String TOOL_FUNCTION_NAME = "CurrentWeather"; var client = new OpenAiChatModel(new OpenAiApi("TEST"), @@ -101,48 +138,43 @@ public void defaultOptionsTools() { .build())) .build()); - var request = client.createRequest(new Prompt("Test message content"), false); + var prompt = client.buildRequestPrompt(new Prompt("Test message content")); - assertThat(client.getFunctionCallbackRegister()).hasSize(1); - assertThat(client.getFunctionCallbackRegister()).containsKeys(TOOL_FUNCTION_NAME); - assertThat(client.getFunctionCallbackRegister().get(TOOL_FUNCTION_NAME).getDescription()) - .isEqualTo("Get the weather in location"); + var request = client.createRequest(prompt, false); assertThat(request.messages()).hasSize(1); assertThat(request.stream()).isFalse(); assertThat(request.model()).isEqualTo("DEFAULT_MODEL"); + assertThat(request.tools()).hasSize(1); + assertThat(request.tools().get(0).getFunction().getName()).isEqualTo(TOOL_FUNCTION_NAME); - assertThat(request.tools()).as("Default Options callback functions are not automatically enabled!") - .isNullOrEmpty(); - - // Explicitly enable the function - request = client.createRequest( - new Prompt("Test message content", OpenAiChatOptions.builder().function(TOOL_FUNCTION_NAME).build()), - false); + // Reference the default options tool by name at runtime + prompt = client.buildRequestPrompt( + new Prompt("Test message content", OpenAiChatOptions.builder().function(TOOL_FUNCTION_NAME).build())); + request = client.createRequest(prompt, false); assertThat(request.tools()).hasSize(1); - assertThat(request.tools().get(0).getFunction().getName()).as("Explicitly enabled function") - .isEqualTo(TOOL_FUNCTION_NAME); + assertThat(request.tools().get(0).getFunction().getName()).isEqualTo(TOOL_FUNCTION_NAME); + } - // Override the default options function with one from the prompt - request = client.createRequest(new Prompt("Test message content", - OpenAiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function(TOOL_FUNCTION_NAME, new MockWeatherService()) - .description("Overridden function description") - .inputType(MockWeatherService.Request.class) - .build())) - .build()), - false); + static class TestToolCallback implements ToolCallback { - assertThat(request.tools()).hasSize(1); - assertThat(request.tools().get(0).getFunction().getName()).as("Explicitly enabled function") - .isEqualTo(TOOL_FUNCTION_NAME); + private final ToolDefinition toolDefinition; + + public TestToolCallback(String name) { + this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build(); + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String toolInput) { + return "Mission accomplished!"; + } - assertThat(client.getFunctionCallbackRegister()).hasSize(1); - assertThat(client.getFunctionCallbackRegister()).containsKeys(TOOL_FUNCTION_NAME); - assertThat(client.getFunctionCallbackRegister().get(TOOL_FUNCTION_NAME).getDescription()) - .isEqualTo("Overridden function description"); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java index c991e6a60ae..ed9b739d7dc 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java @@ -90,6 +90,12 @@ public List resolveToolDefinitions(ToolCallingChatOptions chatOp List toolCallbacks = new ArrayList<>(chatOptions.getToolCallbacks()); for (String toolName : chatOptions.getToolNames()) { + // Skip the tool if it is already present in the request toolCallbacks. + // That might happen if a tool is defined in the options + // both as a ToolCallback and as a tool name. + if (chatOptions.getToolCallbacks().stream().anyMatch(tool -> tool.getName().equals(toolName))) { + continue; + } FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName); if (toolCallback == null) { throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/LegacyToolCallingManager.java b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/LegacyToolCallingManager.java index 4176c6c7c22..f88f7c7b2b2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/LegacyToolCallingManager.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/LegacyToolCallingManager.java @@ -78,6 +78,12 @@ public List resolveToolDefinitions(ToolCallingChatOptions chatOp List toolCallbacks = new ArrayList<>(chatOptions.getToolCallbacks()); for (String toolName : chatOptions.getToolNames()) { + // Skip the tool if it is already present in the request toolCallbacks. + // That might happen if a tool is defined in the options + // both as a ToolCallback and as a tool name. + if (chatOptions.getToolCallbacks().stream().anyMatch(tool -> tool.getName().equals(toolName))) { + continue; + } FunctionCallback toolCallback = resolveFunctionCallback(toolName); if (toolCallback == null) { throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java index b4c25f91172..430fe4ad056 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java @@ -20,8 +20,10 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.tool.util.ToolUtils; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import java.util.ArrayList; import java.util.HashMap; @@ -191,18 +193,20 @@ else if (chatOptions instanceof FunctionCallingOptions functionCallingOptions static Set mergeToolNames(Set runtimeToolNames, Set defaultToolNames) { Assert.notNull(runtimeToolNames, "runtimeToolNames cannot be null"); Assert.notNull(defaultToolNames, "defaultToolNames cannot be null"); - var mergedToolNames = new HashSet<>(runtimeToolNames); - mergedToolNames.addAll(defaultToolNames); - return mergedToolNames; + if (CollectionUtils.isEmpty(runtimeToolNames)) { + return new HashSet<>(defaultToolNames); + } + return new HashSet<>(runtimeToolNames); } static List mergeToolCallbacks(List runtimeToolCallbacks, List defaultToolCallbacks) { Assert.notNull(runtimeToolCallbacks, "runtimeToolCallbacks cannot be null"); Assert.notNull(defaultToolCallbacks, "defaultToolCallbacks cannot be null"); - var mergedToolCallbacks = new ArrayList<>(runtimeToolCallbacks); - mergedToolCallbacks.addAll(defaultToolCallbacks); - return mergedToolCallbacks; + if (CollectionUtils.isEmpty(runtimeToolCallbacks)) { + return new ArrayList<>(defaultToolCallbacks); + } + return new ArrayList<>(runtimeToolCallbacks); } static Map mergeToolContext(Map runtimeToolContext, @@ -216,4 +220,12 @@ static Map mergeToolContext(Map runtimeToolConte return mergedToolContext; } + static void validateToolCallbacks(List toolCallbacks) { + List duplicateToolNames = ToolUtils.getDuplicateToolNames(toolCallbacks); + if (!duplicateToolNames.isEmpty()) { + throw new IllegalStateException("Multiple tools with the same name (%s) found in ToolCallingChatOptions" + .formatted(String.join(", ", duplicateToolNames))); + } + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/util/ToolUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/util/ToolUtils.java index bbbfe50922b..6b3e351d8c8 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/util/ToolUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/util/ToolUtils.java @@ -26,10 +26,10 @@ import org.springframework.util.StringUtils; import java.lang.reflect.Method; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.stream.Collectors; -import java.util.stream.Stream; /** * Miscellaneous tool utility methods. Mainly for internal use within the framework. @@ -85,9 +85,9 @@ public static ToolCallResultConverter getToolCallResultConverter(Method method) } } - public static List getDuplicateToolNames(FunctionCallback... toolCallbacks) { + public static List getDuplicateToolNames(List toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); - return Stream.of(toolCallbacks) + return toolCallbacks.stream() .collect(Collectors.groupingBy(FunctionCallback::getName, Collectors.counting())) .entrySet() .stream() @@ -96,4 +96,9 @@ public static List getDuplicateToolNames(FunctionCallback... toolCallbac .collect(Collectors.toList()); } + public static List getDuplicateToolNames(FunctionCallback... toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + return getDuplicateToolNames(Arrays.asList(toolCallbacks)); + } + } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java index 4ac0f4c6168..711fc99c0ed 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java @@ -217,6 +217,28 @@ void whenMultipleToolCallsInChatResponseThenExecute() { assertThat(toolExecutionResult.conversationHistory()).contains(expectedToolResponse); } + @Test + void whenDuplicateMixedToolCallsInChatResponseThenExecute() { + ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder().build(); + + Prompt prompt = new Prompt(new UserMessage("Hello"), + ToolCallingChatOptions.builder() + .toolCallbacks(new TestToolCallback("toolA")) + .toolNames("toolA") + .build()); + ChatResponse chatResponse = ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("", Map.of(), + List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")))))) + .build(); + + ToolResponseMessage expectedToolResponse = new ToolResponseMessage( + List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!"))); + + ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse); + + assertThat(toolExecutionResult.conversationHistory()).contains(expectedToolResponse); + } + @Test void whenMultipleToolCallsWithReturnDirectInChatResponseThenExecute() { ToolCallback toolCallbackA = new TestToolCallback("toolA", true); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java index 134151ab0b9..8077b6bf8ef 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java @@ -26,6 +26,7 @@ import java.util.Set; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Unit tests for {@link ToolCallingChatOptions}. @@ -79,7 +80,7 @@ void whenMergeRuntimeAndDefaultToolNames() { Set runtimeToolNames = Set.of("toolA"); Set defaultToolNames = Set.of("toolB"); Set mergedToolNames = ToolCallingChatOptions.mergeToolNames(runtimeToolNames, defaultToolNames); - assertThat(mergedToolNames).containsExactlyInAnyOrder("toolA", "toolB"); + assertThat(mergedToolNames).containsExactlyInAnyOrder("toolA"); } @Test @@ -112,7 +113,8 @@ void whenMergeRuntimeAndDefaultToolCallbacks() { List defaultToolCallbacks = List.of(new TestToolCallback("toolB")); List mergedToolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(runtimeToolCallbacks, defaultToolCallbacks); - assertThat(mergedToolCallbacks).hasSize(2); + assertThat(mergedToolCallbacks).hasSize(1); + assertThat(mergedToolCallbacks.get(0).getName()).isEqualTo("toolA"); } @Test @@ -122,6 +124,7 @@ void whenMergeRuntimeAndEmptyDefaultToolCallbacks() { List mergedToolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(runtimeToolCallbacks, defaultToolCallbacks); assertThat(mergedToolCallbacks).hasSize(1); + assertThat(mergedToolCallbacks.get(0).getName()).isEqualTo("toolA"); } @Test @@ -131,6 +134,7 @@ void whenMergeEmptyRuntimeAndDefaultToolCallbacks() { List mergedToolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(runtimeToolCallbacks, defaultToolCallbacks); assertThat(mergedToolCallbacks).hasSize(1); + assertThat(mergedToolCallbacks.get(0).getName()).isEqualTo("toolB"); } @Test @@ -183,6 +187,14 @@ void whenMergeEmptyRuntimeAndEmptyDefaultToolContext() { assertThat(mergedToolContext).hasSize(0); } + @Test + void shouldEnsureUniqueToolNames() { + List toolCallbacks = List.of(new TestToolCallback("toolA"), new TestToolCallback("toolA")); + assertThatThrownBy(() -> ToolCallingChatOptions.validateToolCallbacks(toolCallbacks)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Multiple tools with the same name (toolA)"); + } + static class TestToolCallback implements ToolCallback { private final ToolDefinition toolDefinition; diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/comparison.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/comparison.adoc index ac121825e71..51961e49c07 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/comparison.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/comparison.adoc @@ -7,7 +7,7 @@ This table compares various Chat Models supported by Spring AI, detailing their capabilities: - xref:api/multimodality.adoc[Multimodality]: The types of input the model can process (e.g., text, image, audio, video). -- xref:api/functions.adoc[Tools/Functions]: Whether the model supports function calling or tool use. +- xref:api/tools.adoc[Tools/Function Calling]: Whether the model supports function calling or tool use. - Streaming: If the model offers streaming responses. - Retry: Support for retry mechanisms. - xref:observability/index.adoc[Observability]: Features for monitoring and debugging. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc index bef52710a89..61d40c613e4 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc @@ -1,5 +1,7 @@ = Function Calling +WARNING: This page describes the previous version of the Function Calling API, which has been deprecated and marked for remove in the next release. The current version is available at xref:api/tools.adoc[Tool Calling]. See the xref:api/tools-migration.adoc[Migration Guide] for more information. + You can register custom Java functions with the `OpenAiChatModel` and have the OpenAI model intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. This allows you to connect the LLM capabilities with external tools and APIs. The OpenAI models are trained to detect when a function should be called and to respond with JSON that adheres to the function signature. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/index.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/index.adoc index 43cf65ad873..51a93575a79 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/index.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/index.adoc @@ -22,13 +22,13 @@ image::spring-ai-chat-completions-clients.jpg[align="center", width="800px"] Portable `Vector Store API` across multiple providers, including a novel `SQL-like metadata filter API` that is also portable. Support for 14 vector databases are available. -=== Function Calling API +=== Tool Calling API -`Function calling`. Spring AI makes it easy to have the AI model invoke your POJO `java.util.Function` object. +Spring AI makes it easy to have the AI model invoke your services as `@Tool`-annotated methods or POJO `java.util.Function` objects. -image::function-calling-basic-flow.jpg[Function calling, width=500, align="center"] +image::tools/tool-calling-01.jpg[The main sequence of actions for tool calling, width=500, align="center"] -Check the Spring AI xref::api/functions.adoc[Function Calling] documentation. +Check the Spring AI xref::api/tools.adoc[Tool Calling] documentation. === Auto Configuration diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc index 8b277d6d7ea..d61b64c6978 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc @@ -24,7 +24,7 @@ The AI Model is not guaranteed to return the structured output as requested. The model may not understand the prompt or be unable to generate the structured output as requested. Consider implementing a validation mechanism to ensure the model output is as expected. -TIP: The `StructuredOutputConverter` is not used for LLM xref:api/functions.adoc[Function Calling], as this feature inherently provides structured outputs by default. +TIP: The `StructuredOutputConverter` is not used for LLM xref:api/tools.adoc[Tool Calling], as this feature inherently provides structured outputs by default. == Structured Output API diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/tools.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/tools.adoc index 81bfd0cc482..e83f2849f73 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/tools.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/tools.adoc @@ -232,6 +232,7 @@ ToolCallback[] dateTimeTools = ToolCallbacks.from(new DateTimeTools()); ==== Adding Default Tools to `ChatClient` When using the declarative specification approach, you can add default tools to a `ChatClient.Builder` by passing the tool class instance to the `defaultTools()` method. +If both default and runtime tools are provided, the runtime tools will completely override the default tools. WARNING: Default tools are shared across all the chat requests performed by all the `ChatClient` instances built from the same `ChatClient.Builder`. They are useful for tools that are commonly used across different chat requests, but they can also be dangerous if not used carefully, risking to make them available when they shouldn't. @@ -261,6 +262,7 @@ chatModel.call(prompt); ==== Adding Default Tools to `ChatModel` When using the declarative specification approach, you can add default tools to `ChatModel` at construction time by passing the tool class instance to the `toolCallbacks()` method of the `ToolCallingChatOptions` instance used to create the `ChatModel`. +If both default and runtime tools are provided, the runtime tools will completely override the default tools. WARNING: Default tools are shared across all the chat requests performed by that `ChatModel` instance. They are useful for tools that are commonly used across different chat requests, but they can also be dangerous if not used carefully, risking to make them available when they shouldn't. @@ -397,6 +399,7 @@ ChatClient.create(chatModel) ==== Adding Default Tools to `ChatClient` When using the programmatic specification approach, you can add default tools to a `ChatClient.Builder` by passing the `MethodToolCallback` instance to the `defaultTools()` method. +If both default and runtime tools are provided, the runtime tools will completely override the default tools. WARNING: Default tools are shared across all the chat requests performed by all the `ChatClient` instances built from the same `ChatClient.Builder`. They are useful for tools that are commonly used across different chat requests, but they can also be dangerous if not used carefully, risking to make them available when they shouldn't. @@ -427,6 +430,7 @@ chatModel.call(prompt); ==== Adding Default Tools to `ChatModel` When using the programmatic specification approach, you can add default tools to a `ChatModel` at construction time by passing the `MethodToolCallback` instance to the `toolCallbacks()` method of the `ToolCallingChatOptions` instance used to create the `ChatModel`. +If both default and runtime tools are provided, the runtime tools will completely override the default tools. WARNING: Default tools are shared across all the chat requests performed by that `ChatModel` instance. They are useful for tools that are commonly used across different chat requests, but they can also be dangerous if not used carefully, risking to make them available when they shouldn't. @@ -517,6 +521,7 @@ ChatClient.create(chatModel) ==== Adding Default Tools to `ChatClient` When using the programmatic specification approach, you can add default tools to a `ChatClient.Builder` by passing the `FunctionToolCallback` instance to the `defaultTools()` method. +If both default and runtime tools are provided, the runtime tools will completely override the default tools. WARNING: Default tools are shared across all the chat requests performed by all the `ChatClient` instances built from the same `ChatClient.Builder`. They are useful for tools that are commonly used across different chat requests, but they can also be dangerous if not used carefully, risking to make them available when they shouldn't. @@ -547,6 +552,7 @@ chatModel.call(prompt); ==== Adding Default Tools to `ChatModel` When using the programmatic specification approach, you can add default tools to a `ChatModel` at construction time by passing the `FunctionToolCallback` instance to the `toolCallbacks()` method of the `ToolCallingChatOptions` instance used to create the `ChatModel`. +If both default and runtime tools are provided, the runtime tools will completely override the default tools. WARNING: Default tools are shared across all the chat requests performed by that `ChatModel` instance. They are useful for tools that are commonly used across different chat requests, but they can also be dangerous if not used carefully, risking to make them available when they shouldn't. @@ -625,6 +631,7 @@ ChatClient.create(chatModel) ==== Adding Default Tools to `ChatClient` When using the dynamic specification approach, you can add default tools to a `ChatClient.Builder` by passing the tool name to the `defaultTools()` method. +If both default and runtime tools are provided, the runtime tools will completely override the default tools. WARNING: Default tools are shared across all the chat requests performed by all the `ChatClient` instances built from the same `ChatClient.Builder`. They are useful for tools that are commonly used across different chat requests, but they can also be dangerous if not used carefully, risking to make them available when they shouldn't. @@ -653,6 +660,7 @@ chatModel.call(prompt); ==== Adding Default Tools to `ChatModel` When using the dynamic specification approach, you can add default tools to `ChatModel` at construction time by passing the tool name to the `toolNames()` method of the `ToolCallingChatOptions` instance used to create the `ChatModel`. +If both default and runtime tools are provided, the runtime tools will completely override the default tools. WARNING: Default tools are shared across all the chat requests performed by that `ChatModel` instance. They are useful for tools that are commonly used across different chat requests, but they can also be dangerous if not used carefully, risking to make them available when they shouldn't. @@ -967,6 +975,9 @@ Prompt prompt = new Prompt("Tell me more about the customer with ID 42", chatOpt chatModel.call(prompt); ---- +If the `toolContext` option is set both in the default options and in the runtime options, the resulting `ToolContext` will be the merge of the two, +where the runtime options take precedence over the default options. + === Return Direct By default, the result of a tool call is sent back to the model as a response. Then, the model can use the result to continue the conversation. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/concepts.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/concepts.adoc index 3d923662ff8..6ca49f3ed32 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/concepts.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/concepts.adoc @@ -143,8 +143,8 @@ The Spring AI library helps you implement solutions based on the "`stuffing the image::spring-ai-prompt-stuffing.jpg[Prompt stuffing, width=700, align="center"] -* **xref::concepts.adoc#concept-fc[Function Calling]**: This technique allows registering custom, user functions that connect the large language models to the APIs of external systems. -Spring AI greatly simplifies code you need to write to support xref:api/functions.adoc[function calling]. +* **xref::concepts.adoc#concept-fc[Tool Calling]**: This technique allows registering tools (user-defined services) that connect the large language models to the APIs of external systems. +Spring AI greatly simplifies code you need to write to support xref:api/tools.adoc[tool calling]. [[concept-rag]] === Retrieval Augmented Generation @@ -173,30 +173,29 @@ image::spring-ai-rag.jpg[Spring AI RAG, width=1000, align="center"] * The xref::api/chatclient.adoc#_retrieval_augmented_generation[ChatClient - RAG] explains how to use the `QuestionAnswerAdvisor` to enable the RAG capability in your application. [[concept-fc]] -=== Function Calling +=== Tool Calling Large Language Models (LLMs) are frozen after training, leading to stale knowledge, and they are unable to access or modify external data. -The xref::api/functions.adoc[Function Calling] mechanism addresses these shortcomings. -It allows you to register your own functions to connect the large language models to the APIs of external systems. +The xref::api/tools.adoc[Tool Calling] mechanism addresses these shortcomings. +It allows you to register your own services as tools to connect the large language models to the APIs of external systems. These systems can provide LLMs with real-time data and perform data processing actions on their behalf. -Spring AI greatly simplifies code you need to write to support function invocation. -It handles the function invocation conversation for you. -You can provide your function as a `@Bean` and then provide the bean name of the function in your prompt options to activate that function. -Additionally, you can define and reference multiple functions in a single prompt. +Spring AI greatly simplifies code you need to write to support tool invocation. +It handles the tool invocation conversation for you. +You can provide your tool as a `@Tool`-annotated method and provide it in your prompt options to make it available to the model. +Additionally, you can define and reference multiple tools in a single prompt. -image::function-calling-basic-flow.jpg[Function calling, width=700, align="center"] +image::tools/tool-calling-01.jpg[The main sequence of actions for tool calling, width=700, align="center"] -1. Perform a chat request sending along function definition information. -The latter provides the `name`, `description` (e.g. explaining when the Model should call the function), and `input parameters` (e.g. the function's input parameters schema). -2. When the Model decides to call the function, it will call the function with the input parameters and return the output to the model. -3. Spring AI handles this conversation for you. -It dispatches the function call to the appropriate function and returns the result to the model. -4. The Model can perform multiple function calls to retrieve all the information it needs. -5. Once all information needed is acquired, the Model will generate a response. +1. When we want to make a tool available to the model, we include its definition in the chat request. Each tool definition comprises of a name, a description, and the schema of the input parameters. +2. When the model decides to call a tool, it sends a response with the tool name and the input parameters modeled after the defined schema. +3. The application is responsible for using the tool name to identify and execute the tool with the provided input parameters. +4. The result of the tool call is processed by the application. +5. The application sends the tool call result back to the model. +6. The model generates the final response using the tool call result as additional context. -Follow the xref::api/functions.adoc[Function Calling] documentation for further information on how to use this feature with different AI models. +Follow the xref::api/tools.adoc[Tool Calling] documentation for further information on how to use this feature with different AI models. [[concept-evaluating-ai-responses]] == Evaluating AI responses diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/index.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/index.adoc index c841a5d9d3d..4c02d2315d6 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/index.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/index.adoc @@ -28,7 +28,7 @@ Spring AI provides the following features: * xref:api/structured-output-converter.adoc[Structured Outputs] - Mapping of AI Model output to POJOs. * Support for all major xref:api/vectordbs.adoc[Vector Database providers] such as Apache Cassandra, Azure Cosmos DB, Azure Vector Search, Chroma, Elasticsearch, GemFire, MariaDB, Milvus, MongoDB Atlas, Neo4j, OpenSearch, Oracle, PostgreSQL/PGVector, PineCone, Qdrant, Redis, SAP Hana, Typesense and Weaviate. * Portable API across Vector Store providers, including a novel SQL-like metadata filter API. -* xref:api/functions.adoc[Tools/Function Calling] - permits the model to request the execution of client-side tools and functions, thereby accessing necessary real-time information as required. +* xref:api/tools.adoc[Tools/Function Calling] - Permits the model to request the execution of client-side tools and functions, thereby accessing necessary real-time information as required and taking action. * xref:observability/index.adoc[Observability] - Provides insights into AI-related operations. * Document ingestion xref:api/etl-pipeline.adoc[ETL framework] for Data Engineering. * xref:api/testing.adoc[AI Model Evaluation] - Utilities to help evaluate generated content and protect against hallucinated response. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc index dd6980ad552..e5ae565a964 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc @@ -47,6 +47,32 @@ Example of the new JSON format: } ---- +=== Changes to usage of FunctionCallingOptions for tool calling + +Each `ChatModel` instance, at construction time, accepts an optional `ChatOptions` or `FunctionCallingOptions` instance +that can be used to configure default tools used for calling the model. + +Before 1.0.0-M6: + +- any tool passed via the `functions()` method of the default `FunctionCallingOptions` instance was included in +each call to the model from that `ChatModel` instance, possibly overwritten by runtime options. +- any tool passed via the `functionCallbacks()` method of the default `FunctionCallingOptions` instance was only +made available for runtime dynamic resolution (see xref:api/tools.adoc#_tool_resolution[Tool Resolution]), but never +included in any call to the model unless explicitly requested. + +Starting 1.0.0-M6: + +- any tool passed via the `functions()` method or the `functionCallbacks()` of the default `FunctionCallingOptions` +instance is now handled in the same way: it is included in each call to the model from that `ChatModel` instance, +possibly overwritten by runtime options. With that, there is consistency in the way tools are included in calls +to the model and prevents any confusion due to a difference in behavior between `functionCallbacks()` and all the other options. + +If you want to make a tool available for runtime dynamic resolution and include it in a chat request to the model only +when explicitly requested, you can use one of the strategies described in xref:api/tools.adoc#_tool_resolution[Tool Resolution]. + +NOTE: 1.0.0-M6 introduced new APIs for handling tool calling. Backward compatibility is maintained for the old APIs across +all scenarios, except the one described above. The old APIs are still available, but they are deprecated +and will be removed in 1.0.0-M7. === Removal of deprecated Amazon Bedrock chat models diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index 35e1fcf9ecc..1f50ff40ce1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,14 +23,15 @@ import io.micrometer.observation.ObservationRegistry; import org.jetbrains.annotations.NotNull; +import org.springframework.ai.autoconfigure.chat.model.ToolCallingAutoConfiguration; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.image.observation.ImageModelObservationConvention; import org.springframework.ai.model.SimpleApiKey; import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.openai.OpenAiAudioSpeechModel; import org.springframework.ai.openai.OpenAiAudioTranscriptionModel; import org.springframework.ai.openai.OpenAiChatModel; @@ -69,13 +70,13 @@ * @author Thomas Vitale */ @AutoConfiguration(after = { RestClientAutoConfiguration.class, WebClientAutoConfiguration.class, - SpringAiRetryAutoConfiguration.class }) + SpringAiRetryAutoConfiguration.class, ToolCallingAutoConfiguration.class }) @ConditionalOnClass(OpenAiApi.class) @EnableConfigurationProperties({ OpenAiConnectionProperties.class, OpenAiChatProperties.class, OpenAiEmbeddingProperties.class, OpenAiImageProperties.class, OpenAiAudioTranscriptionProperties.class, OpenAiAudioSpeechProperties.class, OpenAiModerationProperties.class }) @ImportAutoConfiguration(classes = { SpringAiRetryAutoConfiguration.class, RestClientAutoConfiguration.class, - WebClientAutoConfiguration.class }) + WebClientAutoConfiguration.class, ToolCallingAutoConfiguration.class }) public class OpenAiAutoConfiguration { private static @NotNull ResolvedConnectionProperties resolveConnectionProperties( @@ -114,17 +115,22 @@ public class OpenAiAutoConfiguration { matchIfMissing = true) public OpenAiChatModel openAiChatModel(OpenAiConnectionProperties commonProperties, OpenAiChatProperties chatProperties, ObjectProvider restClientBuilderProvider, - ObjectProvider webClientBuilderProvider, List toolFunctionCallbacks, - FunctionCallbackResolver functionCallbackResolver, RetryTemplate retryTemplate, - ResponseErrorHandler responseErrorHandler, ObjectProvider observationRegistry, + ObjectProvider webClientBuilderProvider, ToolCallingManager toolCallingManager, + RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler, + ObjectProvider observationRegistry, ObjectProvider observationConvention) { var openAiApi = openAiApi(chatProperties, commonProperties, restClientBuilderProvider.getIfAvailable(RestClient::builder), webClientBuilderProvider.getIfAvailable(WebClient::builder), responseErrorHandler, "chat"); - var chatModel = new OpenAiChatModel(openAiApi, chatProperties.getOptions(), functionCallbackResolver, - toolFunctionCallbacks, retryTemplate, observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)); + var chatModel = OpenAiChatModel.builder() + .openAiApi(openAiApi) + .defaultOptions(chatProperties.getOptions()) + .toolCallingManager(toolCallingManager) + .retryTemplate(retryTemplate) + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .build(); observationConvention.ifAvailable(chatModel::setObservationConvention);