diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java index 5f8da416109..eda63ec20d8 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java @@ -59,7 +59,7 @@ * @see McpAsyncClient * @see Tool */ -public class AsyncMcpToolCallback implements ToolCallback { +public class AsyncMcpToolCallback implements McpToolCallback { private final McpAsyncClient asyncMcpClient; diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolCallback.java new file mode 100644 index 00000000000..732b170bac2 --- /dev/null +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolCallback.java @@ -0,0 +1,28 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp; + +import org.springframework.ai.tool.ToolCallback; + +/** + * Custom type for MCP specific tool. + */ +public interface McpToolCallback extends ToolCallback { + + // TODO: Add MCP metadata + +} diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java index fc61d801df1..0e65d43fc35 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java @@ -61,7 +61,7 @@ * @see McpSyncClient * @see Tool */ -public class SyncMcpToolCallback implements ToolCallback { +public class SyncMcpToolCallback implements McpToolCallback { private static final Logger logger = LoggerFactory.getLogger(SyncMcpToolCallback.class); diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 2ded856a05f..1b1eb1d1596 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -455,17 +455,22 @@ Prompt buildRequestPrompt(Prompt prompt) { this.defaultOptions.getInternalToolExecutionEnabled())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); - requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), - this.defaultOptions.getToolCallbacks())); + // Make sure to set the tool context before setting toolcallbacks so that the + // context can be used to filter the toolcallbacks. requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(), this.defaultOptions.getToolContext())); + requestOptions.setToolCallbacks(runtimeOptions.getFilteredToolCallbacks(ToolCallingChatOptions + .mergeToolCallbacks(runtimeOptions.getToolCallbacks(), this.defaultOptions.getToolCallbacks()))); } else { requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders()); requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); - requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); + // Make sure to set the tool context before setting toolcallbacks so that the + // context can be used to filter the toolcallbacks. requestOptions.setToolContext(this.defaultOptions.getToolContext()); + requestOptions + .setToolCallbacks(this.defaultOptions.getFilteredToolCallbacks(this.defaultOptions.getToolCallbacks())); } ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index dbfbee561c8..4bf3f831af2 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -24,6 +24,8 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.BiPredicate; +import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -82,13 +84,15 @@ public class AnthropicChatOptions implements ToolCallingChatOptions { @JsonIgnore private Map toolContext = new HashMap<>(); - /** * Optional HTTP headers to be added to the chat completion request. */ @JsonIgnore private Map httpHeaders = new HashMap<>(); + @JsonIgnore + private Predicate toolCallbackFilter; + // @formatter:on public static Builder builder() { @@ -110,6 +114,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) + .toolCallbackFilter(fromOptions.getToolCallbackFilter()) .httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null) .build(); } @@ -259,6 +264,16 @@ public void setHttpHeaders(Map httpHeaders) { this.httpHeaders = httpHeaders; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override @SuppressWarnings("unchecked") public AnthropicChatOptions copy() { @@ -384,6 +399,11 @@ public Builder toolContext(Map toolContext) { return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public Builder httpHeaders(Map httpHeaders) { this.options.setHttpHeaders(httpHeaders); return this; diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java index 6570d5ee6a6..e35d8b680a9 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java @@ -21,6 +21,8 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.function.BiPredicate; +import java.util.function.Predicate; import java.util.stream.Collectors; import org.junit.jupiter.api.Test; @@ -42,6 +44,7 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -50,6 +53,7 @@ import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; @@ -284,11 +288,23 @@ void functionCallTest() { var promptOptions = AnthropicChatOptions.builder() .model(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getName()) + .toolContext(Map.of("tool_prefix", "get")) .toolCallbacks(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description( "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") .inputType(MockWeatherService.Request.class) - .build()) + .build(), + FunctionToolCallback.builder("retrieveWeather", new MockWeatherService()) + .description( + "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") + .inputType(MockWeatherService.Request.class) + .build()) + .toolCallbackFilter(new Predicate() { + @Override + public boolean test(ToolCallback toolCallback) { + return (toolCallback.getToolDefinition().name().startsWith("get")) ? true : false; + } + }) .build(); ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index da442b4ad4d..6a83f0970d7 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; import com.azure.ai.openai.models.ChatCompletionStreamOptions; @@ -257,6 +258,9 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @JsonIgnore + private Predicate toolCallbackFilter; + public static Builder builder() { return new Builder(); } @@ -288,6 +292,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti .toolCallbacks( fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null) .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) + .toolCallbackFilter(fromOptions.getToolCallbackFilter()) .build(); } @@ -474,6 +479,16 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + public ChatCompletionStreamOptions getStreamOptions() { return this.streamOptions; } @@ -664,6 +679,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public AzureOpenAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java index 776cba66d58..09578a59fd6 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java @@ -24,10 +24,12 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.lang.Nullable; @@ -77,6 +79,9 @@ public class BedrockChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Predicate toolCallbackFilter; + public static Builder builder() { return new Builder(); } @@ -96,6 +101,7 @@ public static BedrockChatOptions fromOptions(BedrockChatOptions fromOptions) { .toolNames(new HashSet<>(fromOptions.getToolNames())) .toolContext(new HashMap<>(fromOptions.getToolContext())) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .toolCallbackFilter(fromOptions.getToolCallbackFilter()) .build(); } @@ -224,6 +230,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override @SuppressWarnings("unchecked") public BedrockChatOptions copy() { @@ -337,6 +353,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public BedrockChatOptions build() { return this.options; } diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 071e77a78cb..b81fca818f0 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -303,8 +303,8 @@ Prompt buildRequestPrompt(Prompt prompt) { : this.defaultOptions.getTemperature()) .topP(runtimeOptions.getTopP() != null ? runtimeOptions.getTopP() : this.defaultOptions.getTopP()) - .toolCallbacks(runtimeOptions.getToolCallbacks() != null ? runtimeOptions.getToolCallbacks() - : this.defaultOptions.getToolCallbacks()) + .toolCallbacks(runtimeOptions.getFilteredToolCallbacks(runtimeOptions.getToolCallbacks() != null + ? runtimeOptions.getToolCallbacks() : this.defaultOptions.getToolCallbacks())) .toolNames(runtimeOptions.getToolNames() != null ? runtimeOptions.getToolNames() : this.defaultOptions.getToolNames()) .toolContext(runtimeOptions.getToolContext() != null ? runtimeOptions.getToolContext() @@ -312,6 +312,7 @@ Prompt buildRequestPrompt(Prompt prompt) { .internalToolExecutionEnabled(runtimeOptions.getInternalToolExecutionEnabled() != null ? runtimeOptions.getInternalToolExecutionEnabled() : this.defaultOptions.getInternalToolExecutionEnabled()) + .toolCallbackFilter(runtimeOptions.getToolCallbackFilter()) .build(); } diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java index b9c7a3d4962..e7c02347dc5 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -143,7 +144,10 @@ public class DeepSeekChatOptions implements ToolCallingChatOptions { private Set toolNames = new HashSet<>(); @JsonIgnore - private Map toolContext = new HashMap<>();; + private Map toolContext = new HashMap<>(); + + @JsonIgnore + private Predicate toolCallbackFilter; public static Builder builder() { return new Builder(); @@ -246,7 +250,6 @@ public void setToolChoice(Object toolChoice) { this.toolChoice = toolChoice; } - @Override @JsonIgnore public List getToolCallbacks() { @@ -322,6 +325,16 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override public DeepSeekChatOptions copy() { return DeepSeekChatOptions.fromOptions(this); @@ -379,6 +392,7 @@ public static DeepSeekChatOptions fromOptions(DeepSeekChatOptions fromOptions) { .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) + .toolCallbackFilter(fromOptions.getToolCallbackFilter()) .build(); } @@ -497,6 +511,11 @@ public Builder toolContext(Map toolContext) { return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public DeepSeekChatOptions build() { return this.options; } diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java index 2ee9e4fa029..1ad98d216da 100644 --- a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -138,6 +139,9 @@ public class GoogleGenAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private List safetySettings = new ArrayList<>(); + + @JsonIgnore + private Predicate toolCallbackFilter; // @formatter:on public static Builder builder() { @@ -327,6 +331,16 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -489,6 +503,11 @@ public Builder toolContext(Map toolContext) { return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public GoogleGenAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java index a8f1e62e77e..0b14fb600ab 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java @@ -25,6 +25,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -156,6 +157,9 @@ public class MiniMaxChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Predicate toolCallbackFilter; + // @formatter:on public static Builder builder() { @@ -180,6 +184,7 @@ public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) { .toolNames(fromOptions.getToolNames()) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) .toolContext(fromOptions.getToolContext()) + .toolCallbackFilter(fromOptions.getToolCallbackFilter()) .build(); } @@ -362,6 +367,16 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override public int hashCode() { return Objects.hash(model, frequencyPenalty, maxTokens, n, presencePenalty, responseFormat, seed, stop, @@ -508,6 +523,11 @@ public Builder toolContext(Map toolContext) { return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public MiniMaxChatOptions build() { return this.options; } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index 801c35f2118..8993bf6db62 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -159,6 +160,9 @@ public class MistralAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private Map toolContext = new HashMap<>(); + @JsonIgnore + private Predicate toolCallbackFilter; + public static Builder builder() { return new Builder(); } @@ -182,6 +186,7 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) + .toolCallbackFilter(fromOptions.getToolCallbackFilter()) .build(); } @@ -366,6 +371,16 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override @SuppressWarnings("unchecked") public MistralAiChatOptions copy() { @@ -517,6 +532,11 @@ public Builder toolContext(Map toolContext) { return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public MistralAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index a71be1ce2b2..4e631fd6ade 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; import java.util.stream.Collectors; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -344,6 +345,9 @@ public class OllamaOptions implements ToolCallingChatOptions, EmbeddingOptions { @JsonIgnore private Map toolContext = new HashMap<>(); + @JsonIgnore + private Predicate toolCallbackFilter; + public static Builder builder() { return new Builder(); } @@ -398,7 +402,8 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) { .toolNames(fromOptions.getToolNames()) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) .toolCallbacks(fromOptions.getToolCallbacks()) - .toolContext(fromOptions.getToolContext()).build(); + .toolContext(fromOptions.getToolContext()) + .toolCallbackFilter(fromOptions.getToolCallbackFilter()).build(); } // ------------------- @@ -764,6 +769,16 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + /** * Convert the {@link OllamaOptions} object to a {@link Map} of key/value pairs. * @return The {@link Map} of key/value pairs. @@ -1039,6 +1054,11 @@ public Builder toolContext(Map toolContext) { return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public OllamaOptions build() { return this.options; } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 2ad584fa82f..7026c8d05d4 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -529,10 +529,10 @@ Prompt buildRequestPrompt(Prompt prompt) { this.defaultOptions.getInternalToolExecutionEnabled())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); - requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), - this.defaultOptions.getToolCallbacks())); requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(), this.defaultOptions.getToolContext())); + requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), + this.defaultOptions.getToolCallbacks())); } else { requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders()); @@ -542,7 +542,8 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setToolContext(this.defaultOptions.getToolContext()); } - ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); + ToolCallingChatOptions + .validateToolCallbacks(requestOptions.getFilteredToolCallbacks(requestOptions.getToolCallbacks())); return new Prompt(prompt.getInstructions(), requestOptions); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index afbbd803ec6..b115c25b335 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -228,6 +229,9 @@ public class OpenAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private Map toolContext = new HashMap<>(); + @JsonIgnore + private Predicate toolCallbackFilter; + // @formatter:on public static Builder builder() { @@ -268,6 +272,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) { .metadata(fromOptions.getMetadata()) .reasoningEffort(fromOptions.getReasoningEffort()) .webSearchOptions(fromOptions.getWebSearchOptions()) + .toolCallbackFilter(fromOptions.getToolCallbackFilter()) .build(); } @@ -564,6 +569,16 @@ public void setWebSearchOptions(WebSearchOptions webSearchOptions) { this.webSearchOptions = webSearchOptions; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override public OpenAiChatOptions copy() { return OpenAiChatOptions.fromOptions(this); @@ -802,6 +817,11 @@ public Builder webSearchOptions(WebSearchOptions webSearchOptions) { return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public OpenAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index 9c7788c82a3..ca551f6daf4 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -24,6 +24,8 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.BiPredicate; +import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -151,6 +153,9 @@ public class VertexAiGeminiChatOptions implements ToolCallingChatOptions { @JsonIgnore private List safetySettings = new ArrayList<>(); + + @JsonIgnore + private Predicate toolCallbackFilter; // @formatter:on public static Builder builder() { @@ -178,6 +183,7 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr options.setToolContext(fromOptions.getToolContext()); options.setLogprobs(fromOptions.getLogprobs()); options.setResponseLogprobs(fromOptions.getResponseLogprobs()); + options.setToolCallbackFilter(fromOptions.getToolCallbackFilter()); return options; } @@ -358,6 +364,16 @@ public boolean getResponseLogprobs() { return responseLogprobs; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java index c31320defe1..2a74d41e56c 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Predicate; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -126,6 +127,9 @@ public class ZhiPuAiChatOptions implements ToolCallingChatOptions { private Map toolContext = new HashMap<>(); // @formatter:on + @JsonIgnore + private Predicate toolCallbackFilter; + public static Builder builder() { return new Builder(); } @@ -146,6 +150,7 @@ public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) { .toolNames(fromOptions.getToolNames()) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) .toolContext(fromOptions.getToolContext()) + .toolCallbackFilter(fromOptions.getToolCallbackFilter()) .build(); } @@ -314,6 +319,16 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + @JsonIgnore + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + @Override + public void setToolCallbackFilter(Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override public int hashCode() { final int prime = 31; @@ -610,6 +625,11 @@ public Builder toolContext(Map toolContext) { return this; } + public Builder toolCallbackFilter(Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + public ZhiPuAiChatOptions build() { return this.options; } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java index 870db6931b9..da295eda950 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java @@ -23,6 +23,8 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.BiPredicate; +import java.util.function.Predicate; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.tool.ToolCallback; @@ -70,6 +72,9 @@ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions { @Nullable private Double topP; + @Nullable + private Predicate toolCallbackFilter; + @Override public List getToolCallbacks() { return List.copyOf(this.toolCallbacks); @@ -198,6 +203,16 @@ public void setTopP(@Nullable Double topP) { this.topP = topP; } + @Override + @Nullable + public Predicate getToolCallbackFilter() { + return this.toolCallbackFilter; + } + + public void setToolCallbackFilter(@Nullable Predicate toolCallbackFilter) { + this.toolCallbackFilter = toolCallbackFilter; + } + @Override @SuppressWarnings("unchecked") public T copy() { @@ -214,6 +229,7 @@ public T copy() { options.setTemperature(getTemperature()); options.setTopK(getTopK()); options.setTopP(getTopP()); + options.setToolCallbackFilter(getToolCallbackFilter()); return (T) options; } @@ -325,6 +341,13 @@ public ToolCallingChatOptions.Builder topP(@Nullable Double topP) { return this; } + @Override + public ToolCallingChatOptions.Builder toolCallbackFilter( + @Nullable Predicate toolCallbackFilter) { + this.options.setToolCallbackFilter(toolCallbackFilter); + return this; + } + @Override public ToolCallingChatOptions build() { return this.options; diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java index f06e71aa869..09463719379 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java @@ -22,8 +22,11 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.BiPredicate; +import java.util.function.Predicate; import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.support.ToolUtils; @@ -88,6 +91,26 @@ public interface ToolCallingChatOptions extends ChatOptions { */ void setToolContext(Map toolContext); + void setToolCallbackFilter(Predicate toolCallbackFilter); + + Predicate getToolCallbackFilter(); + + default List getFilteredToolCallbacks(List toolCallbacks) { + Predicate filter = getToolCallbackFilter(); + if (filter == null) { + return this.getToolCallbacks(); + } + else { + return applyFilter(toolCallbacks, filter); + } + } + + private List applyFilter(List toolCallbacks, + Predicate filter) { + + return toolCallbacks.stream().filter(toolCallback -> filter.test((T) toolCallback)).toList(); + } + /** * A builder to create a new {@link ToolCallingChatOptions} instance. */ @@ -193,6 +216,8 @@ interface Builder extends ChatOptions.Builder { */ Builder toolContext(String key, Object value); + Builder toolCallbackFilter(@Nullable Predicate toolCallbackFilter); + // ChatOptions.Builder methods @Override