diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index 222f6f556e2..c14f8ae0f31 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,10 +18,8 @@ import java.util.ArrayList; import java.util.Base64; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import io.micrometer.observation.Observation; @@ -29,6 +27,12 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.model.tool.LegacyToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.lang.Nullable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -88,6 +92,8 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build(); + private final Logger logger = LoggerFactory.getLogger(getClass()); /** @@ -107,11 +113,17 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM */ private final ObservationRegistry observationRegistry; + private final ToolCallingManager toolCallingManager; + /** * Conventions to use for generating observations. */ private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + /** + * @deprecated Use {@link MistralAiChatModel.Builder}. + */ + @Deprecated public MistralAiChatModel(MistralAiApi mistralAiApi) { this(mistralAiApi, MistralAiChatOptions.builder() @@ -122,32 +134,67 @@ public MistralAiChatModel(MistralAiApi mistralAiApi) { .build()); } + /** + * @deprecated Use {@link MistralAiChatModel.Builder}. + */ + @Deprecated public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options) { this(mistralAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE); } + /** + * @deprecated Use {@link MistralAiChatModel.Builder}. + */ + @Deprecated public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options, - FunctionCallbackResolver functionCallbackResolver, RetryTemplate retryTemplate) { + @Nullable FunctionCallbackResolver functionCallbackResolver, @Nullable RetryTemplate retryTemplate) { this(mistralAiApi, options, functionCallbackResolver, List.of(), retryTemplate); } + /** + * @deprecated Use {@link MistralAiChatModel.Builder}. + */ + @Deprecated public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options, - FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks, - RetryTemplate retryTemplate) { + @Nullable FunctionCallbackResolver functionCallbackResolver, + @Nullable List toolFunctionCallbacks, RetryTemplate retryTemplate) { this(mistralAiApi, options, functionCallbackResolver, toolFunctionCallbacks, retryTemplate, ObservationRegistry.NOOP); } + /** + * @deprecated Use {@link MistralAiChatModel.Builder}. + */ + @Deprecated public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options, - FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks, - RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { - super(functionCallbackResolver, options, toolFunctionCallbacks); - Assert.notNull(mistralAiApi, "mistralAiApi must not be null"); - Assert.notNull(options, "options must not be null"); - Assert.notNull(retryTemplate, "retryTemplate must not be null"); - Assert.notNull(observationRegistry, "observationRegistry must not be null"); + @Nullable FunctionCallbackResolver functionCallbackResolver, + @Nullable List toolFunctionCallbacks, RetryTemplate retryTemplate, + ObservationRegistry observationRegistry) { + this(mistralAiApi, options, + LegacyToolCallingManager.builder() + .functionCallbackResolver(functionCallbackResolver) + .functionCallbacks(toolFunctionCallbacks) + .build(), + retryTemplate, observationRegistry); + logger.warn("This constructor is deprecated and will be removed in the next milestone. " + + "Please use the MistralAiChatModel.Builder or the new constructor accepting ToolCallingManager instead."); + } + + public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions defaultOptions, + ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, + ObservationRegistry observationRegistry) { + // We do not pass the 'defaultOptions' to the AbstractToolSupport, + // because it modifies them. We are using ToolCallingManager instead, + // so we just pass empty options here. + super(null, MistralAiChatOptions.builder().build(), List.of()); + Assert.notNull(mistralAiApi, "mistralAiApi cannot be null"); + Assert.notNull(defaultOptions, "defaultOptions cannot be null"); + Assert.notNull(toolCallingManager, "toolCallingManager cannot be null"); + Assert.notNull(retryTemplate, "retryTemplate cannot be null"); + Assert.notNull(observationRegistry, "observationRegistry cannot be null"); this.mistralAiApi = mistralAiApi; - this.defaultOptions = options; + this.defaultOptions = defaultOptions; + this.toolCallingManager = toolCallingManager; this.retryTemplate = retryTemplate; this.observationRegistry = observationRegistry; } @@ -179,7 +226,10 @@ private static DefaultUsage getDefaultUsage(MistralAiApi.Usage usage) { @Override public ChatResponse call(Prompt prompt) { - return this.internalCall(prompt, null); + // Before moving any further, build the final request Prompt, + // merging runtime and default options. + Prompt requestPrompt = buildRequestPrompt(prompt); + return this.internalCall(requestPrompt, null); } public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { @@ -189,7 +239,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(MistralAiApi.PROVIDER_NAME) - .requestOptions(buildRequestOptions(request)) + .requestOptions(prompt.getOptions()) .build(); ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION @@ -228,13 +278,21 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons return chatResponse; }); - if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null - && isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), - MistralAiApi.ChatCompletionFinishReason.STOP.name()))) { - var toolCallConversation = handleToolCalls(prompt, response); - // Recursively call the call method with the tool call message - // conversation that contains the call responses. - return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response); + if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null + && response.hasToolCalls()) { + var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return ChatResponse.builder() + .from(response) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build(); + } + else { + // Send the tool execution result back to the model. + return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), + response); + } } return response; @@ -242,7 +300,10 @@ && isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALL @Override public Flux stream(Prompt prompt) { - return this.internalStream(prompt, null); + // Before moving any further, build the final request Prompt, + // merging runtime and default options. + Prompt requestPrompt = buildRequestPrompt(prompt); + return this.internalStream(requestPrompt, null); } public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { @@ -252,7 +313,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(MistralAiApi.PROVIDER_NAME) - .requestOptions(buildRequestOptions(request)) + .requestOptions(prompt.getOptions()) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( @@ -307,11 +368,18 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // @formatter:off Flux chatResponseFlux = chatResponse.flatMap(response -> { - if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name()))) { - var toolCallConversation = handleToolCalls(prompt, response); - // Recursively call the stream method with the tool call message - // conversation that contains the call responses. - return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response); + if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response.hasToolCalls()) { + var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return Flux.just(ChatResponse.builder().from(response) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build()); + } else { + // Send the tool execution result back to the model. + return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), + response); + } } else { return Flux.just(response); @@ -352,13 +420,57 @@ private ChatCompletion toChatCompletion(ChatCompletionChunk chunk) { chunk.usage()); } + Prompt buildRequestPrompt(Prompt prompt) { + // Process runtime options + MistralAiChatOptions runtimeOptions = null; + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { + runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, + MistralAiChatOptions.class); + } + else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { + runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class, + MistralAiChatOptions.class); + } + else { + runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, + MistralAiChatOptions.class); + } + } + + // Define request options by merging runtime options and default options + MistralAiChatOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, + MistralAiChatOptions.class); + + // Merge @JsonIgnore-annotated options explicitly since they are ignored by + // Jackson, used by ModelOptionsUtils. + if (runtimeOptions != null) { + requestOptions.setInternalToolExecutionEnabled( + ModelOptionsUtils.mergeOption(runtimeOptions.isInternalToolExecutionEnabled(), + this.defaultOptions.isInternalToolExecutionEnabled())); + requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), + this.defaultOptions.getToolNames())); + requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), + this.defaultOptions.getToolCallbacks())); + requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(), + this.defaultOptions.getToolContext())); + } + else { + requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.isInternalToolExecutionEnabled()); + requestOptions.setToolNames(this.defaultOptions.getToolNames()); + requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); + requestOptions.setToolContext(this.defaultOptions.getToolContext()); + } + + ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); + + return new Prompt(prompt.getInstructions(), requestOptions); + } + /** * Accessible for testing. */ MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { - - Set functionsForThisRequest = new HashSet<>(); - List chatCompletionMessages = prompt.getInstructions().stream().map(message -> { if (message instanceof UserMessage userMessage) { Object content = message.getText(); @@ -392,7 +504,6 @@ else if (message instanceof AssistantMessage assistantMessage) { MistralAiApi.ChatCompletionMessage.Role.ASSISTANT, null, toolCalls, null)); } else if (message instanceof ToolResponseMessage toolResponseMessage) { - toolResponseMessage.getResponses() .forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id")); @@ -409,35 +520,15 @@ else if (message instanceof ToolResponseMessage toolResponseMessage) { var request = new MistralAiApi.ChatCompletionRequest(chatCompletionMessages, stream); - if (!CollectionUtils.isEmpty(this.defaultOptions.getFunctions())) { - functionsForThisRequest.addAll(this.defaultOptions.getFunctions()); - } - - request = ModelOptionsUtils.merge(request, this.defaultOptions, MistralAiApi.ChatCompletionRequest.class); - - if (prompt.getOptions() != null) { - MistralAiChatOptions updatedRuntimeOptions; - - if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { - updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, - FunctionCallingOptions.class, MistralAiChatOptions.class); - } - else { - updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, - MistralAiChatOptions.class); - } - - functionsForThisRequest.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions)); - - request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, MistralAiApi.ChatCompletionRequest.class); - } - - // Add the enabled functions definitions to the request's tools parameter. - if (!CollectionUtils.isEmpty(functionsForThisRequest)) { + MistralAiChatOptions requestOptions = (MistralAiChatOptions) prompt.getOptions(); + request = ModelOptionsUtils.merge(requestOptions, request, MistralAiApi.ChatCompletionRequest.class); + // Add the tool definitions to the request's tools parameter. + List toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); + if (!CollectionUtils.isEmpty(toolDefinitions)) { request = ModelOptionsUtils.merge( - MistralAiChatOptions.builder().tools(this.getFunctionTools(functionsForThisRequest)).build(), - request, ChatCompletionRequest.class); + MistralAiChatOptions.builder().tools(this.getFunctionTools(toolDefinitions)).build(), request, + ChatCompletionRequest.class); } return request; @@ -464,24 +555,14 @@ else if (mediaContentData instanceof String text) { } } - private List getFunctionTools(Set functionNames) { - return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> { - var function = new MistralAiApi.FunctionTool.Function(functionCallback.getDescription(), - functionCallback.getName(), functionCallback.getInputTypeSchema()); + private List getFunctionTools(List toolDefinitions) { + return toolDefinitions.stream().map(toolDefinition -> { + var function = new MistralAiApi.FunctionTool.Function(toolDefinition.description(), toolDefinition.name(), + toolDefinition.inputSchema()); return new MistralAiApi.FunctionTool(function); }).toList(); } - private ChatOptions buildRequestOptions(MistralAiApi.ChatCompletionRequest request) { - return ChatOptions.builder() - .model(request.model()) - .maxTokens(request.maxTokens()) - .stopSequences(request.stop()) - .temperature(request.temperature()) - .topP(request.topP()) - .build(); - } - @Override public ChatOptions getDefaultOptions() { return MistralAiChatOptions.fromOptions(this.defaultOptions); @@ -496,4 +577,96 @@ public void setObservationConvention(ChatModelObservationConvention observationC this.observationConvention = observationConvention; } + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private MistralAiApi mistralAiApi; + + private MistralAiChatOptions defaultOptions = MistralAiChatOptions.builder() + .temperature(0.7) + .topP(1.0) + .safePrompt(false) + .model(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue()) + .build(); + + private ToolCallingManager toolCallingManager; + + private FunctionCallbackResolver functionCallbackResolver; + + private List toolFunctionCallbacks; + + private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + + private Builder() { + } + + public Builder mistralAiApi(MistralAiApi mistralAiApi) { + this.mistralAiApi = mistralAiApi; + return this; + } + + public Builder defaultOptions(MistralAiChatOptions defaultOptions) { + this.defaultOptions = defaultOptions; + return this; + } + + public Builder toolCallingManager(ToolCallingManager toolCallingManager) { + this.toolCallingManager = toolCallingManager; + return this; + } + + @Deprecated + public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) { + this.functionCallbackResolver = functionCallbackResolver; + return this; + } + + @Deprecated + public Builder toolFunctionCallbacks(List toolFunctionCallbacks) { + this.toolFunctionCallbacks = toolFunctionCallbacks; + return this; + } + + public Builder retryTemplate(RetryTemplate retryTemplate) { + this.retryTemplate = retryTemplate; + return this; + } + + public Builder observationRegistry(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + return this; + } + + public MistralAiChatModel build() { + if (toolCallingManager != null) { + Assert.isNull(functionCallbackResolver, + "functionCallbackResolver cannot be set when toolCallingManager is set"); + Assert.isNull(toolFunctionCallbacks, + "toolFunctionCallbacks cannot be set when toolCallingManager is set"); + + return new MistralAiChatModel(mistralAiApi, defaultOptions, toolCallingManager, retryTemplate, + observationRegistry); + } + + if (functionCallbackResolver != null) { + Assert.isNull(toolCallingManager, + "toolCallingManager cannot be set when functionCallbackResolver is set"); + List toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks + : List.of(); + + return new MistralAiChatModel(mistralAiApi, defaultOptions, functionCallbackResolver, toolCallbacks, + retryTemplate, observationRegistry); + } + + return new MistralAiChatModel(mistralAiApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate, + observationRegistry); + } + + } + } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index 2524836e001..a59e8a71e58 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ package org.springframework.ai.mistralai; import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -32,7 +34,9 @@ import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice; import org.springframework.ai.mistralai.api.MistralAiApi.FunctionTool; import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -45,7 +49,7 @@ * @since 0.8.1 */ @JsonInclude(JsonInclude.Include.NON_NULL) -public class MistralAiChatOptions implements FunctionCallingOptions { +public class MistralAiChatOptions implements ToolCallingChatOptions { /** * ID of the model to use @@ -112,34 +116,27 @@ public class MistralAiChatOptions implements FunctionCallingOptions { private @JsonProperty("tool_choice") ToolChoice toolChoice; /** - * MistralAI Tool Function Callbacks to register with the ChatModel. For Prompt - * Options the functionCallbacks are automatically enabled for the duration of the - * prompt execution. For Default Options the functionCallbacks are registered but - * disabled by default. Use the enableFunctions to set the functions from the registry - * to be used by the ChatModel chat completion requests. + * Collection of {@link ToolCallback}s to be used for tool calling in the chat + * completion requests. */ @JsonIgnore - private List functionCallbacks = new ArrayList<>(); + private List toolCallbacks = new ArrayList<>(); /** - * List of functions, identified by their names, to configure for function calling in - * the chat completion requests. Functions with those names must exist in the - * functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions - * are automatically enabled for the duration of the prompt execution. - * - * Note that function enabled with the default options are enabled for all chat - * completion requests. This could impact the token count and the billing. If the - * functions is set in a prompt options, then the enabled functions are only active - * for the duration of this prompt execution. + * Collection of tool names to be resolved at runtime and used for tool calling in the + * chat completion requests. */ @JsonIgnore - private Set functions = new HashSet<>(); + private Set toolNames = new HashSet<>(); + /** + * Whether to enable the tool execution lifecycle internally in ChatModel. + */ @JsonIgnore - private Boolean proxyToolCalls; + private Boolean internalToolExecutionEnabled; @JsonIgnore - private Map toolContext; + private Map toolContext = new HashMap<>(); public static Builder builder() { return new Builder(); @@ -156,9 +153,9 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) .stop(fromOptions.getStop()) .tools(fromOptions.getTools()) .toolChoice(fromOptions.getToolChoice()) - .functionCallbacks(fromOptions.getFunctionCallbacks()) - .functions(fromOptions.getFunctions()) - .proxyToolCalls(fromOptions.getProxyToolCalls()) + .toolCallbacks(fromOptions.getToolCallbacks()) + .toolNames(fromOptions.getToolNames()) + .internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled()) .toolContext(fromOptions.getToolContext()) .build(); } @@ -259,25 +256,73 @@ public void setTopP(Double topP) { } @Override + @JsonIgnore + public List getToolCallbacks() { + return this.toolCallbacks; + } + + @Override + @JsonIgnore + public void setToolCallbacks(List toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); + this.toolCallbacks = toolCallbacks; + } + + @Override + @JsonIgnore + public Set getToolNames() { + return this.toolNames; + } + + @Override + @JsonIgnore + public void setToolNames(Set toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + Assert.noNullElements(toolNames, "toolNames cannot contain null elements"); + toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements")); + this.toolNames = toolNames; + } + + @Override + @Nullable + @JsonIgnore + public Boolean isInternalToolExecutionEnabled() { + return internalToolExecutionEnabled; + } + + @Override + @JsonIgnore + public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.internalToolExecutionEnabled = internalToolExecutionEnabled; + } + + @Override + @Deprecated + @JsonIgnore public List getFunctionCallbacks() { - return this.functionCallbacks; + return this.getToolCallbacks(); } @Override + @Deprecated + @JsonIgnore public void setFunctionCallbacks(List functionCallbacks) { - Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null"); - this.functionCallbacks = functionCallbacks; + this.setToolCallbacks(functionCallbacks); } @Override + @Deprecated + @JsonIgnore public Set getFunctions() { - return this.functions; + return this.getToolNames(); } @Override - public void setFunctions(Set functions) { - Assert.notNull(functions, "Function must not be null"); - this.functions = functions; + @Deprecated + @JsonIgnore + public void setFunctions(Set functionNames) { + this.setToolNames(functionNames); } @Override @@ -299,20 +344,26 @@ public Integer getTopK() { } @Override + @Deprecated + @JsonIgnore public Boolean getProxyToolCalls() { - return this.proxyToolCalls; + return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null; } + @Deprecated + @JsonIgnore public void setProxyToolCalls(Boolean proxyToolCalls) { - this.proxyToolCalls = proxyToolCalls; + this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null; } @Override + @JsonIgnore public Map getToolContext() { return this.toolContext; } @Override + @JsonIgnore public void setToolContext(Map toolContext) { this.toolContext = toolContext; } @@ -324,10 +375,9 @@ public MistralAiChatOptions copy() { @Override public int hashCode() { - return Objects.hash(this.model, this.temperature, this.topP, this.maxTokens, this.safePrompt, this.randomSeed, - this.responseFormat, this.stop, this.tools, this.toolChoice, this.functionCallbacks, this.functions, - this.proxyToolCalls, this.toolContext); + this.responseFormat, this.stop, this.tools, this.toolChoice, this.toolCallbacks, this.tools, + this.internalToolExecutionEnabled, this.toolContext); } @Override @@ -348,9 +398,9 @@ public boolean equals(Object obj) { && Objects.equals(this.randomSeed, other.randomSeed) && Objects.equals(this.responseFormat, other.responseFormat) && Objects.equals(this.stop, other.stop) && Objects.equals(this.tools, other.tools) && Objects.equals(this.toolChoice, other.toolChoice) - && Objects.equals(this.functionCallbacks, other.functionCallbacks) - && Objects.equals(this.functions, other.functions) - && Objects.equals(this.proxyToolCalls, other.proxyToolCalls) + && Objects.equals(this.toolCallbacks, other.toolCallbacks) + && Objects.equals(this.toolNames, other.toolNames) + && Objects.equals(this.internalToolExecutionEnabled, other.internalToolExecutionEnabled) && Objects.equals(this.toolContext, other.toolContext); } @@ -413,25 +463,54 @@ public Builder toolChoice(ToolChoice toolChoice) { return this; } - public Builder functionCallbacks(List functionCallbacks) { - this.options.functionCallbacks = functionCallbacks; + public Builder toolCallbacks(List toolCallbacks) { + this.options.setToolCallbacks(toolCallbacks); return this; } - public Builder functions(Set functionNames) { - Assert.notNull(functionNames, "Function names must not be null"); - this.options.functions = functionNames; + public Builder toolCallbacks(FunctionCallback... toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks)); return this; } - public Builder function(String functionName) { - Assert.hasText(functionName, "Function name must not be empty"); - this.options.functions.add(functionName); + public Builder toolNames(Set toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + this.options.setToolNames(toolNames); + return this; + } + + public Builder toolNames(String... toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + this.options.toolNames.addAll(Set.of(toolNames)); return this; } + public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled); + return this; + } + + @Deprecated + public Builder functionCallbacks(List functionCallbacks) { + return toolCallbacks(functionCallbacks); + } + + @Deprecated + public Builder functions(Set functionNames) { + return toolNames(functionNames); + } + + @Deprecated + public Builder function(String functionName) { + return toolNames(functionName); + } + + @Deprecated public Builder proxyToolCalls(Boolean proxyToolCalls) { - this.options.proxyToolCalls = proxyToolCalls; + if (proxyToolCalls != null) { + this.options.setInternalToolExecutionEnabled(!proxyToolCalls); + } return this; } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java index 51451208277..a842fddf0d6 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,25 +21,32 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.mistralai.api.MistralAiApi; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.boot.test.context.SpringBootTest; +import java.util.Map; + import static org.assertj.core.api.Assertions.assertThat; /** * @author Ricken Bazolo * @author Alexandros Pappas + * @author Thomas Vitale * @since 0.8.1 */ @SpringBootTest(classes = MistralAiTestConfiguration.class) @EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+") public class MistralAiChatCompletionRequestTest { - MistralAiChatModel chatModel = new MistralAiChatModel(new MistralAiApi("test")); + MistralAiChatModel chatModel = MistralAiChatModel.builder().mistralAiApi(new MistralAiApi("test")).build(); @Test void chatCompletionDefaultRequestTest() { - - var request = this.chatModel.createRequest(new Prompt("test content"), false); + var prompt = this.chatModel.buildRequestPrompt(new Prompt("test content")); + var request = this.chatModel.createRequest(prompt, false); assertThat(request.messages()).hasSize(1); assertThat(request.topP()).isEqualTo(1); @@ -51,10 +58,9 @@ void chatCompletionDefaultRequestTest() { @Test void chatCompletionRequestWithOptionsTest() { - var options = MistralAiChatOptions.builder().temperature(0.5).topP(0.8).build(); - - var request = this.chatModel.createRequest(new Prompt("test content", options), true); + var prompt = this.chatModel.buildRequestPrompt(new Prompt("test content", options)); + var request = this.chatModel.createRequest(prompt, true); assertThat(request.messages().size()).isEqualTo(1); assertThat(request.topP()).isEqualTo(0.8); @@ -62,4 +68,58 @@ void chatCompletionRequestWithOptionsTest() { assertThat(request.stream()).isTrue(); } + @Test + void whenToolRuntimeOptionsThenMergeWithDefaults() { + MistralAiChatOptions defaultOptions = MistralAiChatOptions.builder() + .model("DEFAULT_MODEL") + .internalToolExecutionEnabled(true) + .toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2")) + .toolNames("tool1", "tool2") + .toolContext(Map.of("key1", "value1", "key2", "valueA")) + .build(); + + MistralAiChatModel chatModel = MistralAiChatModel.builder() + .mistralAiApi(new MistralAiApi("test")) + .defaultOptions(defaultOptions) + .build(); + + MistralAiChatOptions runtimeOptions = MistralAiChatOptions.builder() + .internalToolExecutionEnabled(false) + .toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4")) + .toolNames("tool3") + .toolContext(Map.of("key2", "valueB")) + .build(); + Prompt prompt = chatModel.buildRequestPrompt(new Prompt("Test message content", runtimeOptions)); + + assertThat(((ToolCallingChatOptions) prompt.getOptions())).isNotNull(); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).isInternalToolExecutionEnabled()).isFalse(); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(2); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks() + .stream() + .map(FunctionCallback::getName)).containsExactlyInAnyOrder("tool3", "tool4"); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolNames()).containsExactlyInAnyOrder("tool3"); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolContext()).containsEntry("key1", "value1") + .containsEntry("key2", "valueB"); + } + + static class TestToolCallback implements ToolCallback { + + private final ToolDefinition toolDefinition; + + public TestToolCallback(String name) { + this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build(); + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String toolInput) { + return "Mission accomplished!"; + } + + } + } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/mistralai-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/mistralai-chat-functions.adoc index ba5dbda6c24..bcddf3fb66d 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/mistralai-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/mistralai-chat-functions.adoc @@ -1,4 +1,6 @@ -= Mistral AI Function Calling += Mistral AI Function Calling (Deprecated) + +WARNING: This page describes the previous version of the Function Calling API, which has been deprecated and marked for remove in the next release. The current version is available at xref:api/tools.adoc[Tool Calling]. See the xref:api/tools-migration.adoc[Migration Guide] for more information. You can register custom Java functions with the `MistralAiChatModel` and have the Mistral AI models intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. This allows you to connect the LLM capabilities with external tools and APIs. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java index e380e897cc6..712756504fe 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,9 @@ package org.springframework.ai.autoconfigure.mistralai; -import java.util.List; - import io.micrometer.observation.ObservationRegistry; +import org.springframework.ai.autoconfigure.chat.model.ToolCallingAutoConfiguration; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; @@ -27,8 +26,8 @@ import org.springframework.ai.mistralai.MistralAiEmbeddingModel; import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.ImportAutoConfiguration; @@ -36,7 +35,6 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; -import org.springframework.boot.autoconfigure.web.reactive.function.client.WebClientAutoConfiguration; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; @@ -54,12 +52,13 @@ * @author Thomas Vitale * @since 0.8.1 */ -@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class }) +@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class, + ToolCallingAutoConfiguration.class }) @EnableConfigurationProperties({ MistralAiEmbeddingProperties.class, MistralAiCommonProperties.class, MistralAiChatProperties.class }) @ConditionalOnClass(MistralAiApi.class) @ImportAutoConfiguration(classes = { SpringAiRetryAutoConfiguration.class, RestClientAutoConfiguration.class, - WebClientAutoConfiguration.class }) + ToolCallingAutoConfiguration.class }) public class MistralAiAutoConfiguration { @Bean @@ -91,17 +90,21 @@ public MistralAiEmbeddingModel mistralAiEmbeddingModel(MistralAiCommonProperties matchIfMissing = true) public MistralAiChatModel mistralAiChatModel(MistralAiCommonProperties commonProperties, MistralAiChatProperties chatProperties, ObjectProvider restClientBuilderProvider, - List toolFunctionCallbacks, FunctionCallbackResolver functionCallbackResolver, - RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler, - ObjectProvider observationRegistry, + ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, + ResponseErrorHandler responseErrorHandler, ObjectProvider observationRegistry, ObjectProvider observationConvention) { var mistralAiApi = mistralAiApi(chatProperties.getApiKey(), commonProperties.getApiKey(), chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), restClientBuilderProvider.getIfAvailable(RestClient::builder), responseErrorHandler); - var chatModel = new MistralAiChatModel(mistralAiApi, chatProperties.getOptions(), functionCallbackResolver, - toolFunctionCallbacks, retryTemplate, observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)); + var chatModel = MistralAiChatModel.builder() + .mistralAiApi(mistralAiApi) + .defaultOptions(chatProperties.getOptions()) + .toolCallingManager(toolCallingManager) + .retryTemplate(retryTemplate) + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .build(); observationConvention.ifAvailable(chatModel::setObservationConvention);