Skip to content

Tool filtering as ToolCallingChatOption #3941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
* @see McpAsyncClient
* @see Tool
*/
public class AsyncMcpToolCallback implements ToolCallback {
public class AsyncMcpToolCallback implements McpToolCallback {

private final McpAsyncClient asyncMcpClient;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright 2025-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.mcp;

import org.springframework.ai.tool.ToolCallback;

/**
* Custom type for MCP specific tool.
*/
public interface McpToolCallback extends ToolCallback {

// TODO: Add MCP metadata

}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
* @see McpSyncClient
* @see Tool
*/
public class SyncMcpToolCallback implements ToolCallback {
public class SyncMcpToolCallback implements McpToolCallback {

private static final Logger logger = LoggerFactory.getLogger(SyncMcpToolCallback.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,17 +455,22 @@ Prompt buildRequestPrompt(Prompt prompt) {
this.defaultOptions.getInternalToolExecutionEnabled()));
requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
this.defaultOptions.getToolNames()));
requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(),
this.defaultOptions.getToolCallbacks()));
// Make sure to set the tool context before setting toolcallbacks so that the
// context can be used to filter the toolcallbacks.
requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(),
this.defaultOptions.getToolContext()));
requestOptions.setToolCallbacks(runtimeOptions.getFilteredToolCallbacks(ToolCallingChatOptions
.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), this.defaultOptions.getToolCallbacks())));
}
else {
requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
requestOptions.setToolNames(this.defaultOptions.getToolNames());
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
// Make sure to set the tool context before setting toolcallbacks so that the
// context can be used to filter the toolcallbacks.
requestOptions.setToolContext(this.defaultOptions.getToolContext());
requestOptions
.setToolCallbacks(this.defaultOptions.getFilteredToolCallbacks(this.defaultOptions.getToolCallbacks()));
}

ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.BiPredicate;
import java.util.function.Predicate;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
Expand Down Expand Up @@ -82,13 +84,15 @@ public class AnthropicChatOptions implements ToolCallingChatOptions {
@JsonIgnore
private Map<String, Object> toolContext = new HashMap<>();


/**
* Optional HTTP headers to be added to the chat completion request.
*/
@JsonIgnore
private Map<String, String> httpHeaders = new HashMap<>();

@JsonIgnore
private Predicate<? extends ToolCallback> toolCallbackFilter;

// @formatter:on

public static Builder builder() {
Expand All @@ -110,6 +114,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
.toolCallbackFilter(fromOptions.getToolCallbackFilter())
.httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null)
.build();
}
Expand Down Expand Up @@ -259,6 +264,16 @@ public void setHttpHeaders(Map<String, String> httpHeaders) {
this.httpHeaders = httpHeaders;
}

@JsonIgnore
public Predicate<? extends ToolCallback> getToolCallbackFilter() {
return this.toolCallbackFilter;
}

@Override
public void setToolCallbackFilter(Predicate<? extends ToolCallback> toolCallbackFilter) {
this.toolCallbackFilter = toolCallbackFilter;
}

