diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java index 52639f7..16a350b 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java @@ -48,6 +48,8 @@ */ public abstract class AbstractAsyncMcpToolMethodCallback { + protected final Class toolCallExceptionClass; + private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference>() { // No implementation needed }; @@ -58,10 +60,12 @@ public abstract class AbstractAsyncMcpToolMethodCallback { protected final ReturnMode returnMode; - protected AbstractAsyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject) { + protected AbstractAsyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject, + Class toolCallExceptionClass) { this.toolMethod = toolMethod; this.toolObject = toolObject; this.returnMode = returnMode; + this.toolCallExceptionClass = toolCallExceptionClass; } /** diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java index a6f9c07..9af2562 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java @@ -44,6 +44,8 @@ */ public abstract class AbstractSyncMcpToolMethodCallback { + protected final Class toolCallExceptionClass; + private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference>() { // No implementation needed }; @@ -54,10 +56,12 @@ public abstract class AbstractSyncMcpToolMethodCallback { protected final ReturnMode returnMode; - protected AbstractSyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject) { + protected AbstractSyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject, + Class toolCallExceptionClass) { this.toolMethod = toolMethod; this.toolObject = toolObject; this.returnMode = returnMode; + this.toolCallExceptionClass = toolCallExceptionClass; } /** diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java index 0b9e542..8d9bc42 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java @@ -38,7 +38,12 @@ public final class AsyncMcpToolMethodCallback extends AbstractAsyncMcpToolMethod implements BiFunction> { public AsyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject) { - super(returnMode, toolMethod, toolObject); + super(returnMode, toolMethod, toolObject, Exception.class); + } + + public AsyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject, + Class toolCallExceptionClass) { + super(returnMode, toolMethod, toolObject, toolCallExceptionClass); } @Override @@ -72,7 +77,10 @@ public Mono apply(McpAsyncServerExchange exchange, CallToolReque } catch (Exception e) { - return this.createErrorResult(e); + if (this.toolCallExceptionClass.isInstance(e)) { + return this.createErrorResult(e); + } + return Mono.error(e); } })); } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java index d1586fa..f97f349 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java @@ -38,7 +38,12 @@ public final class AsyncStatelessMcpToolMethodCallback extends AbstractAsyncMcpT public AsyncStatelessMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod, Object toolObject) { - super(returnMode, toolMethod, toolObject); + super(returnMode, toolMethod, toolObject, Exception.class); + } + + public AsyncStatelessMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod, + Object toolObject, Class toolCallExceptionClass) { + super(returnMode, toolMethod, toolObject, toolCallExceptionClass); } @Override @@ -71,7 +76,10 @@ public Mono apply(McpTransportContext mcpTransportContext, CallT } catch (Exception e) { - return this.createErrorResult(e); + if (this.toolCallExceptionClass.isInstance(e)) { + return this.createErrorResult(e); + } + return Mono.error(e); } })); } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java index 1798da8..344dee8 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java @@ -36,7 +36,12 @@ public final class SyncMcpToolMethodCallback extends AbstractSyncMcpToolMethodCa implements BiFunction { public SyncMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod, Object toolObject) { - super(returnMode, toolMethod, toolObject); + super(returnMode, toolMethod, toolObject, Exception.class); + } + + public SyncMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod, Object toolObject, + Class toolCallExceptionClass) { + super(returnMode, toolMethod, toolObject, toolCallExceptionClass); } @Override @@ -69,7 +74,10 @@ public CallToolResult apply(McpSyncServerExchange exchange, CallToolRequest requ return this.processResult(result); } catch (Exception e) { - return this.createErrorResult(e); + if (this.toolCallExceptionClass.isInstance(e)) { + return this.createErrorResult(e); + } + throw e; } } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java index ad8a0dd..4daee3e 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java @@ -37,7 +37,12 @@ public final class SyncStatelessMcpToolMethodCallback extends AbstractSyncMcpToo public SyncStatelessMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod, Object toolObject) { - super(returnMode, toolMethod, toolObject); + super(returnMode, toolMethod, toolObject, Exception.class); + } + + public SyncStatelessMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod, + Object toolObject, Class toolCallExceptionClass) { + super(returnMode, toolMethod, toolObject, toolCallExceptionClass); } @Override @@ -61,7 +66,10 @@ public CallToolResult apply(McpTransportContext mcpTransportContext, CallToolReq return this.processResult(result); } catch (Exception e) { - return this.createErrorResult(e); + if (this.toolCallExceptionClass.isInstance(e)) { + return this.createErrorResult(e); + } + throw e; } } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/ProvidrerUtils.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/ProvidrerUtils.java new file mode 100644 index 0000000..65f7c8c --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/ProvidrerUtils.java @@ -0,0 +1,36 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import java.lang.reflect.Method; +import java.util.function.Predicate; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class ProvidrerUtils { + + public final static Predicate isReactiveReturnType = method -> Mono.class + .isAssignableFrom(method.getReturnType()) || Flux.class.isAssignableFrom(method.getReturnType()) + || Publisher.class.isAssignableFrom(method.getReturnType()); + + public final static Predicate isNotReactiveReturnType = method -> !Mono.class + .isAssignableFrom(method.getReturnType()) && !Flux.class.isAssignableFrom(method.getReturnType()) + && !Publisher.class.isAssignableFrom(method.getReturnType()); + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/AbstractMcpToolProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/AbstractMcpToolProvider.java new file mode 100644 index 0000000..d3921d7 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/AbstractMcpToolProvider.java @@ -0,0 +1,30 @@ +package org.springaicommunity.mcp.provider.tool; + +import java.lang.reflect.Method; +import java.util.List; + +import io.modelcontextprotocol.util.Assert; +import org.springaicommunity.mcp.annotation.McpTool; + +public abstract class AbstractMcpToolProvider { + + protected final List toolObjects; + + public AbstractMcpToolProvider(List toolObjects) { + Assert.notNull(toolObjects, "toolObjects cannot be null"); + this.toolObjects = toolObjects; + } + + protected Method[] doGetClassMethods(Object bean) { + return bean.getClass().getDeclaredMethods(); + } + + protected McpTool doGetMcpToolAnnotation(Method method) { + return method.getAnnotation(McpTool.class); + } + + protected Class doGetToolCallException() { + return Exception.class; + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/AsyncMcpToolProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/AsyncMcpToolProvider.java index 3d98a1d..e93282e 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/AsyncMcpToolProvider.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/AsyncMcpToolProvider.java @@ -16,12 +16,16 @@ package org.springaicommunity.mcp.provider.tool; -import java.lang.reflect.Method; import java.util.List; import java.util.function.BiFunction; import java.util.stream.Stream; -import org.reactivestreams.Publisher; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolSpecification; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springaicommunity.mcp.annotation.McpTool; @@ -30,33 +34,22 @@ import org.springaicommunity.mcp.method.tool.ReturnMode; import org.springaicommunity.mcp.method.tool.utils.ClassUtils; import org.springaicommunity.mcp.method.tool.utils.JsonSchemaGenerator; - -import io.modelcontextprotocol.server.McpAsyncServerExchange; -import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolSpecification; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.Utils; -import reactor.core.publisher.Flux; +import org.springaicommunity.mcp.provider.ProvidrerUtils; import reactor.core.publisher.Mono; /** * @author Christian Tzolov */ -public class AsyncMcpToolProvider { +public class AsyncMcpToolProvider extends AbstractMcpToolProvider { private static final Logger logger = LoggerFactory.getLogger(AsyncMcpToolProvider.class); - private final List toolObjects; - /** * Create a new SyncMcpToolProvider. * @param toolObjects the objects containing methods annotated with {@link McpTool} */ public AsyncMcpToolProvider(List toolObjects) { - Assert.notNull(toolObjects, "toolObjects cannot be null"); - this.toolObjects = toolObjects; + super(toolObjects); } /** @@ -68,15 +61,13 @@ public AsyncMcpToolProvider(List toolObjects) { public List getToolSpecifications() { List toolSpecs = this.toolObjects.stream() - .map(toolObject -> Stream.of(doGetClassMethods(toolObject)) + .map(toolObject -> Stream.of(this.doGetClassMethods(toolObject)) .filter(method -> method.isAnnotationPresent(McpTool.class)) - .filter(method -> Mono.class.isAssignableFrom(method.getReturnType()) - || Flux.class.isAssignableFrom(method.getReturnType()) - || Publisher.class.isAssignableFrom(method.getReturnType())) + .filter(ProvidrerUtils.isReactiveReturnType) .sorted((m1, m2) -> m1.getName().compareTo(m2.getName())) .map(mcpToolMethod -> { - var toolJavaAnnotation = doGetMcpToolAnnotation(mcpToolMethod); + var toolJavaAnnotation = this.doGetMcpToolAnnotation(mcpToolMethod); String toolName = Utils.hasText(toolJavaAnnotation.name()) ? toolJavaAnnotation.name() : mcpToolMethod.getName(); @@ -140,7 +131,7 @@ public List getToolSpecifications() { : ReturnMode.TEXT; BiFunction> methodCallback = new AsyncMcpToolMethodCallback( - returnMode, mcpToolMethod, toolObject); + returnMode, mcpToolMethod, toolObject, this.doGetToolCallException()); AsyncToolSpecification toolSpec = AsyncToolSpecification.builder() .tool(tool) @@ -160,12 +151,4 @@ public List getToolSpecifications() { return toolSpecs; } - protected Method[] doGetClassMethods(Object bean) { - return bean.getClass().getDeclaredMethods(); - } - - protected McpTool doGetMcpToolAnnotation(Method method) { - return method.getAnnotation(McpTool.class); - } - } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/AsyncStatelessMcpToolProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/AsyncStatelessMcpToolProvider.java index 668c9ec..375eceb 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/AsyncStatelessMcpToolProvider.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/AsyncStatelessMcpToolProvider.java @@ -16,12 +16,16 @@ package org.springaicommunity.mcp.provider.tool; -import java.lang.reflect.Method; import java.util.List; import java.util.function.BiFunction; import java.util.stream.Stream; -import org.reactivestreams.Publisher; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpStatelessServerFeatures.AsyncToolSpecification; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springaicommunity.mcp.annotation.McpTool; @@ -30,15 +34,7 @@ import org.springaicommunity.mcp.method.tool.ReturnMode; import org.springaicommunity.mcp.method.tool.utils.ClassUtils; import org.springaicommunity.mcp.method.tool.utils.JsonSchemaGenerator; - -import io.modelcontextprotocol.server.McpStatelessServerFeatures.AsyncToolSpecification; -import io.modelcontextprotocol.common.McpTransportContext; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.Utils; -import reactor.core.publisher.Flux; +import org.springaicommunity.mcp.provider.ProvidrerUtils; import reactor.core.publisher.Mono; /** @@ -50,19 +46,16 @@ * * @author Christian Tzolov */ -public class AsyncStatelessMcpToolProvider { +public class AsyncStatelessMcpToolProvider extends AbstractMcpToolProvider { private static final Logger logger = LoggerFactory.getLogger(AsyncStatelessMcpToolProvider.class); - private final List toolObjects; - /** * Create a new AsyncStatelessMcpToolProvider. * @param toolObjects the objects containing methods annotated with {@link McpTool} */ public AsyncStatelessMcpToolProvider(List toolObjects) { - Assert.notNull(toolObjects, "toolObjects cannot be null"); - this.toolObjects = toolObjects; + super(toolObjects); } /** @@ -74,9 +67,7 @@ public List getToolSpecifications() { List toolSpecs = this.toolObjects.stream() .map(toolObject -> Stream.of(doGetClassMethods(toolObject)) .filter(method -> method.isAnnotationPresent(McpTool.class)) - .filter(method -> Mono.class.isAssignableFrom(method.getReturnType()) - || Flux.class.isAssignableFrom(method.getReturnType()) - || Publisher.class.isAssignableFrom(method.getReturnType())) + .filter(ProvidrerUtils.isReactiveReturnType) .sorted((m1, m2) -> m1.getName().compareTo(m2.getName())) .map(mcpToolMethod -> { @@ -145,7 +136,7 @@ public List getToolSpecifications() { : ReturnMode.TEXT; BiFunction> methodCallback = new AsyncStatelessMcpToolMethodCallback( - returnMode, mcpToolMethod, toolObject); + returnMode, mcpToolMethod, toolObject, this.doGetToolCallException()); AsyncToolSpecification toolSpec = AsyncToolSpecification.builder() .tool(tool) @@ -165,12 +156,4 @@ public List getToolSpecifications() { return toolSpecs; } - protected Method[] doGetClassMethods(Object bean) { - return bean.getClass().getDeclaredMethods(); - } - - protected McpTool doGetMcpToolAnnotation(Method method) { - return method.getAnnotation(McpTool.class); - } - } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/SyncMcpToolProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/SyncMcpToolProvider.java index 1caf640..e7b9237 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/SyncMcpToolProvider.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/SyncMcpToolProvider.java @@ -16,8 +16,6 @@ package org.springaicommunity.mcp.provider.tool; -import java.lang.reflect.Method; -import java.util.Arrays; import java.util.List; import java.util.function.BiFunction; import java.util.stream.Stream; @@ -27,7 +25,6 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,24 +33,21 @@ import org.springaicommunity.mcp.method.tool.SyncMcpToolMethodCallback; import org.springaicommunity.mcp.method.tool.utils.ClassUtils; import org.springaicommunity.mcp.method.tool.utils.JsonSchemaGenerator; -import reactor.core.publisher.Mono; +import org.springaicommunity.mcp.provider.ProvidrerUtils; /** * @author Christian Tzolov */ -public class SyncMcpToolProvider { +public class SyncMcpToolProvider extends AbstractMcpToolProvider { private static final Logger logger = LoggerFactory.getLogger(SyncMcpToolProvider.class); - private final List toolObjects; - /** * Create a new SyncMcpToolProvider. * @param toolObjects the objects containing methods annotated with {@link McpTool} */ public SyncMcpToolProvider(List toolObjects) { - Assert.notNull(toolObjects, "toolObjects cannot be null"); - this.toolObjects = toolObjects; + super(toolObjects); } /** @@ -65,34 +59,20 @@ public SyncMcpToolProvider(List toolObjects) { public List getToolSpecifications() { List toolSpecs = this.toolObjects.stream() - .map(toolObject -> Stream.of(doGetClassMethods(toolObject)) + .map(toolObject -> Stream.of(this.doGetClassMethods(toolObject)) .filter(method -> method.isAnnotationPresent(McpTool.class)) - .filter(method -> !Mono.class.isAssignableFrom(method.getReturnType())) + .filter(ProvidrerUtils.isNotReactiveReturnType) .sorted((m1, m2) -> m1.getName().compareTo(m2.getName())) .map(mcpToolMethod -> { - McpTool toolJavaAnnotation = doGetMcpToolAnnotation(mcpToolMethod); + McpTool toolJavaAnnotation = this.doGetMcpToolAnnotation(mcpToolMethod); String toolName = Utils.hasText(toolJavaAnnotation.name()) ? toolJavaAnnotation.name() : mcpToolMethod.getName(); String toolDescription = toolJavaAnnotation.description(); - // Check if method has CallToolRequest parameter - boolean hasCallToolRequestParam = Arrays.stream(mcpToolMethod.getParameterTypes()) - .anyMatch(type -> CallToolRequest.class.isAssignableFrom(type)); - - String inputSchema; - if (hasCallToolRequestParam) { - // For methods with CallToolRequest, generate minimal schema or - // use the one from the request - // The schema generation will handle this appropriately - inputSchema = JsonSchemaGenerator.generateForMethodInput(mcpToolMethod); - logger.debug("Tool method '{}' uses CallToolRequest parameter, using minimal schema", toolName); - } - else { - inputSchema = JsonSchemaGenerator.generateForMethodInput(mcpToolMethod); - } + String inputSchema = JsonSchemaGenerator.generateForMethodInput(mcpToolMethod); var toolBuilder = McpSchema.Tool.builder() .name(toolName) @@ -148,7 +128,7 @@ public List getToolSpecifications() { : ReturnMode.TEXT); BiFunction methodCallback = new SyncMcpToolMethodCallback( - returnMode, mcpToolMethod, toolObject); + returnMode, mcpToolMethod, toolObject, this.doGetToolCallException()); var toolSpec = SyncToolSpecification.builder().tool(tool).callHandler(methodCallback).build(); @@ -165,12 +145,4 @@ public List getToolSpecifications() { return toolSpecs; } - protected Method[] doGetClassMethods(Object bean) { - return bean.getClass().getDeclaredMethods(); - } - - protected McpTool doGetMcpToolAnnotation(Method method) { - return method.getAnnotation(McpTool.class); - } - } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/SyncStatelessMcpToolProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/SyncStatelessMcpToolProvider.java index 17eda58..6a6896d 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/SyncStatelessMcpToolProvider.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/SyncStatelessMcpToolProvider.java @@ -16,50 +16,43 @@ package org.springaicommunity.mcp.provider.tool; -import java.lang.reflect.Method; import java.util.List; import java.util.function.BiFunction; import java.util.stream.Stream; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpStatelessServerFeatures.SyncToolSpecification; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springaicommunity.mcp.annotation.McpTool; -import org.springaicommunity.mcp.method.tool.ReactiveUtils; import org.springaicommunity.mcp.method.tool.ReturnMode; import org.springaicommunity.mcp.method.tool.SyncStatelessMcpToolMethodCallback; import org.springaicommunity.mcp.method.tool.utils.ClassUtils; import org.springaicommunity.mcp.method.tool.utils.JsonSchemaGenerator; - -import io.modelcontextprotocol.server.McpStatelessServerFeatures.SyncToolSpecification; -import io.modelcontextprotocol.common.McpTransportContext; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.Utils; -import reactor.core.publisher.Mono; +import org.springaicommunity.mcp.provider.ProvidrerUtils; /** * Provider for synchronous stateless MCP tool methods. - * + * * This provider creates tool specifications for methods annotated with {@link McpTool} * that are designed to work in a stateless manner using {@link McpTransportContext}. * * @author Christian Tzolov */ -public class SyncStatelessMcpToolProvider { +public class SyncStatelessMcpToolProvider extends AbstractMcpToolProvider { private static final Logger logger = LoggerFactory.getLogger(SyncStatelessMcpToolProvider.class); - private final List toolObjects; - /** * Create a new SyncStatelessMcpToolProvider. * @param toolObjects the objects containing methods annotated with {@link McpTool} */ public SyncStatelessMcpToolProvider(List toolObjects) { - Assert.notNull(toolObjects, "toolObjects cannot be null"); - this.toolObjects = toolObjects; + super(toolObjects); } /** @@ -69,13 +62,13 @@ public SyncStatelessMcpToolProvider(List toolObjects) { public List getToolSpecifications() { List toolSpecs = this.toolObjects.stream() - .map(toolObject -> Stream.of(doGetClassMethods(toolObject)) + .map(toolObject -> Stream.of(this.doGetClassMethods(toolObject)) .filter(method -> method.isAnnotationPresent(McpTool.class)) - .filter(method -> !Mono.class.isAssignableFrom(method.getReturnType())) + .filter(ProvidrerUtils.isNotReactiveReturnType) .sorted((m1, m2) -> m1.getName().compareTo(m2.getName())) .map(mcpToolMethod -> { - var toolJavaAnnotation = doGetMcpToolAnnotation(mcpToolMethod); + var toolJavaAnnotation = this.doGetMcpToolAnnotation(mcpToolMethod); String toolName = Utils.hasText(toolJavaAnnotation.name()) ? toolJavaAnnotation.name() : mcpToolMethod.getName(); @@ -138,7 +131,7 @@ public List getToolSpecifications() { : ReturnMode.TEXT); BiFunction methodCallback = new SyncStatelessMcpToolMethodCallback( - returnMode, mcpToolMethod, toolObject); + returnMode, mcpToolMethod, toolObject, this.doGetToolCallException()); var toolSpec = SyncToolSpecification.builder().tool(tool).callHandler(methodCallback).build(); @@ -155,12 +148,4 @@ public List getToolSpecifications() { return toolSpecs; } - protected Method[] doGetClassMethods(Object bean) { - return bean.getClass().getDeclaredMethods(); - } - - protected McpTool doGetMcpToolAnnotation(Method method) { - return method.getAnnotation(McpTool.class); - } - } diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallbackExceptionHandlingTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallbackExceptionHandlingTests.java new file mode 100644 index 0000000..cde2f89 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallbackExceptionHandlingTests.java @@ -0,0 +1,328 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.method.tool; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.Map; + +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpTool; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Tests for exception handling in {@link SyncMcpToolMethodCallback}. + * + * These tests verify the exception handling behavior in the apply() method, specifically + * the catch block that checks if an exception is an instance of the configured + * toolCallExceptionClass. + * + * @author Christian Tzolov + */ +public class SyncMcpToolMethodCallbackExceptionHandlingTests { + + // Custom exception classes for testing + public static class BusinessException extends Exception { + + public BusinessException(String message) { + super(message); + } + + } + + public static class CustomRuntimeException extends RuntimeException { + + public CustomRuntimeException(String message) { + super(message); + } + + } + + // Test tool provider with various exception-throwing methods + private static class ExceptionTestToolProvider { + + @McpTool(name = "runtime-exception-tool", description = "Tool that throws RuntimeException") + public String runtimeExceptionTool(String input) { + throw new RuntimeException("Runtime error: " + input); + } + + @McpTool(name = "custom-runtime-exception-tool", description = "Tool that throws CustomRuntimeException") + public String customRuntimeExceptionTool(String input) { + throw new CustomRuntimeException("Custom runtime error: " + input); + } + + @McpTool(name = "checked-exception-tool", description = "Tool that throws checked exception") + public String checkedExceptionTool(String input) throws BusinessException { + throw new BusinessException("Business error: " + input); + } + + @McpTool(name = "success-tool", description = "Tool that succeeds") + public String successTool(String input) { + return "Success: " + input; + } + + @McpTool(name = "null-pointer-tool", description = "Tool that throws NullPointerException") + public String nullPointerTool(String input) { + throw new NullPointerException("Null pointer: " + input); + } + + @McpTool(name = "illegal-argument-tool", description = "Tool that throws IllegalArgumentException") + public String illegalArgumentTool(String input) { + throw new IllegalArgumentException("Illegal argument: " + input); + } + + } + + @Test + public void testDefaultConstructor_CatchesAllExceptions() throws Exception { + // Test with default constructor (uses Exception.class) + ExceptionTestToolProvider provider = new ExceptionTestToolProvider(); + Method method = ExceptionTestToolProvider.class.getMethod("runtimeExceptionTool", String.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("runtime-exception-tool", Map.of("input", "test")); + + // The RuntimeException thrown by callMethod should be caught and converted to + // error result + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).contains("Error invoking method"); + } + + @Test + public void testExceptionClassConstructor_CatchesSpecifiedExceptions() throws Exception { + // Configure to catch only RuntimeException and its subclasses + ExceptionTestToolProvider provider = new ExceptionTestToolProvider(); + Method method = ExceptionTestToolProvider.class.getMethod("customRuntimeExceptionTool", String.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider, + RuntimeException.class); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("custom-runtime-exception-tool", Map.of("input", "test")); + + // The RuntimeException wrapper from callMethod should be caught + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).contains("Error invoking method"); + } + + @Test + public void testNonMatchingExceptionClass_ThrowsException() throws Exception { + // Configure to catch only IllegalArgumentException + ExceptionTestToolProvider provider = new ExceptionTestToolProvider(); + Method method = ExceptionTestToolProvider.class.getMethod("runtimeExceptionTool", String.class); + + // Create callback that only catches IllegalArgumentException + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider, + IllegalArgumentException.class); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("runtime-exception-tool", Map.of("input", "test")); + + // The RuntimeException from callMethod should NOT be caught (not an + // IllegalArgumentException) + assertThatThrownBy(() -> callback.apply(exchange, request)).isInstanceOf(RuntimeException.class) + .hasMessageContaining("Error invoking method"); + } + + @Test + public void testCheckedExceptionHandling_WithExceptionClass() throws Exception { + // Test handling of checked exceptions wrapped in RuntimeException + ExceptionTestToolProvider provider = new ExceptionTestToolProvider(); + Method method = ExceptionTestToolProvider.class.getMethod("checkedExceptionTool", String.class); + + // Configure to catch Exception (which includes RuntimeException) + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider, + Exception.class); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("checked-exception-tool", Map.of("input", "test")); + + // The RuntimeException wrapper should be caught + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).contains("Error invoking method"); + } + + @Test + public void testCheckedExceptionHandling_WithSpecificClass() throws Exception { + // Configure to catch only IllegalArgumentException (not RuntimeException) + ExceptionTestToolProvider provider = new ExceptionTestToolProvider(); + Method method = ExceptionTestToolProvider.class.getMethod("checkedExceptionTool", String.class); + + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider, + IllegalArgumentException.class); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("checked-exception-tool", Map.of("input", "test")); + + // The RuntimeException wrapper should NOT be caught + assertThatThrownBy(() -> callback.apply(exchange, request)).isInstanceOf(RuntimeException.class) + .hasMessageContaining("Error invoking method") + .hasCauseInstanceOf(InvocationTargetException.class); + } + + @Test + public void testSuccessfulExecution_NoExceptionThrown() throws Exception { + // Test that successful execution works normally regardless of exception class + // config + ExceptionTestToolProvider provider = new ExceptionTestToolProvider(); + Method method = ExceptionTestToolProvider.class.getMethod("successTool", String.class); + + // Configure with a specific exception class + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider, + IllegalArgumentException.class); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("success-tool", Map.of("input", "test")); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Success: test"); + } + + @Test + public void testNullPointerException_WithRuntimeExceptionClass() throws Exception { + // Configure to catch RuntimeException (which includes NullPointerException) + ExceptionTestToolProvider provider = new ExceptionTestToolProvider(); + Method method = ExceptionTestToolProvider.class.getMethod("nullPointerTool", String.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider, + RuntimeException.class); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("null-pointer-tool", Map.of("input", "test")); + + // Should catch the RuntimeException wrapper + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).contains("Error invoking method"); + } + + @Test + public void testIllegalArgumentException_WithSpecificHandling() throws Exception { + // Configure to catch only RuntimeException + ExceptionTestToolProvider provider = new ExceptionTestToolProvider(); + Method method = ExceptionTestToolProvider.class.getMethod("illegalArgumentTool", String.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider, + RuntimeException.class); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("illegal-argument-tool", Map.of("input", "test")); + + // Should catch the RuntimeException wrapper (which wraps + // IllegalArgumentException) + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).contains("Error invoking method"); + } + + @Test + public void testMultipleCallsWithDifferentResults() throws Exception { + // Test that the same callback instance handles different scenarios correctly + ExceptionTestToolProvider provider = new ExceptionTestToolProvider(); + Method successMethod = ExceptionTestToolProvider.class.getMethod("successTool", String.class); + Method exceptionMethod = ExceptionTestToolProvider.class.getMethod("runtimeExceptionTool", String.class); + + // Create callbacks with Exception handling (catches all) + SyncMcpToolMethodCallback successCallback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, successMethod, + provider, Exception.class); + SyncMcpToolMethodCallback exceptionCallback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, exceptionMethod, + provider, Exception.class); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + + // Test success case + CallToolRequest successRequest = new CallToolRequest("success-tool", Map.of("input", "success")); + CallToolResult successResult = successCallback.apply(exchange, successRequest); + assertThat(successResult.isError()).isFalse(); + assertThat(((TextContent) successResult.content().get(0)).text()).isEqualTo("Success: success"); + + // Test exception case + CallToolRequest exceptionRequest = new CallToolRequest("runtime-exception-tool", Map.of("input", "error")); + CallToolResult exceptionResult = exceptionCallback.apply(exchange, exceptionRequest); + assertThat(exceptionResult.isError()).isTrue(); + assertThat(((TextContent) exceptionResult.content().get(0)).text()).contains("Error invoking method"); + } + + @Test + public void testExceptionHierarchy_ParentClassCatchesSubclasses() throws Exception { + // Configure to catch Exception (parent of RuntimeException) + ExceptionTestToolProvider provider = new ExceptionTestToolProvider(); + Method method = ExceptionTestToolProvider.class.getMethod("customRuntimeExceptionTool", String.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider, + Exception.class); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("custom-runtime-exception-tool", Map.of("input", "test")); + + // Should catch the RuntimeException (subclass of Exception) + CallToolResult result = callback.apply(exchange, request); + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + } + + @Test + public void testConstructorWithNullExceptionClass_UsesDefault() throws Exception { + // The constructor with 3 parameters uses Exception.class as default + ExceptionTestToolProvider provider = new ExceptionTestToolProvider(); + Method method = ExceptionTestToolProvider.class.getMethod("runtimeExceptionTool", String.class); + + // This constructor uses Exception.class internally + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("runtime-exception-tool", Map.of("input", "test")); + + // Should catch all exceptions (default is Exception.class) + CallToolResult result = callback.apply(exchange, request); + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + } + +}