Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import java.util.function.BiFunction;
import java.util.stream.Stream;

import com.fasterxml.jackson.annotation.JsonAlias;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import io.micrometer.common.util.StringUtils;
import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.client.McpSyncClient;
Expand All @@ -39,7 +37,7 @@
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;
import org.springframework.ai.tool.method.MethodToolCallback;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeType;

Expand Down Expand Up @@ -196,12 +194,13 @@ public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(To
public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(ToolCallback toolCallback,
MimeType mimeType) {

SharedSyncToolSpecification sharedSpec = toSharedSyncToolSpecification(toolCallback, mimeType);
SharedAsyncToolSpecification sharedSpec = toSharedAsyncToolSpecification(toolCallback, mimeType);

return new McpServerFeatures.SyncToolSpecification(sharedSpec.tool(),
(exchange, map) -> sharedSpec.sharedHandler()
.apply(exchange, new CallToolRequest(sharedSpec.tool().name(), map)),
(exchange, request) -> sharedSpec.sharedHandler().apply(exchange, request));
.apply(exchange, new CallToolRequest(sharedSpec.tool().name(), map))
.block(),
(exchange, request) -> sharedSpec.sharedHandler().apply(exchange, request).block());
}

/**
Expand All @@ -219,15 +218,15 @@ public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(To
public static McpStatelessServerFeatures.SyncToolSpecification toStatelessSyncToolSpecification(
ToolCallback toolCallback, MimeType mimeType) {

var sharedSpec = toSharedSyncToolSpecification(toolCallback, mimeType);
var sharedSpec = toSharedAsyncToolSpecification(toolCallback, mimeType);

return McpStatelessServerFeatures.SyncToolSpecification.builder()
.tool(sharedSpec.tool())
.callHandler((exchange, request) -> sharedSpec.sharedHandler().apply(exchange, request))
.callHandler((exchange, request) -> sharedSpec.sharedHandler().apply(exchange, request).block())
.build();
}

private static SharedSyncToolSpecification toSharedSyncToolSpecification(ToolCallback toolCallback,
private static SharedAsyncToolSpecification toSharedAsyncToolSpecification(ToolCallback toolCallback,
MimeType mimeType) {

var tool = McpSchema.Tool.builder()
Expand All @@ -237,20 +236,31 @@ private static SharedSyncToolSpecification toSharedSyncToolSpecification(ToolCal
McpSchema.JsonSchema.class))
.build();

return new SharedSyncToolSpecification(tool, (exchangeOrContext, request) -> {
try {
String callResult = toolCallback.call(ModelOptionsUtils.toJsonString(request.arguments()),
new ToolContext(Map.of(TOOL_CONTEXT_MCP_EXCHANGE_KEY, exchangeOrContext)));
return new SharedAsyncToolSpecification(tool, (exchangeOrContext, request) -> {
final String toolRequest = ModelOptionsUtils.toJsonString(request.arguments());
final ToolContext toolContext = new ToolContext(Map.of(TOOL_CONTEXT_MCP_EXCHANGE_KEY, exchangeOrContext));
final Mono<String> callResult;
if (toolCallback instanceof MethodToolCallback reactiveMethodToolCallback) {
callResult = reactiveMethodToolCallback.callReactive(toolRequest, toolContext);
}
else {
callResult = Mono.fromCallable(() -> toolCallback.call(toolRequest, toolContext));
}
return callResult.map(result -> {
if (mimeType != null && mimeType.toString().startsWith("image")) {
McpSchema.Annotations annotations = new McpSchema.Annotations(List.of(Role.ASSISTANT), null);
return new McpSchema.CallToolResult(
List.of(new McpSchema.ImageContent(annotations, callResult, mimeType.toString())), false);
return McpSchema.CallToolResult.builder()
.addContent(new McpSchema.ImageContent(annotations, result, mimeType.toString()))
.isError(false)
.build();
}
return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(callResult)), false);
}
catch (Exception e) {
return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(e.getMessage())), true);
}
return McpSchema.CallToolResult.builder().addTextContent(result).isError(false).build();
})
.onErrorResume(Exception.class,
error -> Mono.fromSupplier(() -> McpSchema.CallToolResult.builder()
.addTextContent(error.getMessage())
.isError(true)
.build()));
});
}

Expand Down Expand Up @@ -331,7 +341,6 @@ public static McpServerFeatures.AsyncToolSpecification toAsyncToolSpecification(
* This method enables Spring AI tools to be exposed as asynchronous MCP tools that
* can be discovered and invoked by language models. The conversion process:
* <ul>
* <li>First converts the callback to a synchronous specification</li>
* <li>Wraps the synchronous execution in a reactive Mono</li>
* <li>Configures execution on a bounded elastic scheduler for non-blocking
* operation</li>
Expand All @@ -352,26 +361,24 @@ public static McpServerFeatures.AsyncToolSpecification toAsyncToolSpecification(
public static McpServerFeatures.AsyncToolSpecification toAsyncToolSpecification(ToolCallback toolCallback,
MimeType mimeType) {

McpServerFeatures.SyncToolSpecification syncToolSpecification = toSyncToolSpecification(toolCallback, mimeType);
SharedAsyncToolSpecification asyncToolSpecification = toSharedAsyncToolSpecification(toolCallback, mimeType);

return McpServerFeatures.AsyncToolSpecification.builder()
.tool(syncToolSpecification.tool())
.callHandler((exchange, request) -> Mono
.fromCallable(
() -> syncToolSpecification.callHandler().apply(new McpSyncServerExchange(exchange), request))
.tool(asyncToolSpecification.tool())
.callHandler((exchange, request) -> asyncToolSpecification.sharedHandler()
.apply(new McpSyncServerExchange(exchange), request)
.subscribeOn(Schedulers.boundedElastic()))
.build();
}

public static McpStatelessServerFeatures.AsyncToolSpecification toStatelessAsyncToolSpecification(
ToolCallback toolCallback, MimeType mimeType) {

McpStatelessServerFeatures.SyncToolSpecification statelessSyncToolSpecification = toStatelessSyncToolSpecification(
toolCallback, mimeType);
SharedAsyncToolSpecification asyncToolSpecification = toSharedAsyncToolSpecification(toolCallback, mimeType);

return new McpStatelessServerFeatures.AsyncToolSpecification(statelessSyncToolSpecification.tool(),
(context, request) -> Mono
.fromCallable(() -> statelessSyncToolSpecification.callHandler().apply(context, request))
return new McpStatelessServerFeatures.AsyncToolSpecification(asyncToolSpecification.tool(),
(context, request) -> asyncToolSpecification.sharedHandler()
.apply(context, request)
.subscribeOn(Schedulers.boundedElastic()));
}

Expand Down Expand Up @@ -441,13 +448,8 @@ public static List<ToolCallback> getToolCallbacksFromAsyncClients(List<McpAsyncC
return List.of((AsyncMcpToolCallbackProvider.builder().mcpClients(asyncMcpClients).build().getToolCallbacks()));
}

@JsonIgnoreProperties(ignoreUnknown = true)
// @formatter:off
private record Base64Wrapper(@JsonAlias("mimetype") @Nullable MimeType mimeType, @JsonAlias({
"base64", "b64", "imageData" }) @Nullable String data) {
private record SharedAsyncToolSpecification(McpSchema.Tool tool,
BiFunction<Object, CallToolRequest, Mono<McpSchema.CallToolResult>> sharedHandler) {
}

private record SharedSyncToolSpecification(McpSchema.Tool tool,
BiFunction<Object, CallToolRequest, McpSchema.CallToolResult> sharedHandler) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
import java.util.stream.Stream;

import com.fasterxml.jackson.core.type.TypeReference;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.ToolCallback;
Expand Down Expand Up @@ -96,6 +98,10 @@ public String call(String toolInput) {

@Override
public String call(String toolInput, @Nullable ToolContext toolContext) {
return callReactive(toolInput, toolContext).block();
}

public Mono<String> callReactive(String toolInput, @Nullable ToolContext toolContext) {
Assert.hasText(toolInput, "toolInput cannot be null or empty");

logger.debug("Starting execution of tool: {}", this.toolDefinition.name());
Expand All @@ -106,13 +112,13 @@ public String call(String toolInput, @Nullable ToolContext toolContext) {

Object[] methodArguments = buildMethodArguments(toolArguments, toolContext);

Object result = callMethod(methodArguments);

logger.debug("Successful execution of tool: {}", this.toolDefinition.name());
return callMethod(methodArguments).map(result -> {
logger.debug("Successful execution of tool: {}", this.toolDefinition.name());

Type returnType = this.toolMethod.getGenericReturnType();
Type returnType = this.toolMethod.getGenericReturnType();

return this.toolCallResultConverter.convert(result, returnType);
return this.toolCallResultConverter.convert(result, returnType);
});
}

private void validateToolContextSupport(@Nullable ToolContext toolContext) {
Expand Down Expand Up @@ -155,15 +161,16 @@ private Object buildTypedArgument(@Nullable Object value, Type type) {
return JsonParser.fromJson(json, type);
}

@Nullable
private Object callMethod(Object[] methodArguments) {
private Mono<Object> callMethod(Object[] methodArguments) {
if (isObjectNotPublic() || isMethodNotPublic()) {
this.toolMethod.setAccessible(true);
}

Object result;
final Mono<Object> result;
try {
result = this.toolMethod.invoke(this.toolObject, methodArguments);
result = Publisher.class.isAssignableFrom(this.toolMethod.getReturnType())
? Mono.from((Publisher<Object>) this.toolMethod.invoke(this.toolObject, methodArguments))
: Mono.justOrEmpty(this.toolMethod.invoke(this.toolObject, methodArguments));
}
catch (IllegalAccessException ex) {
throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex);
Expand Down