@Override
@SuppressWarnings("unchecked")
public AnthropicChatOptions copy() {
Expand Down Expand Up @@ -384,6 +399,11 @@ public Builder toolContext(Map<String, Object> toolContext) {
return this;
}

public Builder toolCallbackFilter(Predicate<? extends ToolCallback> toolCallbackFilter) {
this.options.setToolCallbackFilter(toolCallbackFilter);
return this;
}

public Builder httpHeaders(Map<String, String> httpHeaders) {
this.options.setHttpHeaders(httpHeaders);
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.BiPredicate;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import org.junit.jupiter.api.Test;
Expand All @@ -42,6 +44,7 @@
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
Expand All @@ -50,6 +53,7 @@
import org.springframework.ai.converter.ListOutputConverter;
import org.springframework.ai.converter.MapOutputConverter;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
Expand Down Expand Up @@ -284,11 +288,23 @@ void functionCallTest() {

var promptOptions = AnthropicChatOptions.builder()
.model(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getName())
.toolContext(Map.of("tool_prefix", "get"))
.toolCallbacks(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
.description(
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
.inputType(MockWeatherService.Request.class)
.build())
.build(),
FunctionToolCallback.builder("retrieveWeather", new MockWeatherService())
.description(
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
.inputType(MockWeatherService.Request.class)
.build())
.toolCallbackFilter(new Predicate<ToolCallback>() {
@Override
public boolean test(ToolCallback toolCallback) {
return (toolCallback.getToolDefinition().name().startsWith("get")) ? true : false;
}
})
.build();

ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;

import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
import com.azure.ai.openai.models.ChatCompletionStreamOptions;
Expand Down Expand Up @@ -257,6 +258,9 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut
this.internalToolExecutionEnabled = internalToolExecutionEnabled;
}

@JsonIgnore
private Predicate<? extends ToolCallback> toolCallbackFilter;

public static Builder builder() {
return new Builder();
}
Expand Down Expand Up @@ -288,6 +292,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
.toolCallbacks(
fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null)
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
.toolCallbackFilter(fromOptions.getToolCallbackFilter())
.build();
}

Expand Down Expand Up @@ -474,6 +479,16 @@ public void setToolContext(Map<String, Object> toolContext) {
this.toolContext = toolContext;
}

@JsonIgnore
public Predicate<? extends ToolCallback> getToolCallbackFilter() {
return this.toolCallbackFilter;
}

@Override
public void setToolCallbackFilter(Predicate<? extends ToolCallback> toolCallbackFilter) {
this.toolCallbackFilter = toolCallbackFilter;
}

public ChatCompletionStreamOptions getStreamOptions() {
return this.streamOptions;
}
Expand Down Expand Up @@ -664,6 +679,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut
return this;
}

public Builder toolCallbackFilter(Predicate<? extends ToolCallback> toolCallbackFilter) {
this.options.setToolCallbackFilter(toolCallbackFilter);
return this;
}

public AzureOpenAiChatOptions build() {
return this.options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;

import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;
Expand Down Expand Up @@ -77,6 +79,9 @@ public class BedrockChatOptions implements ToolCallingChatOptions {
@JsonIgnore
private Boolean internalToolExecutionEnabled;

@JsonIgnore
private Predicate<? extends ToolCallback> toolCallbackFilter;

public static Builder builder() {
return new Builder();
}
Expand All @@ -96,6 +101,7 @@ public static BedrockChatOptions fromOptions(BedrockChatOptions fromOptions) {
.toolNames(new HashSet<>(fromOptions.getToolNames()))
.toolContext(new HashMap<>(fromOptions.getToolContext()))
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
.toolCallbackFilter(fromOptions.getToolCallbackFilter())
.build();
}

Expand Down Expand Up @@ -224,6 +230,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut
this.internalToolExecutionEnabled = internalToolExecutionEnabled;
}

@JsonIgnore
public Predicate<? extends ToolCallback> getToolCallbackFilter() {
return this.toolCallbackFilter;
}

@Override
public void setToolCallbackFilter(Predicate<? extends ToolCallback> toolCallbackFilter) {
this.toolCallbackFilter = toolCallbackFilter;
}

@Override
@SuppressWarnings("unchecked")
public BedrockChatOptions copy() {
Expand Down Expand Up @@ -337,6 +353,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut
return this;
}

public Builder toolCallbackFilter(Predicate<? extends ToolCallback> toolCallbackFilter) {
this.options.setToolCallbackFilter(toolCallbackFilter);
return this;
}

public BedrockChatOptions build() {
return this.options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,16 @@ Prompt buildRequestPrompt(Prompt prompt) {
: this.defaultOptions.getTemperature())
.topP(runtimeOptions.getTopP() != null ? runtimeOptions.getTopP() : this.defaultOptions.getTopP())

.toolCallbacks(runtimeOptions.getToolCallbacks() != null ? runtimeOptions.getToolCallbacks()
: this.defaultOptions.getToolCallbacks())
.toolCallbacks(runtimeOptions.getFilteredToolCallbacks(runtimeOptions.getToolCallbacks() != null
? runtimeOptions.getToolCallbacks() : this.defaultOptions.getToolCallbacks()))
.toolNames(runtimeOptions.getToolNames() != null ? runtimeOptions.getToolNames()
: this.defaultOptions.getToolNames())
.toolContext(runtimeOptions.getToolContext() != null ? runtimeOptions.getToolContext()
: this.defaultOptions.getToolContext())
.internalToolExecutionEnabled(runtimeOptions.getInternalToolExecutionEnabled() != null
? runtimeOptions.getInternalToolExecutionEnabled()
: this.defaultOptions.getInternalToolExecutionEnabled())
.toolCallbackFilter(runtimeOptions.getToolCallbackFilter())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
Expand Down Expand Up @@ -143,7 +144,10 @@ public class DeepSeekChatOptions implements ToolCallingChatOptions {
private Set<String> toolNames = new HashSet<>();

@JsonIgnore
private Map<String, Object> toolContext = new HashMap<>();;
private Map<String, Object> toolContext = new HashMap<>();

@JsonIgnore
private Predicate<? extends ToolCallback> toolCallbackFilter;

public static Builder builder() {
return new Builder();
Expand Down Expand Up @@ -246,7 +250,6 @@ public void setToolChoice(Object toolChoice) {
this.toolChoice = toolChoice;
}


@Override
@JsonIgnore
public List<ToolCallback> getToolCallbacks() {
Expand Down Expand Up @@ -322,6 +325,16 @@ public void setToolContext(Map<String, Object> toolContext) {
this.toolContext = toolContext;
}

@JsonIgnore
public Predicate<? extends ToolCallback> getToolCallbackFilter() {
return this.toolCallbackFilter;
}

@Override
public void setToolCallbackFilter(Predicate<? extends ToolCallback> toolCallbackFilter) {
this.toolCallbackFilter = toolCallbackFilter;
}

@Override
public DeepSeekChatOptions copy() {
return DeepSeekChatOptions.fromOptions(this);
Expand Down Expand Up @@ -379,6 +392,7 @@ public static DeepSeekChatOptions fromOptions(DeepSeekChatOptions fromOptions) {
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
.toolCallbackFilter(fromOptions.getToolCallbackFilter())
.build();
}

Expand Down Expand Up @@ -497,6 +511,11 @@ public Builder toolContext(Map<String, Object> toolContext) {
return this;
}

public Builder toolCallbackFilter(Predicate<? extends ToolCallback> toolCallbackFilter) {
this.options.setToolCallbackFilter(toolCallbackFilter);
return this;
}

public DeepSeekChatOptions build() {
return this.options;
}
Expand Down
Loading