Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
*/
public abstract class AbstractAsyncMcpToolMethodCallback<T> {

protected final Class<? extends Throwable> toolCallExceptionClass;

private static final TypeReference<Map<String, Object>> MAP_TYPE_REFERENCE = new TypeReference<Map<String, Object>>() {
// No implementation needed
};
Expand All @@ -58,10 +60,12 @@ public abstract class AbstractAsyncMcpToolMethodCallback<T> {

protected final ReturnMode returnMode;

protected AbstractAsyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject) {
protected AbstractAsyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject,
Class<? extends Throwable> toolCallExceptionClass) {
this.toolMethod = toolMethod;
this.toolObject = toolObject;
this.returnMode = returnMode;
this.toolCallExceptionClass = toolCallExceptionClass;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
*/
public abstract class AbstractSyncMcpToolMethodCallback<T> {

protected final Class<? extends Throwable> toolCallExceptionClass;

private static final TypeReference<Map<String, Object>> MAP_TYPE_REFERENCE = new TypeReference<Map<String, Object>>() {
// No implementation needed
};
Expand All @@ -54,10 +56,12 @@ public abstract class AbstractSyncMcpToolMethodCallback<T> {

protected final ReturnMode returnMode;

protected AbstractSyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject) {
protected AbstractSyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject,
Class<? extends Throwable> toolCallExceptionClass) {
this.toolMethod = toolMethod;
this.toolObject = toolObject;
this.returnMode = returnMode;
this.toolCallExceptionClass = toolCallExceptionClass;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ public final class AsyncMcpToolMethodCallback extends AbstractAsyncMcpToolMethod
implements BiFunction<McpAsyncServerExchange, CallToolRequest, Mono<CallToolResult>> {

public AsyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject) {
super(returnMode, toolMethod, toolObject);
super(returnMode, toolMethod, toolObject, Exception.class);
}

public AsyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject,
Class<? extends Throwable> toolCallExceptionClass) {
super(returnMode, toolMethod, toolObject, toolCallExceptionClass);
}

@Override
Expand Down Expand Up @@ -72,7 +77,10 @@ public Mono<CallToolResult> apply(McpAsyncServerExchange exchange, CallToolReque

}
catch (Exception e) {
return this.createErrorResult(e);
if (this.toolCallExceptionClass.isInstance(e)) {
return this.createErrorResult(e);
}
return Mono.error(e);
}
}));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ public final class AsyncStatelessMcpToolMethodCallback extends AbstractAsyncMcpT

public AsyncStatelessMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod,
Object toolObject) {
super(returnMode, toolMethod, toolObject);
super(returnMode, toolMethod, toolObject, Exception.class);
}

public AsyncStatelessMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod,
Object toolObject, Class<? extends Throwable> toolCallExceptionClass) {
super(returnMode, toolMethod, toolObject, toolCallExceptionClass);
}

@Override
Expand Down Expand Up @@ -71,7 +76,10 @@ public Mono<CallToolResult> apply(McpTransportContext mcpTransportContext, CallT

}
catch (Exception e) {
return this.createErrorResult(e);
if (this.toolCallExceptionClass.isInstance(e)) {
return this.createErrorResult(e);
}
return Mono.error(e);
}
}));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ public final class SyncMcpToolMethodCallback extends AbstractSyncMcpToolMethodCa
implements BiFunction<McpSyncServerExchange, CallToolRequest, CallToolResult> {

public SyncMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod, Object toolObject) {
super(returnMode, toolMethod, toolObject);
super(returnMode, toolMethod, toolObject, Exception.class);
}

public SyncMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod, Object toolObject,
Class<? extends Throwable> toolCallExceptionClass) {
super(returnMode, toolMethod, toolObject, toolCallExceptionClass);
}

@Override
Expand Down Expand Up @@ -69,7 +74,10 @@ public CallToolResult apply(McpSyncServerExchange exchange, CallToolRequest requ
return this.processResult(result);
}
catch (Exception e) {
return this.createErrorResult(e);
if (this.toolCallExceptionClass.isInstance(e)) {
return this.createErrorResult(e);
}
throw e;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ public final class SyncStatelessMcpToolMethodCallback extends AbstractSyncMcpToo

public SyncStatelessMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod,
Object toolObject) {
super(returnMode, toolMethod, toolObject);
super(returnMode, toolMethod, toolObject, Exception.class);
}

public SyncStatelessMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod,
Object toolObject, Class<? extends Throwable> toolCallExceptionClass) {
super(returnMode, toolMethod, toolObject, toolCallExceptionClass);
}

@Override
Expand All @@ -61,7 +66,10 @@ public CallToolResult apply(McpTransportContext mcpTransportContext, CallToolReq
return this.processResult(result);
}
catch (Exception e) {
return this.createErrorResult(e);
if (this.toolCallExceptionClass.isInstance(e)) {
return this.createErrorResult(e);
}
throw e;
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright 2025-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springaicommunity.mcp.provider;

import java.lang.reflect.Method;
import java.util.function.Predicate;

import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public class ProvidrerUtils {

public final static Predicate<Method> isReactiveReturnType = method -> Mono.class
.isAssignableFrom(method.getReturnType()) || Flux.class.isAssignableFrom(method.getReturnType())
|| Publisher.class.isAssignableFrom(method.getReturnType());

public final static Predicate<Method> isNotReactiveReturnType = method -> !Mono.class
.isAssignableFrom(method.getReturnType()) && !Flux.class.isAssignableFrom(method.getReturnType())
&& !Publisher.class.isAssignableFrom(method.getReturnType());

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package org.springaicommunity.mcp.provider.tool;

import java.lang.reflect.Method;
import java.util.List;

import io.modelcontextprotocol.util.Assert;
import org.springaicommunity.mcp.annotation.McpTool;

public abstract class AbstractMcpToolProvider {

protected final List<Object> toolObjects;

public AbstractMcpToolProvider(List<Object> toolObjects) {
Assert.notNull(toolObjects, "toolObjects cannot be null");
this.toolObjects = toolObjects;
}

protected Method[] doGetClassMethods(Object bean) {
return bean.getClass().getDeclaredMethods();
}

protected McpTool doGetMcpToolAnnotation(Method method) {
return method.getAnnotation(McpTool.class);
}

protected Class<? extends Throwable> doGetToolCallException() {
return Exception.class;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@

package org.springaicommunity.mcp.provider.tool;

import java.lang.reflect.Method;
import java.util.List;
import java.util.function.BiFunction;
import java.util.stream.Stream;

import org.reactivestreams.Publisher;
import io.modelcontextprotocol.server.McpAsyncServerExchange;
import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolSpecification;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
import io.modelcontextprotocol.util.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springaicommunity.mcp.annotation.McpTool;
Expand All @@ -30,33 +34,22 @@
import org.springaicommunity.mcp.method.tool.ReturnMode;
import org.springaicommunity.mcp.method.tool.utils.ClassUtils;
import org.springaicommunity.mcp.method.tool.utils.JsonSchemaGenerator;

import io.modelcontextprotocol.server.McpAsyncServerExchange;
import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolSpecification;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.Utils;
import reactor.core.publisher.Flux;
import org.springaicommunity.mcp.provider.ProvidrerUtils;
import reactor.core.publisher.Mono;

/**
* @author Christian Tzolov
*/
public class AsyncMcpToolProvider {
public class AsyncMcpToolProvider extends AbstractMcpToolProvider {

private static final Logger logger = LoggerFactory.getLogger(AsyncMcpToolProvider.class);

private final List<Object> toolObjects;

/**
* Create a new SyncMcpToolProvider.
* @param toolObjects the objects containing methods annotated with {@link McpTool}
*/
public AsyncMcpToolProvider(List<Object> toolObjects) {
Assert.notNull(toolObjects, "toolObjects cannot be null");
this.toolObjects = toolObjects;
super(toolObjects);
}

/**
Expand All @@ -68,15 +61,13 @@ public AsyncMcpToolProvider(List<Object> toolObjects) {
public List<AsyncToolSpecification> getToolSpecifications() {

List<AsyncToolSpecification> toolSpecs = this.toolObjects.stream()
.map(toolObject -> Stream.of(doGetClassMethods(toolObject))
.map(toolObject -> Stream.of(this.doGetClassMethods(toolObject))
.filter(method -> method.isAnnotationPresent(McpTool.class))
.filter(method -> Mono.class.isAssignableFrom(method.getReturnType())
|| Flux.class.isAssignableFrom(method.getReturnType())
|| Publisher.class.isAssignableFrom(method.getReturnType()))
.filter(ProvidrerUtils.isReactiveReturnType)
.sorted((m1, m2) -> m1.getName().compareTo(m2.getName()))
.map(mcpToolMethod -> {

var toolJavaAnnotation = doGetMcpToolAnnotation(mcpToolMethod);
var toolJavaAnnotation = this.doGetMcpToolAnnotation(mcpToolMethod);

String toolName = Utils.hasText(toolJavaAnnotation.name()) ? toolJavaAnnotation.name()
: mcpToolMethod.getName();
Expand Down Expand Up @@ -140,7 +131,7 @@ public List<AsyncToolSpecification> getToolSpecifications() {
: ReturnMode.TEXT;

BiFunction<McpAsyncServerExchange, CallToolRequest, Mono<CallToolResult>> methodCallback = new AsyncMcpToolMethodCallback(
returnMode, mcpToolMethod, toolObject);
returnMode, mcpToolMethod, toolObject, this.doGetToolCallException());

AsyncToolSpecification toolSpec = AsyncToolSpecification.builder()
.tool(tool)
Expand All @@ -160,12 +151,4 @@ public List<AsyncToolSpecification> getToolSpecifications() {
return toolSpecs;
}

protected Method[] doGetClassMethods(Object bean) {
return bean.getClass().getDeclaredMethods();
}

protected McpTool doGetMcpToolAnnotation(Method method) {
return method.getAnnotation(McpTool.class);
}

}
Loading