diff --git a/README.md b/README.md index 83fa872..b9fc714 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ To use the MCP Annotations core module in your project, add the following depend org.springaicommunity mcp-annotations - 0.1.0 + 0.2.0-SNAPSHOT ``` @@ -48,7 +48,7 @@ To use the Spring integration module, add the following dependency: corg.springaicommunity spring-ai-mcp-annotations - 0.1.0 + 0.2.0-SNAPSHOT ``` @@ -89,8 +89,9 @@ The core module provides a set of annotations and callback implementations for p 1. **Complete** - For auto-completion functionality in prompts and URI templates 2. **Prompt** - For generating prompt messages 3. **Resource** - For accessing resources via URI templates -4. **Logging Consumer** - For handling logging message notifications -5. **Sampling** - For handling sampling requests +4. **Tool** - For implementing MCP tools with automatic JSON schema generation +5. **Logging Consumer** - For handling logging message notifications +6. **Sampling** - For handling sampling requests Each operation type has both synchronous and asynchronous implementations, allowing for flexible integration with different application architectures. @@ -105,6 +106,8 @@ The Spring integration module provides seamless integration with Spring AI and S - **`@McpComplete`** - Annotates methods that provide completion functionality for prompts or URI templates - **`@McpPrompt`** - Annotates methods that generate prompt messages - **`@McpResource`** - Annotates methods that provide access to resources +- **`@McpTool`** - Annotates methods that implement MCP tools with automatic JSON schema generation +- **`@McpToolParam`** - Annotates tool method parameters with descriptions and requirement specifications - **`@McpLoggingConsumer`** - Annotates methods that handle logging message notifications from MCP servers - **`@McpSampling`** - Annotates methods that handle sampling requests from MCP servers - **`@McpArg`** - Annotates method parameters as MCP arguments @@ -133,6 +136,10 @@ The modules provide callback implementations for each operation type: - `SyncMcpLoggingConsumerMethodCallback` - Synchronous implementation - `AsyncMcpLoggingConsumerMethodCallback` - Asynchronous implementation using Reactor's Mono +#### Tool +- `SyncMcpToolMethodCallback` - Synchronous implementation for tool method callbacks +- `AsyncMcpToolMethodCallback` - Asynchronous implementation using Reactor's Mono + #### Sampling - `AbstractMcpSamplingMethodCallback` - Base class for sampling method callbacks - `SyncMcpSamplingMethodCallback` - Synchronous implementation @@ -145,6 +152,8 @@ The project includes provider classes that scan for annotated methods and create - `SyncMcpCompletionProvider` - Processes `@McpComplete` annotations for synchronous operations - `SyncMcpPromptProvider` - Processes `@McpPrompt` annotations for synchronous operations - `SyncMcpResourceProvider` - Processes `@McpResource` annotations for synchronous operations +- `SyncMcpToolProvider` - Processes `@McpTool` annotations for synchronous operations +- `AsyncMcpToolProvider` - Processes `@McpTool` annotations for asynchronous operations - `SyncMcpLoggingConsumerProvider` - Processes `@McpLoggingConsumer` annotations for synchronous operations - `AsyncMcpLoggingConsumerProvider` - Processes `@McpLoggingConsumer` annotations for asynchronous operations - `SyncMcpSamplingProvider` - Processes `@McpSampling` annotations for synchronous operations @@ -325,6 +334,166 @@ public class MyResourceProvider { } ``` +### Tool Example + +```java +public class CalculatorToolProvider { + + @McpTool(name = "add", description = "Add two numbers together") + public int add( + @McpToolParam(description = "First number to add", required = true) int a, + @McpToolParam(description = "Second number to add", required = true) int b) { + return a + b; + } + + @McpTool(name = "multiply", description = "Multiply two numbers") + public double multiply( + @McpToolParam(description = "First number", required = true) double x, + @McpToolParam(description = "Second number", required = true) double y) { + return x * y; + } + + @McpTool(name = "calculate-area", + description = "Calculate the area of a rectangle", + annotations = @McpTool.McpAnnotations( + title = "Rectangle Area Calculator", + readOnlyHint = true, + destructiveHint = false, + idempotentHint = true + )) + public AreaResult calculateRectangleArea( + @McpToolParam(description = "Width of the rectangle", required = true) double width, + @McpToolParam(description = "Height of the rectangle", required = true) double height) { + + double area = width * height; + return new AreaResult(area, "square units"); + } + + @McpTool(name = "process-data", description = "Process data with exchange context") + public String processData( + McpSyncServerExchange exchange, + @McpToolParam(description = "Data to process", required = true) String data) { + + exchange.loggingNotification(LoggingMessageNotification.builder() + .level(LoggingLevel.INFO) + .data("Processing data: " + data) + .build()); + + return "Processed: " + data.toUpperCase(); + } + + // Async tool example + @McpTool(name = "async-calculation", description = "Perform async calculation") + public Mono asyncCalculation( + @McpToolParam(description = "Input value", required = true) int value) { + return Mono.fromCallable(() -> { + // Simulate some async work + Thread.sleep(100); + return "Async result: " + (value * 2); + }).subscribeOn(Schedulers.boundedElastic()); + } + + public static class AreaResult { + public double area; + public String unit; + + public AreaResult(double area, String unit) { + this.area = area; + this.unit = unit; + } + } +} +``` + +### Async Tool Example + +```java +public class AsyncToolProvider { + + @McpTool(name = "fetch-data", description = "Fetch data asynchronously") + public Mono fetchData( + @McpToolParam(description = "Data ID to fetch", required = true) String dataId, + @McpToolParam(description = "Include metadata", required = false) Boolean includeMetadata) { + + return Mono.fromCallable(() -> { + // Simulate async data fetching + DataResponse response = new DataResponse(); + response.id = dataId; + response.data = "Sample data for " + dataId; + response.metadata = Boolean.TRUE.equals(includeMetadata) ? + Map.of("timestamp", System.currentTimeMillis()) : null; + return response; + }).subscribeOn(Schedulers.boundedElastic()); + } + + @McpTool(name = "stream-process", description = "Process data stream") + public Flux streamProcess( + @McpToolParam(description = "Number of items to process", required = true) int count) { + + return Flux.range(1, count) + .map(i -> "Processed item " + i) + .delayElements(Duration.ofMillis(100)); + } + + public static class DataResponse { + public String id; + public String data; + public Map metadata; + } +} +``` + +### Mcp Server with Tool capabilities + +```java +public class McpServerFactory { + + public McpSyncServer createMcpServerWithTools( + CalculatorToolProvider calculatorProvider, + MyResourceProvider resourceProvider) { + + List toolSpecifications = + new SyncMcpToolProvider(List.of(calculatorProvider)).getToolSpecifications(); + + List resourceSpecifications = + new SyncMcpResourceProvider(List.of(resourceProvider)).getResourceSpecifications(); + + // Create a server with tool support + McpSyncServer syncServer = McpServer.sync(transportProvider) + .serverInfo("calculator-server", "1.0.0") + .capabilities(ServerCapabilities.builder() + .tools(true) // Enable tool support + .resources(true) // Enable resource support + .logging() // Enable logging support + .build()) + .tools(toolSpecifications) + .resources(resourceSpecifications) + .build(); + + return syncServer; + } + + public McpAsyncServer createAsyncMcpServerWithTools( + AsyncToolProvider asyncToolProvider) { + + List asyncToolSpecifications = + new AsyncMcpToolProvider(List.of(asyncToolProvider)).getToolSpecifications(); + + // Create an async server with tool support + McpAsyncServer asyncServer = McpServer.async(transportProvider) + .serverInfo("async-tool-server", "1.0.0") + .capabilities(ServerCapabilities.builder() + .tools(true) // Enable tool support + .logging() // Enable logging support + .build()) + .tools(asyncToolSpecifications) + .build(); + + return asyncServer; + } +} +``` + ### Mcp Server with Resource, Prompt and Completion capabilities ```java @@ -506,6 +675,18 @@ public class McpConfig { return SpringAiMcpAnnotationProvider.createSyncResourceSpecifications(resourceProviders); } + @Bean + public List syncToolSpecifications( + List toolProviders) { + return SpringAiMcpAnnotationProvider.createSyncToolSpecifications(toolProviders); + } + + @Bean + public List asyncToolSpecifications( + List asyncToolProviders) { + return SpringAiMcpAnnotationProvider.createAsyncToolSpecifications(asyncToolProviders); + } + @Bean public List> syncLoggingConsumers( List loggingHandlers) { @@ -533,6 +714,7 @@ public class McpConfig { - **Builder pattern for callback creation** - Clean and fluent API for creating method callbacks - **Comprehensive validation** - Ensures method signatures are compatible with MCP operations - **URI template support** - Powerful URI template handling for resource and completion operations +- **Tool support with automatic JSON schema generation** - Create MCP tools with automatic input/output schema generation from method signatures - **Logging consumer support** - Handle logging message notifications from MCP servers - **Sampling support** - Handle sampling requests from MCP servers - **Spring integration** - Seamless integration with Spring Framework and Spring AI diff --git a/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/AsyncMcpAnnotationProvider.java b/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/AsyncMcpAnnotationProvider.java index a3ea96f..9a29a2c 100644 --- a/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/AsyncMcpAnnotationProvider.java +++ b/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/AsyncMcpAnnotationProvider.java @@ -21,7 +21,9 @@ import org.springaicommunity.mcp.provider.AsyncMcpLoggingConsumerProvider; import org.springaicommunity.mcp.provider.AsyncMcpSamplingProvider; +import org.springaicommunity.mcp.provider.AsyncMcpToolProvider; +import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolSpecification; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; @@ -58,6 +60,19 @@ protected Method[] doGetClassMethods(Object bean) { } + private static class SpringAiAsyncMcpToolProvider extends AsyncMcpToolProvider { + + public SpringAiAsyncMcpToolProvider(List toolObjects) { + super(toolObjects); + } + + @Override + protected Method[] doGetClassMethods(Object bean) { + return AnnotationProviderUtil.beanMethods(bean); + } + + } + public static List>> createAsyncLoggingConsumers( List loggingObjects) { return new SpringAiAsyncMcpLoggingConsumerProvider(loggingObjects).getLoggingConsumers(); @@ -68,4 +83,8 @@ public static Function> createAs return new SpringAiAsyncMcpSamplingProvider(samplingObjects).getSamplingHandler(); } + public static List createAsyncToolSpecifications(List toolObjects) { + return new SpringAiAsyncMcpToolProvider(toolObjects).getToolSpecifications(); + } + } diff --git a/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/SyncMcpAnnotationProvider.java b/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/SyncMcpAnnotationProvider.java index 4b1842e..c91a6c0 100644 --- a/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/SyncMcpAnnotationProvider.java +++ b/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/SyncMcpAnnotationProvider.java @@ -20,20 +20,20 @@ import java.util.function.Consumer; import java.util.function.Function; -import org.springaicommunity.mcp.provider.AsyncMcpSamplingProvider; import org.springaicommunity.mcp.provider.SyncMcpCompletionProvider; import org.springaicommunity.mcp.provider.SyncMcpLoggingConsumerProvider; import org.springaicommunity.mcp.provider.SyncMcpPromptProvider; import org.springaicommunity.mcp.provider.SyncMcpResourceProvider; import org.springaicommunity.mcp.provider.SyncMcpSamplingProvider; +import org.springaicommunity.mcp.provider.SyncMcpToolProvider; import io.modelcontextprotocol.server.McpServerFeatures.SyncCompletionSpecification; import io.modelcontextprotocol.server.McpServerFeatures.SyncPromptSpecification; import io.modelcontextprotocol.server.McpServerFeatures.SyncResourceSpecification; +import io.modelcontextprotocol.server.McpServerFeatures.SyncToolSpecification; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; -import reactor.core.publisher.Mono; /** * @author Christian Tzolov @@ -53,6 +53,19 @@ protected Method[] doGetClassMethods(Object bean) { }; + private static class SpringAiSyncToolProvider extends SyncMcpToolProvider { + + public SpringAiSyncToolProvider(List toolObjects) { + super(toolObjects); + } + + @Override + protected Method[] doGetClassMethods(Object bean) { + return AnnotationProviderUtil.beanMethods(bean); + } + + } + private static class SpringAiSyncMcpPromptProvider extends SyncMcpPromptProvider { public SpringAiSyncMcpPromptProvider(List promptObjects) { @@ -105,17 +118,8 @@ protected Method[] doGetClassMethods(Object bean) { } - private static class SpringAiAsyncMcpSamplingProvider extends AsyncMcpSamplingProvider { - - public SpringAiAsyncMcpSamplingProvider(List samplingObjects) { - super(samplingObjects); - } - - @Override - protected Method[] doGetClassMethods(Object bean) { - return AnnotationProviderUtil.beanMethods(bean); - } - + public static List createSyncToolSpecifications(List toolObjects) { + return new SpringAiSyncToolProvider(toolObjects).getToolSpecifications(); } public static List createSyncCompleteSpecifications(List completeObjects) { @@ -139,9 +143,4 @@ public static Function createSyncSamp return new SpringAiSyncMcpSamplingProvider(samplingObjects).getSamplingHandler(); } - public static Function> createAsyncSamplingHandler( - List samplingObjects) { - return new SpringAiAsyncMcpSamplingProvider(samplingObjects).getSamplingHandler(); - } - } diff --git a/mcp-annotations/pom.xml b/mcp-annotations/pom.xml index 5ae83c3..f448f80 100644 --- a/mcp-annotations/pom.xml +++ b/mcp-annotations/pom.xml @@ -22,6 +22,12 @@ git@github.com/spring-ai-community/mcp-annotations.git + + 4.38.0 + 2.2.34 + 2.19.2 + + @@ -30,6 +36,43 @@ ${mcp.java.sdk.version} + + com.github.victools + jsonschema-module-swagger-2 + ${jsonschema.version} + + + + com.github.victools + jsonschema-generator + ${jsonschema.version} + + + + + com.github.victools + jsonschema-module-jackson + ${jsonschema.version} + + + + io.swagger.core.v3 + swagger-annotations-jakarta + ${swagger-annotations.version} + + + + com.fasterxml.jackson.core + jackson-databind + ${jackson-databind.version} + + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + ${jackson-databind.version} + + io.modelcontextprotocol.sdk mcp-test diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpTool.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpTool.java new file mode 100644 index 0000000..3a1876d --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpTool.java @@ -0,0 +1,90 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ +package org.springaicommunity.mcp.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * @author Christian Tzolov + */ +@Target({ ElementType.METHOD, ElementType.ANNOTATION_TYPE }) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface McpTool { + + /** + * The name of the tool. If not provided, the method name will be used. + */ + String name() default ""; + + /** + * The description of the tool. If not provided, the method name will be used. + */ + String description() default ""; + + /** + * Additional hints for clients. + */ + McpAnnotations annotations() default @McpAnnotations; + + /** + * If true, the tool will generate an output schema for non-primitive output types. If + * false, the tool will not generate an output schema. + */ + boolean generateOutputSchema() default true; + + /** + * Additional properties describing a Tool to clients. + * + * all properties in ToolAnnotations are hints. They are not guaranteed to provide a + * faithful description of tool behavior (including descriptive properties like + * title). + * + * Clients should never make tool use decisions based on ToolAnnotations received from + * untrusted servers. + */ + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.ANNOTATION_TYPE) + public @interface McpAnnotations { + + /** + * A human-readable title for the tool. + */ + String title() default ""; + + /** + * If true, the tool does not modify its environment. + */ + boolean readOnlyHint() default false; + + /** + * If true, the tool may perform destructive updates to its environment. If false, + * the tool performs only additive updates. + * + * (This property is meaningful only when readOnlyHint == false) + */ + boolean destructiveHint() default true; + + /** + * If true, calling the tool repeatedly with the same arguments will have no + * additional effect on the its environment. + * + * (This property is meaningful only when readOnlyHint == false) + */ + boolean idempotentHint() default false; + + /** + * If true, this tool may interact with an “open world” of external entities. If + * false, the tool’s domain of interaction is closed. For example, the world of a + * web search tool is open, whereas that of a memory tool is not. + */ + boolean openWorldHint() default true; + + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpToolParam.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpToolParam.java new file mode 100644 index 0000000..8cfd5c7 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpToolParam.java @@ -0,0 +1,31 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * @author Christian Tzolov + */ +@Target({ ElementType.PARAMETER, ElementType.FIELD, ElementType.ANNOTATION_TYPE }) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface McpToolParam { + + /** + * Whether the tool argument is required. + */ + boolean required() default true; + + /** + * The description of the tool argument. + */ + String description() default ""; + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/ResourceAdaptor.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/ResourceAdaptor.java index 98900d7..5b4f350 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/ResourceAdaptor.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/ResourceAdaptor.java @@ -14,12 +14,24 @@ private ResourceAdaptor() { } public static McpSchema.Resource asResource(McpResource mcpResource) { - return new McpSchema.Resource(mcpResource.uri(), mcpResource.name(), mcpResource.description(), - mcpResource.mimeType(), null); + String name = mcpResource.name(); + if (name == null || name.isEmpty()) { + name = "resource"; // Default name when not specified + } + return McpSchema.Resource.builder() + .uri(mcpResource.uri()) + .name(name) + .description(mcpResource.description()) + .mimeType(mcpResource.mimeType()) + .build(); } public static McpSchema.ResourceTemplate asResourceTemplate(McpResource mcpResource) { - return new McpSchema.ResourceTemplate(mcpResource.uri(), mcpResource.name(), mcpResource.description(), + String name = mcpResource.name(); + if (name == null || name.isEmpty()) { + name = "resource"; // Default name when not specified + } + return new McpSchema.ResourceTemplate(mcpResource.uri(), name, mcpResource.description(), mcpResource.mimeType(), null); } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/SyncMcpResourceMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/SyncMcpResourceMethodCallback.java index e1c21e7..37ab4bd 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/SyncMcpResourceMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/SyncMcpResourceMethodCallback.java @@ -92,14 +92,10 @@ public static class Builder extends AbstractBuilder returnType = method.getReturnType(); @@ -137,12 +128,6 @@ protected void validateReturnType(Method method) { } } - /** - * Checks if a parameter type is compatible with the exchange type. - * @param paramType The parameter type to check - * @return true if the parameter type is compatible with the exchange type, false - * otherwise - */ @Override protected boolean isExchangeType(Class paramType) { return McpSyncServerExchange.class.isAssignableFrom(paramType); diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java new file mode 100644 index 0000000..05a7ad2 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java @@ -0,0 +1,249 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.method.tool; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Type; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.stream.Stream; + +import org.reactivestreams.Publisher; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.method.tool.utils.JsonParser; + +import com.fasterxml.jackson.core.type.TypeReference; + +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Class for creating Function callbacks around tool methods. + * + * This class provides a way to convert methods annotated with {@link McpTool} into + * callback functions that can be used to handle tool requests. + * + * @author Christian Tzolov + */ +public final class AsyncMcpToolMethodCallback + implements BiFunction> { + + private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference>() { + // No implementation needed + }; + + private final Method toolMethod; + + private final Object toolObject; + + private ReturnMode returnMode; + + public AsyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject) { + this.toolMethod = toolMethod; + this.toolObject = toolObject; + this.returnMode = returnMode; + } + + /** + * Apply the callback to the given request. + *

+ * This method builds the arguments for the method call, invokes the method, and + * returns the result. + * @param request The tool call request, must not be null + * @return The result of the method invocation + */ + @Override + public Mono apply(McpAsyncServerExchange exchange, CallToolRequest request) { + + if (request == null) { + return Mono.error(new IllegalArgumentException("Request must not be null")); + } + + return Mono.defer(() -> { + try { + // Build arguments for the method call + Object[] args = this.buildMethodArguments(exchange, request.arguments()); + + // Invoke the method + Object result = this.callMethod(args); + + // Handle reactive types - method return types should always be reactive + return this.convertToCallToolResult(result); + + } + catch (Exception e) { + return Mono.just(CallToolResult.builder() + .isError(true) + .addTextContent("Error invoking method: %s".formatted(e.getMessage())) + .build()); + } + }); + } + + /** + * Convert reactive types to Mono + */ + private Mono convertToCallToolResult(Object result) { + // Handle Mono types + if (result instanceof Mono) { + + Mono monoResult = (Mono) result; + + // Check if the Mono contains CallToolResult + if (ReactiveUtils.isReactiveReturnTypeOfCallToolResult(this.toolMethod)) { + return (Mono) monoResult; + } + + // Handle Mono for VOID return type + if (ReactiveUtils.isReactiveReturnTypeOfVoid(this.toolMethod)) { + return monoResult + .then(Mono.just(CallToolResult.builder().addTextContent(JsonParser.toJson("Done")).build())); + } + + // Handle other Mono types - map the emitted value to CallToolResult + return monoResult.map(this::mapValueToCallToolResult) + .onErrorResume(e -> Mono.just(CallToolResult.builder() + .isError(true) + .addTextContent("Error invoking method: %s".formatted(e.getMessage())) + .build())); + } + + // Handle Flux by taking the first element + if (result instanceof Flux) { + Flux fluxResult = (Flux) result; + + // Check if the Flux contains CallToolResult + if (ReactiveUtils.isReactiveReturnTypeOfCallToolResult(this.toolMethod)) { + return ((Flux) fluxResult).next(); + } + + // Handle Mono for VOID return type + if (ReactiveUtils.isReactiveReturnTypeOfVoid(this.toolMethod)) { + return fluxResult + .then(Mono.just(CallToolResult.builder().addTextContent(JsonParser.toJson("Done")).build())); + } + + // Handle other Flux types by taking the first element and mapping + return fluxResult.next() + .map(this::mapValueToCallToolResult) + .onErrorResume(e -> Mono.just(CallToolResult.builder() + .isError(true) + .addTextContent("Error invoking method: %s".formatted(e.getMessage())) + .build())); + } + + // Handle other Publisher types + if (result instanceof Publisher) { + Publisher publisherResult = (Publisher) result; + Mono monoFromPublisher = Mono.from(publisherResult); + + // Check if the Publisher contains CallToolResult + if (ReactiveUtils.isReactiveReturnTypeOfCallToolResult(this.toolMethod)) { + return (Mono) monoFromPublisher; + } + + // Handle Mono for VOID return type + if (ReactiveUtils.isReactiveReturnTypeOfVoid(this.toolMethod)) { + return monoFromPublisher + .then(Mono.just(CallToolResult.builder().addTextContent(JsonParser.toJson("Done")).build())); + } + + // Handle other Publisher types by mapping the emitted value + return monoFromPublisher.map(this::mapValueToCallToolResult) + .onErrorResume(e -> Mono.just(CallToolResult.builder() + .isError(true) + .addTextContent("Error invoking method: %s".formatted(e.getMessage())) + .build())); + } + + // This should not happen in async context, but handle as fallback + throw new IllegalStateException( + "Expected reactive return type but got: " + (result != null ? result.getClass().getName() : "null")); + } + + /** + * Map individual values to CallToolResult + */ + private CallToolResult mapValueToCallToolResult(Object value) { + if (value instanceof CallToolResult) { + return (CallToolResult) value; + } + + if (returnMode == ReturnMode.VOID) { + return CallToolResult.builder().addTextContent(JsonParser.toJson("Done")).build(); + } + else if (this.returnMode == ReturnMode.STRUCTURED) { + String jsonOutput = JsonParser.toJson(value); + Map structuredOutput = JsonParser.fromJson(jsonOutput, MAP_TYPE_REFERENCE); + return CallToolResult.builder().structuredContent(structuredOutput).build(); + } + + // Default to text output + return CallToolResult.builder().addTextContent(value != null ? value.toString() : "null").build(); + } + + private Object callMethod(Object[] methodArguments) { + + this.toolMethod.setAccessible(true); + + Object result; + try { + result = this.toolMethod.invoke(this.toolObject, methodArguments); + } + catch (IllegalAccessException ex) { + throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex); + } + catch (InvocationTargetException ex) { + throw new RuntimeException("Error invoking method: " + this.toolMethod.getName(), ex); + // throw new ToolExecutionException(this.toolDefinition, ex.getCause()); + } + return result; + } + + private Object[] buildMethodArguments(McpAsyncServerExchange exchange, Map toolInputArguments) { + return Stream.of(this.toolMethod.getParameters()).map(parameter -> { + Object rawArgument = toolInputArguments.get(parameter.getName()); + if (isExchangeType(parameter.getType())) { + return exchange; + } + return buildTypedArgument(rawArgument, parameter.getParameterizedType()); + }).toArray(); + } + + private Object buildTypedArgument(Object value, Type type) { + if (value == null) { + return null; + } + + if (type instanceof Class) { + return JsonParser.toTypedObject(value, (Class) type); + } + + // For generic types, use the fromJson method that accepts Type + String json = JsonParser.toJson(value); + return JsonParser.fromJson(json, type); + } + + protected boolean isExchangeType(Class paramType) { + return McpAsyncServerExchange.class.isAssignableFrom(paramType); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/ReactiveUtils.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/ReactiveUtils.java new file mode 100644 index 0000000..cf1e434 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/ReactiveUtils.java @@ -0,0 +1,120 @@ +package org.springaicommunity.mcp.method.tool; + +import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.util.Map; +import java.util.Optional; + +import org.reactivestreams.Publisher; +import org.springaicommunity.mcp.method.tool.utils.ConcurrentReferenceHashMap; + +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class ReactiveUtils { + + private static final Map isReactiveOfVoidCache = new ConcurrentReferenceHashMap<>(256); + + private static final Map isReactiveOfCallToolResultCache = new ConcurrentReferenceHashMap<>(256); + + /** + * Check if the given type is a reactive type containing Void (e.g., Mono, + * Flux, Publisher) + */ + public static boolean isReactiveReturnTypeOfVoid(Method method) { + Type returnType = method.getGenericReturnType(); + if (isReactiveOfVoidCache.containsKey(returnType)) { + return isReactiveOfVoidCache.get(returnType); + } + + boolean isReactiveOfVoid = false; + if (returnType instanceof ParameterizedType) { + ParameterizedType parameterizedType = (ParameterizedType) returnType; + Type rawType = parameterizedType.getRawType(); + + // Check if raw type is a reactive type (Mono, Flux, or Publisher) + if (rawType instanceof Class) { + Class rawClass = (Class) rawType; + if (Mono.class.isAssignableFrom(rawClass) || Flux.class.isAssignableFrom(rawClass) + || Publisher.class.isAssignableFrom(rawClass)) { + + Type[] typeArguments = parameterizedType.getActualTypeArguments(); + if (typeArguments.length == 1) { + Type typeArgument = typeArguments[0]; + if (typeArgument instanceof Class) { + isReactiveOfVoid = Void.class.equals(typeArgument) || void.class.equals(typeArgument); + } + } + } + } + } + + isReactiveOfVoidCache.putIfAbsent(returnType, isReactiveOfVoid); + + return isReactiveOfVoid; + } + + /** + * Check if the given type is a reactive type containing CallToolResult (e.g., + * Mono, Flux, Publisher) + */ + public static boolean isReactiveReturnTypeOfCallToolResult(Method method) { + + Type returnType = method.getGenericReturnType(); + + if (isReactiveOfCallToolResultCache.containsKey(returnType)) { + return isReactiveOfCallToolResultCache.get(returnType); + } + boolean isReactiveOfCallToolResult = false; + if (returnType instanceof ParameterizedType) { + ParameterizedType parameterizedType = (ParameterizedType) returnType; + Type rawType = parameterizedType.getRawType(); + + // Check if raw type is a reactive type (Mono, Flux, or Publisher) + if (rawType instanceof Class) { + Class rawClass = (Class) rawType; + if (Mono.class.isAssignableFrom(rawClass) || Flux.class.isAssignableFrom(rawClass) + || Publisher.class.isAssignableFrom(rawClass)) { + + Type[] typeArguments = parameterizedType.getActualTypeArguments(); + if (typeArguments.length == 1) { + Type typeArgument = typeArguments[0]; + if (typeArgument instanceof Class) { + isReactiveOfCallToolResult = CallToolResult.class.isAssignableFrom((Class) typeArgument); + } + } + } + } + } + + isReactiveOfCallToolResultCache.putIfAbsent(returnType, isReactiveOfCallToolResult); + + return isReactiveOfCallToolResult; + } + + public static Optional getReactiveReturnTypeArgument(Method method) { + + Type returnType = method.getGenericReturnType(); + + if (returnType instanceof ParameterizedType) { + ParameterizedType parameterizedType = (ParameterizedType) returnType; + Type rawType = parameterizedType.getRawType(); + + // Check if raw type is a reactive type (Mono, Flux, or Publisher) + if (rawType instanceof Class) { + Class rawClass = (Class) rawType; + if (Mono.class.isAssignableFrom(rawClass) || Flux.class.isAssignableFrom(rawClass) + || Publisher.class.isAssignableFrom(rawClass)) { + + return Optional.of(parameterizedType.getActualTypeArguments()[0]); + } + } + } + + return Optional.empty(); + + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/ReturnMode.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/ReturnMode.java new file mode 100644 index 0000000..4ea28f9 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/ReturnMode.java @@ -0,0 +1,22 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springaicommunity.mcp.method.tool; + +public enum ReturnMode { + + VOID, STRUCTURED, TEXT; + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java new file mode 100644 index 0000000..63b7637 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java @@ -0,0 +1,158 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.method.tool; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Type; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.stream.Stream; + +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.method.tool.utils.JsonParser; + +import com.fasterxml.jackson.core.type.TypeReference; + +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; + +/** + * Class for creating Function callbacks around tool methods. + * + * This class provides a way to convert methods annotated with {@link McpTool} into + * callback functions that can be used to handle tool requests. + * + * @author Christian Tzolov + */ +public final class SyncMcpToolMethodCallback + implements BiFunction { + + private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference>() { + // No implementation needed + }; + + private final Method toolMethod; + + private final Object toolObject; + + private ReturnMode returnMode; + + public SyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject) { + this.toolMethod = toolMethod; + this.toolObject = toolObject; + this.returnMode = returnMode; + } + + /** + * Apply the callback to the given request. + *

+ * This method builds the arguments for the method call, invokes the method, and + * returns the result. + * @param request The tool call request, must not be null + * @return The result of the method invocation + */ + @Override + public CallToolResult apply(McpSyncServerExchange exchange, CallToolRequest request) { + + if (request == null) { + throw new IllegalArgumentException("Request must not be null"); + } + + try { + // Build arguments for the method call + Object[] args = this.buildMethodArguments(exchange, request.arguments()); + + // Invoke the method + Object result = this.callMethod(args); + + // Return the result + if (result instanceof CallToolResult) { + return (CallToolResult) result; + } + + if (returnMode == ReturnMode.VOID) { + return CallToolResult.builder().addTextContent(JsonParser.toJson("Done")).build(); + } + else if (this.returnMode == ReturnMode.STRUCTURED) { + + String jsonOutput = JsonParser.toJson(result); + Map structuredOutput = JsonParser.fromJson(jsonOutput, MAP_TYPE_REFERENCE); + + return CallToolResult.builder().structuredContent(structuredOutput).build(); + } + + // Default to text output + return CallToolResult.builder().addTextContent(result != null ? result.toString() : "null").build(); + + } + catch (Exception e) { + return CallToolResult.builder() + .isError(true) + .addTextContent("Error invoking method: %s".formatted(e.getMessage())) + .build(); + } + } + + private Object callMethod(Object[] methodArguments) { + + this.toolMethod.setAccessible(true); + + Object result; + try { + result = this.toolMethod.invoke(this.toolObject, methodArguments); + } + catch (IllegalAccessException ex) { + throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex); + } + catch (InvocationTargetException ex) { + throw new RuntimeException("Error invoking method: " + this.toolMethod.getName(), ex); + // throw new ToolExecutionException(this.toolDefinition, ex.getCause()); + } + return result; + } + + private Object[] buildMethodArguments(McpSyncServerExchange exchange, Map toolInputArguments) { + return Stream.of(this.toolMethod.getParameters()).map(parameter -> { + Object rawArgument = toolInputArguments.get(parameter.getName()); + if (isExchangeType(parameter.getType())) { + return exchange; + } + return buildTypedArgument(rawArgument, parameter.getParameterizedType()); + }).toArray(); + } + + private Object buildTypedArgument(Object value, Type type) { + if (value == null) { + return null; + } + + if (type instanceof Class) { + return JsonParser.toTypedObject(value, (Class) type); + } + + // For generic types, use the fromJson method that accepts Type + String json = JsonParser.toJson(value); + return JsonParser.fromJson(json, type); + } + + protected boolean isExchangeType(Class paramType) { + return McpSyncServerExchange.class.isAssignableFrom(paramType); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/ClassUtils.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/ClassUtils.java new file mode 100644 index 0000000..5072e1e --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/ClassUtils.java @@ -0,0 +1,172 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.method.tool.utils; + +import java.io.File; +import java.net.InetAddress; +import java.net.URI; +import java.net.URL; +import java.nio.charset.Charset; +import java.nio.file.Path; +import java.time.ZoneId; +import java.time.temporal.Temporal; +import java.util.Currency; +import java.util.Date; +import java.util.IdentityHashMap; +import java.util.Locale; +import java.util.Map; +import java.util.TimeZone; +import java.util.UUID; +import java.util.regex.Pattern; + +import io.modelcontextprotocol.util.Assert; + +public abstract class ClassUtils { + + /** + * Map with primitive wrapper type as key and corresponding primitive type as value, + * for example: {@code Integer.class -> int.class}. + */ + private static final Map, Class> primitiveWrapperTypeMap = new IdentityHashMap<>(9); + + /** + * Map with primitive type as key and corresponding wrapper type as value, for + * example: {@code int.class -> Integer.class}. + */ + private static final Map, Class> primitiveTypeToWrapperMap = new IdentityHashMap<>(9); + + static { + primitiveWrapperTypeMap.put(Boolean.class, boolean.class); + primitiveWrapperTypeMap.put(Byte.class, byte.class); + primitiveWrapperTypeMap.put(Character.class, char.class); + primitiveWrapperTypeMap.put(Double.class, double.class); + primitiveWrapperTypeMap.put(Float.class, float.class); + primitiveWrapperTypeMap.put(Integer.class, int.class); + primitiveWrapperTypeMap.put(Long.class, long.class); + primitiveWrapperTypeMap.put(Short.class, short.class); + primitiveWrapperTypeMap.put(Void.class, void.class); + + // Map entry iteration is less expensive to initialize than forEach with lambdas + for (Map.Entry, Class> entry : primitiveWrapperTypeMap.entrySet()) { + primitiveTypeToWrapperMap.put(entry.getValue(), entry.getKey()); + } + } + + /** + * Check if the given class represents a primitive wrapper, i.e. Boolean, Byte, + * Character, Short, Integer, Long, Float, Double, or Void. + * @param clazz the class to check + * @return whether the given class is a primitive wrapper class + */ + public static boolean isPrimitiveWrapper(Class clazz) { + Assert.notNull(clazz, "Class must not be null"); + return primitiveWrapperTypeMap.containsKey(clazz); + } + + /** + * Check if the given class represents a primitive (i.e. boolean, byte, char, short, + * int, long, float, or double), {@code void}, or a wrapper for those types (i.e. + * Boolean, Byte, Character, Short, Integer, Long, Float, Double, or Void). + * @param clazz the class to check + * @return {@code true} if the given class represents a primitive, void, or a wrapper + * class + */ + public static boolean isPrimitiveOrWrapper(Class clazz) { + Assert.notNull(clazz, "Class must not be null"); + return (clazz.isPrimitive() || isPrimitiveWrapper(clazz)); + } + + /** + * Resolve the given class if it is a primitive class, returning the corresponding + * primitive wrapper type instead. + * @param clazz the class to check + * @return the original class, or a primitive wrapper for the original primitive type + */ + @SuppressWarnings("NullAway") + public static Class resolvePrimitiveIfNecessary(Class clazz) { + Assert.notNull(clazz, "Class must not be null"); + return (clazz.isPrimitive() && clazz != void.class ? primitiveTypeToWrapperMap.get(clazz) : clazz); + } + + /** + * Determine if the given type represents either {@code Void} or {@code void}. + * @param type the type to check + * @return {@code true} if the type represents {@code Void} or {@code void} + * @since 6.1.4 + * @see Void + * @see Void#TYPE + */ + public static boolean isVoidType(Class type) { + return (type == void.class || type == Void.class); + } + + /** + * Delegate for {@link org.springframework.beans.BeanUtils#isSimpleValueType}. Also + * used by {@link ObjectUtils#nullSafeConciseToString}. + *

+ * Check if the given type represents a common "simple" value type: primitive or + * primitive wrapper, {@link Enum}, {@link String} or other {@link CharSequence}, + * {@link Number}, {@link Date}, {@link Temporal}, {@link ZoneId}, {@link TimeZone}, + * {@link File}, {@link Path}, {@link URI}, {@link URL}, {@link InetAddress}, + * {@link Charset}, {@link Currency}, {@link Locale}, {@link UUID}, {@link Pattern}, + * or {@link Class}. + *

+ * {@code Void} and {@code void} are not considered simple value types. + * @param type the type to check + * @return whether the given type represents a "simple" value type, suggesting + * value-based data binding and {@code toString} output + * @since 6.1 + */ + public static boolean isSimpleValueType(Class type) { + return (!isVoidType(type) && (isPrimitiveOrWrapper(type) || Enum.class.isAssignableFrom(type) + || CharSequence.class.isAssignableFrom(type) || Number.class.isAssignableFrom(type) + || Date.class.isAssignableFrom(type) || Temporal.class.isAssignableFrom(type) + || ZoneId.class.isAssignableFrom(type) || TimeZone.class.isAssignableFrom(type) + || File.class.isAssignableFrom(type) || Path.class.isAssignableFrom(type) + || Charset.class.isAssignableFrom(type) || Currency.class.isAssignableFrom(type) + || InetAddress.class.isAssignableFrom(type) || URI.class == type || URL.class == type + || UUID.class == type || Locale.class == type || Pattern.class == type || Class.class == type)); + } + + /** + * Check if the right-hand side type may be assigned to the left-hand side type, + * assuming setting by reflection. Considers primitive wrapper classes as assignable + * to the corresponding primitive types. + * @param lhsType the target type (left-hand side (LHS) type) + * @param rhsType the value type (right-hand side (RHS) type) that should be assigned + * to the target type + * @return {@code true} if {@code rhsType} is assignable to {@code lhsType} + * @see TypeUtils#isAssignable(java.lang.reflect.Type, java.lang.reflect.Type) + */ + public static boolean isAssignable(Class lhsType, Class rhsType) { + Assert.notNull(lhsType, "Left-hand side type must not be null"); + Assert.notNull(rhsType, "Right-hand side type must not be null"); + if (lhsType.isAssignableFrom(rhsType)) { + return true; + } + if (lhsType.isPrimitive()) { + Class resolvedPrimitive = primitiveWrapperTypeMap.get(rhsType); + return (lhsType == resolvedPrimitive); + } + else if (rhsType.isPrimitive()) { + Class resolvedWrapper = primitiveTypeToWrapperMap.get(rhsType); + return (resolvedWrapper != null && lhsType.isAssignableFrom(resolvedWrapper)); + } + return false; + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/ConcurrentReferenceHashMap.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/ConcurrentReferenceHashMap.java new file mode 100644 index 0000000..057b252 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/ConcurrentReferenceHashMap.java @@ -0,0 +1,1179 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.method.tool.utils; + +import java.lang.ref.ReferenceQueue; +import java.lang.ref.SoftReference; +import java.lang.ref.WeakReference; +import java.lang.reflect.Array; +import java.util.AbstractMap; +import java.util.AbstractSet; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; + +import io.modelcontextprotocol.util.Assert; +import reactor.util.annotation.Nullable; + +/** + * A {@link ConcurrentHashMap} that uses {@link ReferenceType#SOFT soft} or + * {@linkplain ReferenceType#WEAK weak} references for both {@code keys} and + * {@code values}. + * + *

+ * This class can be used as an alternative to + * {@code Collections.synchronizedMap(new WeakHashMap>())} in order to + * support better performance when accessed concurrently. This implementation follows the + * same design constraints as {@link ConcurrentHashMap} with the exception that + * {@code null} values and {@code null} keys are supported. + * + *

+ * NOTE: The use of references means that there is no guarantee that items placed + * into the map will be subsequently available. The garbage collector may discard + * references at any time, so it may appear that an unknown thread is silently removing + * entries. + * + *

+ * If not explicitly specified, this implementation will use {@linkplain SoftReference + * soft entry references}. + * + * @author Phillip Webb + * @author Juergen Hoeller + * @author Brian Clozel + * @since 3.2 + * @param the key type + * @param the value type + */ +public class ConcurrentReferenceHashMap extends AbstractMap implements ConcurrentMap { + + private static final int DEFAULT_INITIAL_CAPACITY = 16; + + private static final float DEFAULT_LOAD_FACTOR = 0.75f; + + private static final int DEFAULT_CONCURRENCY_LEVEL = 16; + + private static final ReferenceType DEFAULT_REFERENCE_TYPE = ReferenceType.SOFT; + + private static final int MAXIMUM_CONCURRENCY_LEVEL = 1 << 16; + + private static final int MAXIMUM_SEGMENT_SIZE = 1 << 30; + + /** + * Array of segments indexed using the high order bits from the hash. + */ + private final Segment[] segments; + + /** + * When the average number of references per table exceeds this value resize will be + * attempted. + */ + private final float loadFactor; + + /** + * The reference type: SOFT or WEAK. + */ + private final ReferenceType referenceType; + + /** + * The shift value used to calculate the size of the segments array and an index from + * the hash. + */ + private final int shift; + + /** + * Late binding entry set. + */ + private volatile Set> entrySet; + + /** + * Create a new {@code ConcurrentReferenceHashMap} instance. + */ + public ConcurrentReferenceHashMap() { + this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL, DEFAULT_REFERENCE_TYPE); + } + + /** + * Create a new {@code ConcurrentReferenceHashMap} instance. + * @param initialCapacity the initial capacity of the map + */ + public ConcurrentReferenceHashMap(int initialCapacity) { + this(initialCapacity, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL, DEFAULT_REFERENCE_TYPE); + } + + /** + * Create a new {@code ConcurrentReferenceHashMap} instance. + * @param initialCapacity the initial capacity of the map + * @param loadFactor the load factor. When the average number of references per table + * exceeds this value resize will be attempted + */ + public ConcurrentReferenceHashMap(int initialCapacity, float loadFactor) { + this(initialCapacity, loadFactor, DEFAULT_CONCURRENCY_LEVEL, DEFAULT_REFERENCE_TYPE); + } + + /** + * Create a new {@code ConcurrentReferenceHashMap} instance. + * @param initialCapacity the initial capacity of the map + * @param concurrencyLevel the expected number of threads that will concurrently write + * to the map + */ + public ConcurrentReferenceHashMap(int initialCapacity, int concurrencyLevel) { + this(initialCapacity, DEFAULT_LOAD_FACTOR, concurrencyLevel, DEFAULT_REFERENCE_TYPE); + } + + /** + * Create a new {@code ConcurrentReferenceHashMap} instance. + * @param initialCapacity the initial capacity of the map + * @param referenceType the reference type used for entries (soft or weak) + */ + public ConcurrentReferenceHashMap(int initialCapacity, ReferenceType referenceType) { + this(initialCapacity, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL, referenceType); + } + + /** + * Create a new {@code ConcurrentReferenceHashMap} instance. + * @param initialCapacity the initial capacity of the map + * @param loadFactor the load factor. When the average number of references per table + * exceeds this value, resize will be attempted. + * @param concurrencyLevel the expected number of threads that will concurrently write + * to the map + */ + public ConcurrentReferenceHashMap(int initialCapacity, float loadFactor, int concurrencyLevel) { + this(initialCapacity, loadFactor, concurrencyLevel, DEFAULT_REFERENCE_TYPE); + } + + /** + * Create a new {@code ConcurrentReferenceHashMap} instance. + * @param initialCapacity the initial capacity of the map + * @param loadFactor the load factor. When the average number of references per table + * exceeds this value, resize will be attempted. + * @param concurrencyLevel the expected number of threads that will concurrently write + * to the map + * @param referenceType the reference type used for entries (soft or weak) + */ + @SuppressWarnings("unchecked") + public ConcurrentReferenceHashMap(int initialCapacity, float loadFactor, int concurrencyLevel, + ReferenceType referenceType) { + + // Assert.isTrue(initialCapacity >= 0, "Initial capacity must not be negative"); + // Assert.isTrue(loadFactor > 0f, "Load factor must be positive"); + // Assert.isTrue(concurrencyLevel > 0, "Concurrency level must be positive"); + Assert.notNull(referenceType, "Reference type must not be null"); + this.loadFactor = loadFactor; + this.shift = calculateShift(concurrencyLevel, MAXIMUM_CONCURRENCY_LEVEL); + int size = 1 << this.shift; + this.referenceType = referenceType; + int roundedUpSegmentCapacity = (int) ((initialCapacity + size - 1L) / size); + int initialSize = 1 << calculateShift(roundedUpSegmentCapacity, MAXIMUM_SEGMENT_SIZE); + Segment[] segments = (Segment[]) Array.newInstance(Segment.class, size); + int resizeThreshold = (int) (initialSize * getLoadFactor()); + for (int i = 0; i < segments.length; i++) { + segments[i] = new Segment(initialSize, resizeThreshold); + } + this.segments = segments; + } + + protected final float getLoadFactor() { + return this.loadFactor; + } + + protected final int getSegmentsSize() { + return this.segments.length; + } + + protected final Segment getSegment(int index) { + return this.segments[index]; + } + + /** + * Factory method that returns the {@link ReferenceManager}. This method will be + * called once for each {@link Segment}. + * @return a new reference manager + */ + protected ReferenceManager createReferenceManager() { + return new ReferenceManager(); + } + + /** + * Get the hash for a given object, apply an additional hash function to reduce + * collisions. This implementation uses the same Wang/Jenkins algorithm as + * {@link ConcurrentHashMap}. Subclasses can override to provide alternative hashing. + * @param o the object to hash (may be null) + * @return the resulting hash code + */ + protected int getHash(Object o) { + int hash = (o != null ? o.hashCode() : 0); + hash += (hash << 15) ^ 0xffffcd7d; + hash ^= (hash >>> 10); + hash += (hash << 3); + hash ^= (hash >>> 6); + hash += (hash << 2) + (hash << 14); + hash ^= (hash >>> 16); + return hash; + } + + @Override + public V get(Object key) { + Reference ref = getReference(key, Restructure.WHEN_NECESSARY); + Entry entry = (ref != null ? ref.get() : null); + return (entry != null ? entry.getValue() : null); + } + + @Override + public V getOrDefault(Object key, V defaultValue) { + Reference ref = getReference(key, Restructure.WHEN_NECESSARY); + Entry entry = (ref != null ? ref.get() : null); + return (entry != null ? entry.getValue() : defaultValue); + } + + @Override + public boolean containsKey(Object key) { + Reference ref = getReference(key, Restructure.WHEN_NECESSARY); + Entry entry = (ref != null ? ref.get() : null); + return (entry != null && nullSafeEquals(entry.getKey(), key)); + } + + /** + * Return a {@link Reference} to the {@link Entry} for the specified {@code key}, or + * {@code null} if not found. + * @param key the key (can be {@code null}) + * @param restructure types of restructure allowed during this call + * @return the reference, or {@code null} if not found + */ + + protected final Reference getReference(Object key, Restructure restructure) { + int hash = getHash(key); + return getSegmentForHash(hash).getReference(key, hash, restructure); + } + + @Override + + public V put(K key, V value) { + return put(key, value, true); + } + + @Override + + public V putIfAbsent(K key, V value) { + return put(key, value, false); + } + + private V put(final K key, final V value, final boolean overwriteExisting) { + return doTask(key, new Task(TaskOption.RESTRUCTURE_BEFORE, TaskOption.RESIZE) { + @Override + + protected V execute(Reference ref, Entry entry, Entries entries) { + if (entry != null) { + V oldValue = entry.getValue(); + if (overwriteExisting) { + entry.setValue(value); + } + return oldValue; + } + // Assert.state(entries != null, "No entries segment"); + entries.add(value); + return null; + } + }); + } + + @Override + + public V remove(Object key) { + return doTask(key, new Task(TaskOption.RESTRUCTURE_AFTER, TaskOption.SKIP_IF_EMPTY) { + @Override + + protected V execute(Reference ref, Entry entry) { + if (entry != null) { + if (ref != null) { + ref.release(); + } + return entry.value; + } + return null; + } + }); + } + + @Override + public boolean remove(Object key, final Object value) { + Boolean result = doTask(key, new Task(TaskOption.RESTRUCTURE_AFTER, TaskOption.SKIP_IF_EMPTY) { + @Override + protected Boolean execute(Reference ref, Entry entry) { + if (entry != null && nullSafeEquals(entry.getValue(), value)) { + if (ref != null) { + ref.release(); + } + return true; + } + return false; + } + }); + return (Boolean.TRUE.equals(result)); + } + + @Override + public boolean replace(K key, final V oldValue, final V newValue) { + Boolean result = doTask(key, new Task(TaskOption.RESTRUCTURE_BEFORE, TaskOption.SKIP_IF_EMPTY) { + @Override + protected Boolean execute(Reference ref, Entry entry) { + if (entry != null && nullSafeEquals(entry.getValue(), oldValue)) { + entry.setValue(newValue); + return true; + } + return false; + } + }); + return (Boolean.TRUE.equals(result)); + } + + @Override + + public V replace(K key, final V value) { + return doTask(key, new Task(TaskOption.RESTRUCTURE_BEFORE, TaskOption.SKIP_IF_EMPTY) { + @Override + + protected V execute(Reference ref, Entry entry) { + if (entry != null) { + V oldValue = entry.getValue(); + entry.setValue(value); + return oldValue; + } + return null; + } + }); + } + + @Override + public void clear() { + for (Segment segment : this.segments) { + segment.clear(); + } + } + + /** + * Remove any entries that have been garbage collected and are no longer referenced. + * Under normal circumstances garbage collected entries are automatically purged as + * items are added or removed from the Map. This method can be used to force a purge, + * and is useful when the Map is read frequently but updated less often. + */ + public void purgeUnreferencedEntries() { + for (Segment segment : this.segments) { + segment.restructureIfNecessary(false); + } + } + + @Override + public int size() { + int size = 0; + for (Segment segment : this.segments) { + size += segment.getCount(); + } + return size; + } + + @Override + public boolean isEmpty() { + for (Segment segment : this.segments) { + if (segment.getCount() > 0) { + return false; + } + } + return true; + } + + @Override + public Set> entrySet() { + Set> entrySet = this.entrySet; + if (entrySet == null) { + entrySet = new EntrySet(); + this.entrySet = entrySet; + } + return entrySet; + } + + private T doTask(Object key, Task task) { + int hash = getHash(key); + return getSegmentForHash(hash).doTask(hash, key, task); + } + + private Segment getSegmentForHash(int hash) { + return this.segments[(hash >>> (32 - this.shift)) & (this.segments.length - 1)]; + } + + /** + * Calculate a shift value that can be used to create a power-of-two value between the + * specified maximum and minimum values. + * @param minimumValue the minimum value + * @param maximumValue the maximum value + * @return the calculated shift (use {@code 1 << shift} to obtain a value) + */ + protected static int calculateShift(int minimumValue, int maximumValue) { + int shift = 0; + int value = 1; + while (value < minimumValue && value < maximumValue) { + value <<= 1; + shift++; + } + return shift; + } + + /** + * Various reference types supported by this map. + */ + public enum ReferenceType { + + /** Use {@link SoftReference SoftReferences}. */ + SOFT, + + /** Use {@link WeakReference WeakReferences}. */ + WEAK + + } + + /** + * A single segment used to divide the map to allow better concurrent performance. + */ + @SuppressWarnings("serial") + protected final class Segment extends ReentrantLock { + + private final ReferenceManager referenceManager; + + private final int initialSize; + + /** + * Array of references indexed using the low order bits from the hash. This + * property should only be set along with {@code resizeThreshold}. + */ + private volatile Reference[] references; + + /** + * The total number of references contained in this segment. This includes chained + * references and references that have been garbage collected but not purged. + */ + private final AtomicInteger count = new AtomicInteger(); + + /** + * The threshold when resizing of the references should occur. When {@code count} + * exceeds this value references will be resized. + */ + private int resizeThreshold; + + public Segment(int initialSize, int resizeThreshold) { + this.referenceManager = createReferenceManager(); + this.initialSize = initialSize; + this.references = createReferenceArray(initialSize); + this.resizeThreshold = resizeThreshold; + } + + public Reference getReference(Object key, int hash, Restructure restructure) { + if (restructure == Restructure.WHEN_NECESSARY) { + restructureIfNecessary(false); + } + if (this.count.get() == 0) { + return null; + } + // Use a local copy to protect against other threads writing + Reference[] references = this.references; + int index = getIndex(hash, references); + Reference head = references[index]; + return findInChain(head, key, hash); + } + + /** + * Apply an update operation to this segment. The segment will be locked during + * the update. + * @param hash the hash of the key + * @param key the key + * @param task the update operation + * @return the result of the operation + */ + + public T doTask(final int hash, final Object key, final Task task) { + boolean resize = task.hasOption(TaskOption.RESIZE); + if (task.hasOption(TaskOption.RESTRUCTURE_BEFORE)) { + restructureIfNecessary(resize); + } + if (task.hasOption(TaskOption.SKIP_IF_EMPTY) && this.count.get() == 0) { + return task.execute(null, null, null); + } + lock(); + try { + final int index = getIndex(hash, this.references); + final Reference head = this.references[index]; + Reference ref = findInChain(head, key, hash); + Entry entry = (ref != null ? ref.get() : null); + Entries entries = value -> { + @SuppressWarnings("unchecked") + Entry newEntry = new Entry<>((K) key, value); + Reference newReference = Segment.this.referenceManager.createReference(newEntry, hash, head); + Segment.this.references[index] = newReference; + Segment.this.count.incrementAndGet(); + }; + return task.execute(ref, entry, entries); + } + finally { + unlock(); + if (task.hasOption(TaskOption.RESTRUCTURE_AFTER)) { + restructureIfNecessary(resize); + } + } + } + + /** + * Clear all items from this segment. + */ + public void clear() { + if (this.count.get() == 0) { + return; + } + lock(); + try { + this.references = createReferenceArray(this.initialSize); + this.resizeThreshold = (int) (this.references.length * getLoadFactor()); + this.count.set(0); + } + finally { + unlock(); + } + } + + /** + * Restructure the underlying data structure when it becomes necessary. This + * method can increase the size of the references table as well as purge any + * references that have been garbage collected. + * @param allowResize if resizing is permitted + */ + void restructureIfNecessary(boolean allowResize) { + int currCount = this.count.get(); + boolean needsResize = allowResize && (currCount > 0 && currCount >= this.resizeThreshold); + Reference ref = this.referenceManager.pollForPurge(); + if (ref != null || (needsResize)) { + restructure(allowResize, ref); + } + } + + private void restructure(boolean allowResize, Reference ref) { + boolean needsResize; + lock(); + try { + int expectedCount = this.count.get(); + Set> toPurge = Collections.emptySet(); + if (ref != null) { + toPurge = new HashSet<>(); + while (ref != null) { + toPurge.add(ref); + ref = this.referenceManager.pollForPurge(); + } + } + expectedCount -= toPurge.size(); + + // Estimate new count, taking into account count inside lock and items + // that + // will be purged. + needsResize = (expectedCount > 0 && expectedCount >= this.resizeThreshold); + boolean resizing = false; + int restructureSize = this.references.length; + if (allowResize && needsResize && restructureSize < MAXIMUM_SEGMENT_SIZE) { + restructureSize <<= 1; + resizing = true; + } + + int newCount = 0; + // Restructure the resized reference array + if (resizing) { + Reference[] restructured = createReferenceArray(restructureSize); + for (Reference reference : this.references) { + ref = reference; + while (ref != null) { + if (!toPurge.contains(ref)) { + Entry entry = ref.get(); + // Also filter out null references that are now null + // they should be polled from the queue in a later + // restructure call. + if (entry != null) { + int index = getIndex(ref.getHash(), restructured); + restructured[index] = this.referenceManager.createReference(entry, ref.getHash(), + restructured[index]); + newCount++; + } + } + ref = ref.getNext(); + } + } + // Replace volatile members + this.references = restructured; + this.resizeThreshold = (int) (this.references.length * getLoadFactor()); + } + // Restructure the existing reference array "in place" + else { + for (int i = 0; i < this.references.length; i++) { + Reference purgedRef = null; + ref = this.references[i]; + while (ref != null) { + if (!toPurge.contains(ref)) { + Entry entry = ref.get(); + // Also filter out null references that are now null + // they should be polled from the queue in a later + // restructure call. + if (entry != null) { + purgedRef = this.referenceManager.createReference(entry, ref.getHash(), purgedRef); + } + newCount++; + } + ref = ref.getNext(); + } + this.references[i] = purgedRef; + } + } + this.count.set(Math.max(newCount, 0)); + } + finally { + unlock(); + } + } + + private Reference findInChain(Reference ref, Object key, int hash) { + Reference currRef = ref; + while (currRef != null) { + if (currRef.getHash() == hash) { + Entry entry = currRef.get(); + if (entry != null) { + K entryKey = entry.getKey(); + if (nullSafeEquals(entryKey, key)) { + return currRef; + } + } + } + currRef = currRef.getNext(); + } + return null; + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private Reference[] createReferenceArray(int size) { + return new Reference[size]; + } + + private int getIndex(int hash, Reference[] references) { + return (hash & (references.length - 1)); + } + + /** + * Return the size of the current references array. + */ + public int getSize() { + return this.references.length; + } + + /** + * Return the total number of references in this segment. + */ + public int getCount() { + return this.count.get(); + } + + } + + /** + * A reference to an {@link Entry} contained in the map. Implementations are usually + * wrappers around specific Java reference implementations (for example, + * {@link SoftReference}). + * + * @param the key type + * @param the value type + */ + protected interface Reference { + + /** + * Return the referenced entry, or {@code null} if the entry is no longer + * available. + */ + + Entry get(); + + /** + * Return the hash for the reference. + */ + int getHash(); + + /** + * Return the next reference in the chain, or {@code null} if none. + */ + + Reference getNext(); + + /** + * Release this entry and ensure that it will be returned from + * {@code ReferenceManager#pollForPurge()}. + */ + void release(); + + } + + /** + * A single map entry. + * + * @param the key type + * @param the value type + */ + protected static final class Entry implements Map.Entry { + + private final K key; + + private volatile V value; + + public Entry(K key, V value) { + this.key = key; + this.value = value; + } + + @Override + + public K getKey() { + return this.key; + } + + @Override + + public V getValue() { + return this.value; + } + + @Override + + public V setValue(V value) { + V previous = this.value; + this.value = value; + return previous; + } + + @Override + public boolean equals(Object other) { + return (this == other || (other instanceof Map.Entry that && nullSafeEquals(getKey(), that.getKey()) + && nullSafeEquals(getValue(), that.getValue()))); + } + + @Override + public int hashCode() { + return (nullSafeHashCode(this.key) ^ nullSafeHashCode(this.value)); + } + + @Override + public String toString() { + return (this.key + "=" + this.value); + } + + } + + /** + * A task that can be {@link Segment#doTask run} against a {@link Segment}. + */ + private abstract class Task { + + private final EnumSet options; + + public Task(TaskOption... options) { + this.options = (options.length == 0 ? EnumSet.noneOf(TaskOption.class) : EnumSet.of(options[0], options)); + } + + public boolean hasOption(TaskOption option) { + return this.options.contains(option); + } + + /** + * Execute the task. + * @param ref the found reference (or {@code null}) + * @param entry the found entry (or {@code null}) + * @param entries access to the underlying entries + * @return the result of the task + * @see #execute(Reference, Entry) + */ + + protected T execute(Reference ref, Entry entry, Entries entries) { + return execute(ref, entry); + } + + /** + * Convenience method that can be used for tasks that do not need access to + * {@link Entries}. + * @param ref the found reference (or {@code null}) + * @param entry the found entry (or {@code null}) + * @return the result of the task + * @see #execute(Reference, Entry, Entries) + */ + + protected T execute(Reference ref, Entry entry) { + return null; + } + + } + + /** + * Various options supported by a {@code Task}. + */ + private enum TaskOption { + + RESTRUCTURE_BEFORE, RESTRUCTURE_AFTER, SKIP_IF_EMPTY, RESIZE + + } + + /** + * Allows a task access to {@link ConcurrentReferenceHashMap.Segment} entries. + */ + private interface Entries { + + /** + * Add a new entry with the specified value. + * @param value the value to add + */ + void add(V value); + + } + + /** + * Internal entry-set implementation. + */ + private class EntrySet extends AbstractSet> { + + @Override + public Iterator> iterator() { + return new EntryIterator(); + } + + @Override + public boolean contains(Object o) { + if (o instanceof Map.Entry entry) { + Reference ref = ConcurrentReferenceHashMap.this.getReference(entry.getKey(), Restructure.NEVER); + Entry otherEntry = (ref != null ? ref.get() : null); + if (otherEntry != null) { + return nullSafeEquals(entry.getValue(), otherEntry.getValue()); + } + } + return false; + } + + @Override + public boolean remove(Object o) { + if (o instanceof Map.Entry entry) { + return ConcurrentReferenceHashMap.this.remove(entry.getKey(), entry.getValue()); + } + return false; + } + + @Override + public int size() { + return ConcurrentReferenceHashMap.this.size(); + } + + @Override + public void clear() { + ConcurrentReferenceHashMap.this.clear(); + } + + } + + /** + * Internal entry iterator implementation. + */ + private class EntryIterator implements Iterator> { + + private int segmentIndex; + + private int referenceIndex; + + private Reference[] references; + + private Reference reference; + + private Entry next; + + private Entry last; + + public EntryIterator() { + moveToNextSegment(); + } + + @Override + public boolean hasNext() { + getNextIfNecessary(); + return (this.next != null); + } + + @Override + public Entry next() { + getNextIfNecessary(); + if (this.next == null) { + throw new NoSuchElementException(); + } + this.last = this.next; + this.next = null; + return this.last; + } + + private void getNextIfNecessary() { + while (this.next == null) { + moveToNextReference(); + if (this.reference == null) { + return; + } + this.next = this.reference.get(); + } + } + + private void moveToNextReference() { + if (this.reference != null) { + this.reference = this.reference.getNext(); + } + while (this.reference == null && this.references != null) { + if (this.referenceIndex >= this.references.length) { + moveToNextSegment(); + this.referenceIndex = 0; + } + else { + this.reference = this.references[this.referenceIndex]; + this.referenceIndex++; + } + } + } + + private void moveToNextSegment() { + this.reference = null; + this.references = null; + if (this.segmentIndex < ConcurrentReferenceHashMap.this.segments.length) { + this.references = ConcurrentReferenceHashMap.this.segments[this.segmentIndex].references; + this.segmentIndex++; + } + } + + @Override + public void remove() { + // Assert.state(this.last != null, "No element to remove"); + ConcurrentReferenceHashMap.this.remove(this.last.getKey()); + this.last = null; + } + + } + + /** + * The types of restructuring that can be performed. + */ + protected enum Restructure { + + WHEN_NECESSARY, NEVER + + } + + /** + * Strategy class used to manage {@link Reference References}. This class can be + * overridden if alternative reference types need to be supported. + */ + protected class ReferenceManager { + + private final ReferenceQueue> queue = new ReferenceQueue<>(); + + /** + * Factory method used to create a new {@link Reference}. + * @param entry the entry contained in the reference + * @param hash the hash + * @param next the next reference in the chain, or {@code null} if none + * @return a new {@link Reference} + */ + public Reference createReference(Entry entry, int hash, Reference next) { + if (ConcurrentReferenceHashMap.this.referenceType == ReferenceType.WEAK) { + return new WeakEntryReference<>(entry, hash, next, this.queue); + } + return new SoftEntryReference<>(entry, hash, next, this.queue); + } + + /** + * Return any reference that has been garbage collected and can be purged from the + * underlying structure or {@code null} if no references need purging. This method + * must be thread safe and ideally should not block when returning {@code null}. + * References should be returned once and only once. + * @return a reference to purge or {@code null} + */ + @SuppressWarnings("unchecked") + + public Reference pollForPurge() { + return (Reference) this.queue.poll(); + } + + } + + /** + * Internal {@link Reference} implementation for {@link SoftReference SoftReferences}. + */ + private static final class SoftEntryReference extends SoftReference> implements Reference { + + private final int hash; + + private final Reference nextReference; + + public SoftEntryReference(Entry entry, int hash, Reference next, + ReferenceQueue> queue) { + + super(entry, queue); + this.hash = hash; + this.nextReference = next; + } + + @Override + public int getHash() { + return this.hash; + } + + @Override + + public Reference getNext() { + return this.nextReference; + } + + @Override + public void release() { + enqueue(); + } + + } + + /** + * Internal {@link Reference} implementation for {@link WeakReference WeakReferences}. + */ + private static final class WeakEntryReference extends WeakReference> implements Reference { + + private final int hash; + + private final Reference nextReference; + + public WeakEntryReference(Entry entry, int hash, Reference next, + ReferenceQueue> queue) { + + super(entry, queue); + this.hash = hash; + this.nextReference = next; + } + + @Override + public int getHash() { + return this.hash; + } + + @Override + + public Reference getNext() { + return this.nextReference; + } + + @Override + public void release() { + enqueue(); + } + + } + + public static boolean nullSafeEquals(@Nullable Object o1, @Nullable Object o2) { + if (o1 == o2) { + return true; + } + if (o1 == null || o2 == null) { + return false; + } + if (o1.equals(o2)) { + return true; + } + if (o1.getClass().isArray() && o2.getClass().isArray()) { + return arrayEquals(o1, o2); + } + return false; + } + + private static boolean arrayEquals(Object o1, Object o2) { + if (o1 instanceof Object[] objects1 && o2 instanceof Object[] objects2) { + return Arrays.equals(objects1, objects2); + } + if (o1 instanceof boolean[] booleans1 && o2 instanceof boolean[] booleans2) { + return Arrays.equals(booleans1, booleans2); + } + if (o1 instanceof byte[] bytes1 && o2 instanceof byte[] bytes2) { + return Arrays.equals(bytes1, bytes2); + } + if (o1 instanceof char[] chars1 && o2 instanceof char[] chars2) { + return Arrays.equals(chars1, chars2); + } + if (o1 instanceof double[] doubles1 && o2 instanceof double[] doubles2) { + return Arrays.equals(doubles1, doubles2); + } + if (o1 instanceof float[] floats1 && o2 instanceof float[] floats2) { + return Arrays.equals(floats1, floats2); + } + if (o1 instanceof int[] ints1 && o2 instanceof int[] ints2) { + return Arrays.equals(ints1, ints2); + } + if (o1 instanceof long[] longs1 && o2 instanceof long[] longs2) { + return Arrays.equals(longs1, longs2); + } + if (o1 instanceof short[] shorts1 && o2 instanceof short[] shorts2) { + return Arrays.equals(shorts1, shorts2); + } + return false; + } + + public static int nullSafeHashCode(Object obj) { + if (obj == null) { + return 0; + } + if (obj.getClass().isArray()) { + if (obj instanceof Object[] objects) { + return Arrays.hashCode(objects); + } + if (obj instanceof boolean[] booleans) { + return Arrays.hashCode(booleans); + } + if (obj instanceof byte[] bytes) { + return Arrays.hashCode(bytes); + } + if (obj instanceof char[] chars) { + return Arrays.hashCode(chars); + } + if (obj instanceof double[] doubles) { + return Arrays.hashCode(doubles); + } + if (obj instanceof float[] floats) { + return Arrays.hashCode(floats); + } + if (obj instanceof int[] ints) { + return Arrays.hashCode(ints); + } + if (obj instanceof long[] longs) { + return Arrays.hashCode(longs); + } + if (obj instanceof short[] shorts) { + return Arrays.hashCode(shorts); + } + } + return obj.hashCode(); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/JsonParser.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/JsonParser.java new file mode 100644 index 0000000..90b48d9 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/JsonParser.java @@ -0,0 +1,174 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.method.tool.utils; + +import java.lang.reflect.Type; +import java.math.BigDecimal; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.databind.json.JsonMapper; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; + +import io.modelcontextprotocol.util.Assert; + +/** + * Utilities to perform parsing operations between JSON and Java. + */ +public final class JsonParser { + + private static final ObjectMapper OBJECT_MAPPER = JsonMapper.builder() + .disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES) + .disable(SerializationFeature.FAIL_ON_EMPTY_BEANS) + .disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS) + .addModule(new JavaTimeModule()) + .build(); + + private JsonParser() { + } + + /** + * Returns a Jackson {@link ObjectMapper} instance tailored for JSON-parsing + * operations for tool calling and structured output. + */ + public static ObjectMapper getObjectMapper() { + return OBJECT_MAPPER; + } + + /** + * Converts a JSON string to a Java object. + */ + public static T fromJson(String json, Class type) { + Assert.notNull(json, "json cannot be null"); + Assert.notNull(type, "type cannot be null"); + + try { + return OBJECT_MAPPER.readValue(json, type); + } + catch (JsonProcessingException ex) { + throw new IllegalStateException("Conversion from JSON to %s failed".formatted(type.getName()), ex); + } + } + + /** + * Converts a JSON string to a Java object. + */ + public static T fromJson(String json, Type type) { + Assert.notNull(json, "json cannot be null"); + Assert.notNull(type, "type cannot be null"); + + try { + return OBJECT_MAPPER.readValue(json, OBJECT_MAPPER.constructType(type)); + } + catch (JsonProcessingException ex) { + throw new IllegalStateException("Conversion from JSON to %s failed".formatted(type.getTypeName()), ex); + } + } + + /** + * Converts a JSON string to a Java object. + */ + public static T fromJson(String json, TypeReference type) { + Assert.notNull(json, "json cannot be null"); + Assert.notNull(type, "type cannot be null"); + + try { + return OBJECT_MAPPER.readValue(json, type); + } + catch (JsonProcessingException ex) { + throw new IllegalStateException("Conversion from JSON to %s failed".formatted(type.getType().getTypeName()), + ex); + } + } + + /** + * Checks if a string is a valid JSON string. + */ + private static boolean isValidJson(String input) { + try { + OBJECT_MAPPER.readTree(input); + return true; + } + catch (JsonProcessingException e) { + return false; + } + } + + /** + * Converts a Java object to a JSON string if it's not already a valid JSON string. + */ + public static String toJson(Object object) { + if (object instanceof String && isValidJson((String) object)) { + return (String) object; + } + try { + return OBJECT_MAPPER.writeValueAsString(object); + } + catch (JsonProcessingException ex) { + throw new IllegalStateException("Conversion from Object to JSON failed", ex); + } + } + + /** + * Convert a Java Object to a typed Object. Based on the implementation in + * MethodToolCallback. + */ + @SuppressWarnings({ "rawtypes", "unchecked" }) + public static Object toTypedObject(Object value, Class type) { + Assert.notNull(value, "value cannot be null"); + Assert.notNull(type, "type cannot be null"); + + var javaType = ClassUtils.resolvePrimitiveIfNecessary(type); + + if (javaType == String.class) { + return value.toString(); + } + else if (javaType == Byte.class) { + return Byte.parseByte(value.toString()); + } + else if (javaType == Integer.class) { + BigDecimal bigDecimal = new BigDecimal(value.toString()); + return bigDecimal.intValueExact(); + } + else if (javaType == Short.class) { + return Short.parseShort(value.toString()); + } + else if (javaType == Long.class) { + BigDecimal bigDecimal = new BigDecimal(value.toString()); + return bigDecimal.longValueExact(); + } + else if (javaType == Double.class) { + return Double.parseDouble(value.toString()); + } + else if (javaType == Float.class) { + return Float.parseFloat(value.toString()); + } + else if (javaType == Boolean.class) { + return Boolean.parseBoolean(value.toString()); + } + else if (javaType.isEnum()) { + return Enum.valueOf((Class) javaType, value.toString()); + } + + String json = JsonParser.toJson(value); + return JsonParser.fromJson(json, javaType); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/JsonSchemaGenerator.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/JsonSchemaGenerator.java new file mode 100644 index 0000000..2a33dbd --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/JsonSchemaGenerator.java @@ -0,0 +1,194 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.method.tool.utils; + +import java.lang.reflect.Method; +import java.lang.reflect.Parameter; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.springaicommunity.mcp.annotation.McpToolParam; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.github.victools.jsonschema.generator.Option; +import com.github.victools.jsonschema.generator.OptionPreset; +import com.github.victools.jsonschema.generator.SchemaGenerator; +import com.github.victools.jsonschema.generator.SchemaGeneratorConfig; +import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder; +import com.github.victools.jsonschema.generator.SchemaVersion; +import com.github.victools.jsonschema.module.jackson.JacksonModule; +import com.github.victools.jsonschema.module.jackson.JacksonOption; +import com.github.victools.jsonschema.module.swagger2.Swagger2Module; + +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import io.swagger.v3.oas.annotations.media.Schema; +import reactor.util.annotation.Nullable; + +import com.github.victools.jsonschema.generator.Module; + +public class JsonSchemaGenerator { + + private static final boolean PROPERTY_REQUIRED_BY_DEFAULT = true; + + private static final SchemaGenerator TYPE_SCHEMA_GENERATOR; + + private static final SchemaGenerator SUBTYPE_SCHEMA_GENERATOR; + + private static final Map methodSchemaCache = new ConcurrentReferenceHashMap<>(256); + + private static final Map, String> classSchemaCache = new ConcurrentReferenceHashMap<>(256); + + /* + * Initialize JSON Schema generators. + */ + static { + Module jacksonModule = new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED); + Module openApiModule = new Swagger2Module(); + Module springAiSchemaModule = PROPERTY_REQUIRED_BY_DEFAULT ? new SpringAiSchemaModule() + : new SpringAiSchemaModule(SpringAiSchemaModule.Option.PROPERTY_REQUIRED_FALSE_BY_DEFAULT); + + SchemaGeneratorConfigBuilder schemaGeneratorConfigBuilder = new SchemaGeneratorConfigBuilder( + SchemaVersion.DRAFT_2020_12, OptionPreset.PLAIN_JSON) + .with(jacksonModule) + .with(openApiModule) + .with(springAiSchemaModule) + .with(Option.EXTRA_OPEN_API_FORMAT_VALUES) + .with(Option.STANDARD_FORMATS) + .with(Option.PLAIN_DEFINITION_KEYS); + + SchemaGeneratorConfig typeSchemaGeneratorConfig = schemaGeneratorConfigBuilder.build(); + TYPE_SCHEMA_GENERATOR = new SchemaGenerator(typeSchemaGeneratorConfig); + + SchemaGeneratorConfig subtypeSchemaGeneratorConfig = schemaGeneratorConfigBuilder + .without(Option.SCHEMA_VERSION_INDICATOR) + .build(); + SUBTYPE_SCHEMA_GENERATOR = new SchemaGenerator(subtypeSchemaGeneratorConfig); + } + + public static String generateForMethodInput(Method method) { + Assert.notNull(method, "method cannot be null"); + return methodSchemaCache.computeIfAbsent(method, JsonSchemaGenerator::internalGenerateFromMethodArguments); + } + + private static String internalGenerateFromMethodArguments(Method method) { + + ObjectNode schema = JsonParser.getObjectMapper().createObjectNode(); + schema.put("$schema", SchemaVersion.DRAFT_2020_12.getIdentifier()); + schema.put("type", "object"); + + ObjectNode properties = schema.putObject("properties"); + List required = new ArrayList<>(); + + for (int i = 0; i < method.getParameterCount(); i++) { + String parameterName = method.getParameters()[i].getName(); + Type parameterType = method.getGenericParameterTypes()[i]; + if (parameterType instanceof Class parameterClass + && (ClassUtils.isAssignable(McpSyncServerExchange.class, parameterClass) + || ClassUtils.isAssignable(McpAsyncServerExchange.class, parameterClass))) { + continue; + } + if (isMethodParameterRequired(method, i)) { + required.add(parameterName); + } + ObjectNode parameterNode = SUBTYPE_SCHEMA_GENERATOR.generateSchema(parameterType); + String parameterDescription = getMethodParameterDescription(method, i); + if (Utils.hasText(parameterDescription)) { + parameterNode.put("description", parameterDescription); + } + properties.set(parameterName, parameterNode); + } + + var requiredArray = schema.putArray("required"); + required.forEach(requiredArray::add); + + return schema.toPrettyString(); + } + + public static String generateFromClass(Class clazz) { + Assert.notNull(clazz, "clazz cannot be null"); + return classSchemaCache.computeIfAbsent(clazz, JsonSchemaGenerator::internalGenerateFromClass); + } + + private static String internalGenerateFromClass(Class clazz) { + SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_2020_12, + OptionPreset.PLAIN_JSON); + SchemaGeneratorConfig config = configBuilder.with(Option.EXTRA_OPEN_API_FORMAT_VALUES) + .without(Option.FLATTENED_ENUMS_FROM_TOSTRING) + .build(); + + SchemaGenerator generator = new SchemaGenerator(config); + JsonNode jsonSchema = generator.generateSchema(clazz); + return jsonSchema.toPrettyString(); + } + + private static boolean isMethodParameterRequired(Method method, int index) { + Parameter parameter = method.getParameters()[index]; + + var toolParamAnnotation = parameter.getAnnotation(McpToolParam.class); + if (toolParamAnnotation != null) { + return toolParamAnnotation.required(); + } + + var propertyAnnotation = parameter.getAnnotation(JsonProperty.class); + if (propertyAnnotation != null) { + return propertyAnnotation.required(); + } + + var schemaAnnotation = parameter.getAnnotation(Schema.class); + if (schemaAnnotation != null) { + return schemaAnnotation.requiredMode() == Schema.RequiredMode.REQUIRED + || schemaAnnotation.requiredMode() == Schema.RequiredMode.AUTO || schemaAnnotation.required(); + } + + var nullableAnnotation = parameter.getAnnotation(Nullable.class); + if (nullableAnnotation != null) { + return false; + } + + return PROPERTY_REQUIRED_BY_DEFAULT; + } + + private static String getMethodParameterDescription(Method method, int index) { + Parameter parameter = method.getParameters()[index]; + + var toolParamAnnotation = parameter.getAnnotation(McpToolParam.class); + if (toolParamAnnotation != null && Utils.hasText(toolParamAnnotation.description())) { + return toolParamAnnotation.description(); + } + + var jacksonAnnotation = parameter.getAnnotation(JsonPropertyDescription.class); + if (jacksonAnnotation != null && Utils.hasText(jacksonAnnotation.value())) { + return jacksonAnnotation.value(); + } + + var schemaAnnotation = parameter.getAnnotation(Schema.class); + if (schemaAnnotation != null && Utils.hasText(schemaAnnotation.description())) { + return schemaAnnotation.description(); + } + + return null; + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/SpringAiSchemaModule.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/SpringAiSchemaModule.java new file mode 100644 index 0000000..5b6694d --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/SpringAiSchemaModule.java @@ -0,0 +1,122 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.method.tool.utils; + +import java.util.stream.Stream; + +import org.springaicommunity.mcp.annotation.McpToolParam; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.github.victools.jsonschema.generator.FieldScope; +import com.github.victools.jsonschema.generator.MemberScope; +import com.github.victools.jsonschema.generator.Module; +import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder; +import com.github.victools.jsonschema.generator.SchemaGeneratorConfigPart; + +import io.modelcontextprotocol.util.Utils; +import io.swagger.v3.oas.annotations.media.Schema; + +/** + * JSON Schema Generator Module for Spring AI. + *

+ * This module provides a set of customizations to the JSON Schema generator to support + * the Spring AI framework. It allows to extract descriptions from + * {@code @ToolParam(description = ...)} annotations and to determine whether a property + * is required based on the presence of a series of annotations. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public final class SpringAiSchemaModule implements Module { + + private final boolean requiredByDefault; + + public SpringAiSchemaModule(Option... options) { + this.requiredByDefault = Stream.of(options) + .noneMatch(option -> option == Option.PROPERTY_REQUIRED_FALSE_BY_DEFAULT); + } + + @Override + public void applyToConfigBuilder(SchemaGeneratorConfigBuilder builder) { + this.applyToConfigBuilder(builder.forFields()); + } + + private void applyToConfigBuilder(SchemaGeneratorConfigPart configPart) { + configPart.withDescriptionResolver(this::resolveDescription); + configPart.withRequiredCheck(this::checkRequired); + } + + /** + * Extract description from {@code @ToolParam(description = ...)} for the given field. + */ + private String resolveDescription(MemberScope member) { + var toolParamAnnotation = member.getAnnotationConsideringFieldAndGetter(McpToolParam.class); + if (toolParamAnnotation != null && Utils.hasText(toolParamAnnotation.description())) { + return toolParamAnnotation.description(); + } + return null; + } + + /** + * Determines whether a property is required based on the presence of a series of + * annotations. + *

+ *

    + *
  • {@code @ToolParam(required = ...)}
  • + *
  • {@code @JsonProperty(required = ...)}
  • + *
  • {@code @Schema(required = ...)}
  • + *
  • {@code @Nullable}
  • + *
+ *

+ * If none of these annotations are present, the default behavior is to consider the + * property as required, unless the {@link Option#PROPERTY_REQUIRED_FALSE_BY_DEFAULT} + * option is set. + */ + private boolean checkRequired(MemberScope member) { + var toolParamAnnotation = member.getAnnotationConsideringFieldAndGetter(McpToolParam.class); + if (toolParamAnnotation != null) { + return toolParamAnnotation.required(); + } + + var propertyAnnotation = member.getAnnotationConsideringFieldAndGetter(JsonProperty.class); + if (propertyAnnotation != null) { + return propertyAnnotation.required(); + } + + var schemaAnnotation = member.getAnnotationConsideringFieldAndGetter(Schema.class); + if (schemaAnnotation != null) { + return schemaAnnotation.requiredMode() == Schema.RequiredMode.REQUIRED + || schemaAnnotation.requiredMode() == Schema.RequiredMode.AUTO || schemaAnnotation.required(); + } + + return this.requiredByDefault; + } + + /** + * Options for customizing the behavior of the module. + */ + public enum Option { + + /** + * Properties are only required if marked as such via one of the supported + * annotations. + */ + PROPERTY_REQUIRED_FALSE_BY_DEFAULT + + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncMcpToolProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncMcpToolProvider.java new file mode 100644 index 0000000..4da33d2 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncMcpToolProvider.java @@ -0,0 +1,154 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.function.BiFunction; +import java.util.stream.Stream; + +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.method.tool.AsyncMcpToolMethodCallback; +import org.springaicommunity.mcp.method.tool.ReactiveUtils; +import org.springaicommunity.mcp.method.tool.ReturnMode; +import org.springaicommunity.mcp.method.tool.utils.ClassUtils; +import org.springaicommunity.mcp.method.tool.utils.JsonSchemaGenerator; + +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolSpecification; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * @author Christian Tzolov + */ +public class AsyncMcpToolProvider { + + private static final Logger logger = LoggerFactory.getLogger(AsyncMcpToolProvider.class); + + private final List toolObjects; + + /** + * Create a new SyncMcpToolProvider. + * @param toolObjects the objects containing methods annotated with {@link McpTool} + */ + public AsyncMcpToolProvider(List toolObjects) { + Assert.notNull(toolObjects, "toolObjects cannot be null"); + this.toolObjects = toolObjects; + } + + /** + * Get the tool handler. + * @return the tool handler + * @throws IllegalStateException if no tool methods are found or if multiple tool + * methods are found + */ + public List getToolSpecifications() { + + List toolSpecs = this.toolObjects.stream() + .map(toolObject -> Stream.of(doGetClassMethods(toolObject)) + .filter(method -> method.isAnnotationPresent(McpTool.class)) + .filter(method -> Mono.class.isAssignableFrom(method.getReturnType()) + || Flux.class.isAssignableFrom(method.getReturnType()) + || Publisher.class.isAssignableFrom(method.getReturnType())) + .map(mcpToolMethod -> { + + var toolAnnotation = doGetMcpToolAnnotation(mcpToolMethod); + + String toolName = Utils.hasText(toolAnnotation.name()) ? toolAnnotation.name() + : mcpToolMethod.getName(); + + String toolDescrption = toolAnnotation.description(); + + String inputSchema = JsonSchemaGenerator.generateForMethodInput(mcpToolMethod); + + var toolBuilder = McpSchema.Tool.builder() + .name(toolName) + .description(toolDescrption) + .inputSchema(inputSchema); + + // Tool annotations + if (toolAnnotation.annotations() != null) { + var toolAnnotations = toolAnnotation.annotations(); + toolBuilder.annotations(new McpSchema.ToolAnnotations(toolAnnotations.title(), + toolAnnotations.readOnlyHint(), toolAnnotations.destructiveHint(), + toolAnnotations.idempotentHint(), toolAnnotations.openWorldHint(), null)); + } + + // Generate Output Schema from the method return type. + // Output schema is not generated for primitive types, void, + // CallToolResult, simple value types (String, etc.) + // or if generateOutputSchema attribute is set to false. + + if (toolAnnotation.generateOutputSchema() + && !ReactiveUtils.isReactiveReturnTypeOfVoid(mcpToolMethod) + && !ReactiveUtils.isReactiveReturnTypeOfCallToolResult(mcpToolMethod)) { + + ReactiveUtils.getReactiveReturnTypeArgument(mcpToolMethod).ifPresent(typeArgument -> { + Class methodReturnType = typeArgument instanceof Class ? (Class) typeArgument + : null; + if (!ClassUtils.isPrimitiveOrWrapper(methodReturnType) + && !ClassUtils.isSimpleValueType(methodReturnType)) { + toolBuilder + .outputSchema(JsonSchemaGenerator.generateFromClass((Class) typeArgument)); + } + }); + } + var tool = toolBuilder.build(); + + ReturnMode returnMode = tool.outputSchema() != null ? ReturnMode.STRUCTURED + : ReactiveUtils.isReactiveReturnTypeOfVoid(mcpToolMethod) ? ReturnMode.VOID + : ReturnMode.TEXT; + + BiFunction> methodCallback = new AsyncMcpToolMethodCallback( + returnMode, mcpToolMethod, toolObject); + + AsyncToolSpecification toolSpec = AsyncToolSpecification.builder() + .tool(tool) + .callHandler(methodCallback) + .build(); + + return toolSpec; + }) + .toList()) + .flatMap(List::stream) + .toList(); + + if (toolSpecs.isEmpty()) { + logger.warn("No tool methods found in the provided tool objects: {}", this.toolObjects); + } + + return toolSpecs; + } + + protected Method[] doGetClassMethods(Object bean) { + return bean.getClass().getDeclaredMethods(); + } + + protected McpTool doGetMcpToolAnnotation(Method method) { + return method.getAnnotation(McpTool.class); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncMcpResourceProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncMcpResourceProvider.java index 2e8e9f1..b3991ff 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncMcpResourceProvider.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncMcpResourceProvider.java @@ -42,7 +42,7 @@ public SyncMcpResourceProvider(List resourceObjects) { public List getResourceSpecifications() { List methodCallbacks = this.resourceObjects.stream() - .map(resourceObject -> Stream.of(doGetClassMethods(resourceObject)) + .map(resourceObject -> Stream.of(this.doGetClassMethods(resourceObject)) .filter(resourceMethod -> resourceMethod.isAnnotationPresent(McpResource.class)) .filter(method -> !Mono.class.isAssignableFrom(method.getReturnType())) .map(mcpResourceMethod -> { @@ -52,7 +52,13 @@ public List getResourceSpecifications() { var name = getName(mcpResourceMethod, resourceAnnotation); var description = resourceAnnotation.description(); var mimeType = resourceAnnotation.mimeType(); - var mcpResource = new McpSchema.Resource(uri, name, description, mimeType, null); + + var mcpResource = McpSchema.Resource.builder() + .uri(uri) + .name(name) + .description(description) + .mimeType(mimeType) + .build(); var methodCallback = SyncMcpResourceMethodCallback.builder() .method(mcpResourceMethod) diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncMcpToolProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncMcpToolProvider.java new file mode 100644 index 0000000..25cee26 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncMcpToolProvider.java @@ -0,0 +1,144 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.function.BiFunction; +import java.util.stream.Stream; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.method.tool.ReactiveUtils; +import org.springaicommunity.mcp.method.tool.ReturnMode; +import org.springaicommunity.mcp.method.tool.SyncMcpToolMethodCallback; +import org.springaicommunity.mcp.method.tool.utils.ClassUtils; +import org.springaicommunity.mcp.method.tool.utils.JsonSchemaGenerator; + +import io.modelcontextprotocol.server.McpServerFeatures.SyncToolSpecification; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import reactor.core.publisher.Mono; + +/** + * @author Christian Tzolov + */ +public class SyncMcpToolProvider { + + private static final Logger logger = LoggerFactory.getLogger(SyncMcpToolProvider.class); + + private final List toolObjects; + + /** + * Create a new SyncMcpToolProvider. + * @param toolObjects the objects containing methods annotated with {@link McpTool} + */ + public SyncMcpToolProvider(List toolObjects) { + Assert.notNull(toolObjects, "toolObjects cannot be null"); + this.toolObjects = toolObjects; + } + + /** + * Get the tool handler. + * @return the tool handler + * @throws IllegalStateException if no tool methods are found or if multiple tool + * methods are found + */ + public List getToolSpecifications() { + + List toolSpecs = this.toolObjects.stream() + .map(toolObject -> Stream.of(doGetClassMethods(toolObject)) + .filter(method -> method.isAnnotationPresent(McpTool.class)) + .filter(method -> !Mono.class.isAssignableFrom(method.getReturnType())) + .map(mcpToolMethod -> { + + var toolAnnotation = doGetMcpToolAnnotation(mcpToolMethod); + + String toolName = Utils.hasText(toolAnnotation.name()) ? toolAnnotation.name() + : mcpToolMethod.getName(); + + String toolDescrption = toolAnnotation.description(); + + String inputSchema = JsonSchemaGenerator.generateForMethodInput(mcpToolMethod); + + var toolBuilder = McpSchema.Tool.builder() + .name(toolName) + .description(toolDescrption) + .inputSchema(inputSchema); + + // Tool annotations + if (toolAnnotation.annotations() != null) { + var toolAnnotations = toolAnnotation.annotations(); + toolBuilder.annotations(new McpSchema.ToolAnnotations(toolAnnotations.title(), + toolAnnotations.readOnlyHint(), toolAnnotations.destructiveHint(), + toolAnnotations.idempotentHint(), toolAnnotations.openWorldHint(), null)); + } + + ReactiveUtils.isReactiveReturnTypeOfCallToolResult(mcpToolMethod); + // Generate Output Schema from the method return type. + // Output schema is not generated for primitive types, void, + // CallToolResult, simple value types (String, etc.) + // or if generateOutputSchema attribute is set to false. + Class methodReturnType = mcpToolMethod.getReturnType(); + if (toolAnnotation.generateOutputSchema() && methodReturnType != null + && methodReturnType != CallToolResult.class && methodReturnType != Void.class + && methodReturnType != void.class && !ClassUtils.isPrimitiveOrWrapper(methodReturnType) + && !ClassUtils.isSimpleValueType(methodReturnType)) { + + toolBuilder.outputSchema(JsonSchemaGenerator.generateFromClass(methodReturnType)); + } + + var tool = toolBuilder.build(); + + boolean useStructuredOtput = tool.outputSchema() != null; + + ReturnMode returnMode = useStructuredOtput ? ReturnMode.STRUCTURED + : (methodReturnType == Void.TYPE || methodReturnType == void.class ? ReturnMode.VOID + : ReturnMode.TEXT); + + BiFunction methodCallback = new SyncMcpToolMethodCallback( + returnMode, mcpToolMethod, toolObject); + + var toolSpec = SyncToolSpecification.builder().tool(tool).callHandler(methodCallback).build(); + + return toolSpec; + }) + .toList()) + .flatMap(List::stream) + .toList(); + + if (toolSpecs.isEmpty()) { + logger.warn("No tool methods found in the provided tool objects: {}", this.toolObjects); + } + + return toolSpecs; + } + + protected Method[] doGetClassMethods(Object bean) { + return bean.getClass().getDeclaredMethods(); + } + + protected McpTool doGetMcpToolAnnotation(Method method) { + return method.getAnnotation(McpTool.class); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/SyncMcpResourceMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/SyncMcpResourceMethodCallbackTests.java index cc67d1e..49d64c2 100644 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/SyncMcpResourceMethodCallbackTests.java +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/SyncMcpResourceMethodCallbackTests.java @@ -4,24 +4,24 @@ package org.springaicommunity.mcp.method.resource; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + import java.lang.reflect.Method; import java.util.List; import java.util.function.BiFunction; +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpResource; +import org.springaicommunity.mcp.annotation.ResourceAdaptor; + import io.modelcontextprotocol.server.McpSyncServerExchange; import io.modelcontextprotocol.spec.McpSchema.BlobResourceContents; import io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest; import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; import io.modelcontextprotocol.spec.McpSchema.ResourceContents; import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; -import org.junit.jupiter.api.Test; -import org.springaicommunity.mcp.annotation.McpResource; -import org.springaicommunity.mcp.annotation.ResourceAdaptor; -import org.springaicommunity.mcp.method.resource.SyncMcpResourceMethodCallback; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.mock; /** * Tests for {@link SyncMcpResourceMethodCallback}. @@ -140,12 +140,12 @@ public String uri() { @Override public String name() { - return ""; + return "testResource"; } @Override public String description() { - return ""; + return "Test resource description"; } @Override @@ -410,17 +410,17 @@ public String uri() { @Override public String name() { - return ""; + return "testResourceWithExtraVariables"; } @Override public String description() { - return ""; + return "Test resource with extra URI variables"; } @Override public String mimeType() { - return ""; + return "text/plain"; } }; diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallbackTests.java new file mode 100644 index 0000000..b39c07f --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallbackTests.java @@ -0,0 +1,736 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.tool; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Publisher; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.annotation.McpToolParam; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link AsyncMcpToolMethodCallback}. + * + * @author Christian Tzolov + */ +public class AsyncMcpToolMethodCallbackTests { + + private static class TestAsyncToolProvider { + + @McpTool(name = "simple-mono-tool", description = "A simple mono tool") + public Mono simpleMonoTool(String input) { + return Mono.just("Processed: " + input); + } + + @McpTool(name = "simple-flux-tool", description = "A simple flux tool") + public Flux simpleFluxTool(String input) { + return Flux.just("Processed: " + input); + } + + @McpTool(name = "simple-publisher-tool", description = "A simple publisher tool") + public Publisher simplePublisherTool(String input) { + return Mono.just("Processed: " + input); + } + + @McpTool(name = "math-mono-tool", description = "A math mono tool") + public Mono addNumbersMono(int a, int b) { + return Mono.just(a + b); + } + + @McpTool(name = "complex-mono-tool", description = "A complex mono tool") + public Mono complexMonoTool(String name, int age, boolean active) { + return Mono.just(CallToolResult.builder() + .addTextContent("Name: " + name + ", Age: " + age + ", Active: " + active) + .build()); + } + + @McpTool(name = "complex-flux-tool", description = "A complex flux tool") + public Flux complexFluxTool(String name, int age, boolean active) { + return Flux.just(CallToolResult.builder() + .addTextContent("Name: " + name + ", Age: " + age + ", Active: " + active) + .build()); + } + + @McpTool(name = "exchange-mono-tool", description = "Mono tool with exchange parameter") + public Mono monoToolWithExchange(McpAsyncServerExchange exchange, String message) { + return Mono.just("Exchange tool: " + message); + } + + @McpTool(name = "list-mono-tool", description = "Mono tool with list parameter") + public Mono processListMono(List items) { + return Mono.just("Items: " + String.join(", ", items)); + } + + @McpTool(name = "object-mono-tool", description = "Mono tool with object parameter") + public Mono processObjectMono(TestObject obj) { + return Mono.just("Object: " + obj.name + " - " + obj.value); + } + + @McpTool(name = "optional-params-mono-tool", description = "Mono tool with optional parameters") + public Mono monoToolWithOptionalParams(@McpToolParam(required = true) String required, + @McpToolParam(required = false) String optional) { + return Mono.just("Required: " + required + ", Optional: " + (optional != null ? optional : "null")); + } + + @McpTool(name = "no-params-mono-tool", description = "Mono tool with no parameters") + public Mono noParamsMonoTool() { + return Mono.just("No parameters needed"); + } + + @McpTool(name = "exception-mono-tool", description = "Mono tool that throws exception") + public Mono exceptionMonoTool(String input) { + return Mono.error(new RuntimeException("Tool execution failed: " + input)); + } + + @McpTool(name = "null-return-mono-tool", description = "Mono tool that returns null") + public Mono nullReturnMonoTool() { + return Mono.just((String) null); + } + + @McpTool(name = "void-mono-tool", description = "Mono tool") + public Mono voidMonoTool(String input) { + return Mono.empty(); + } + + @McpTool(name = "void-flux-tool", description = "Flux tool") + public Flux voidFluxTool(String input) { + return Flux.empty(); + } + + @McpTool(name = "enum-mono-tool", description = "Mono tool with enum parameter") + public Mono enumMonoTool(TestEnum enumValue) { + return Mono.just("Enum: " + enumValue.name()); + } + + @McpTool(name = "primitive-types-mono-tool", description = "Mono tool with primitive types") + public Mono primitiveTypesMonoTool(boolean flag, byte b, short s, int i, long l, float f, double d) { + return Mono.just(String.format("Primitives: %b, %d, %d, %d, %d, %.1f, %.1f", flag, b, s, i, l, f, d)); + } + + @McpTool(name = "return-object-mono-tool", description = "Mono tool that returns a complex object") + public Mono returnObjectMonoTool(String name, int value) { + return Mono.just(new TestObject(name, value)); + } + + @McpTool(name = "delayed-mono-tool", description = "Mono tool with delay") + public Mono delayedMonoTool(String input) { + return Mono.just("Delayed: " + input); + } + + @McpTool(name = "empty-mono-tool", description = "Mono tool that returns empty") + public Mono emptyMonoTool() { + return Mono.empty(); + } + + @McpTool(name = "multiple-flux-tool", description = "Flux tool that emits multiple values") + public Flux multipleFluxTool(String prefix) { + return Flux.just(prefix + "1", prefix + "2", prefix + "3"); + } + + @McpTool(name = "private-mono-tool", description = "Private mono tool") + private Mono privateMonoTool(String input) { + return Mono.just("Private: " + input); + } + + // Non-reactive method that should cause error in async context + @McpTool(name = "non-reactive-tool", description = "Non-reactive tool") + public String nonReactiveTool(String input) { + return "Non-reactive: " + input; + } + + } + + public static class TestObject { + + public String name; + + public int value; + + public TestObject() { + } + + public TestObject(String name, int value) { + this.name = name; + this.value = value; + } + + } + + public enum TestEnum { + + OPTION_A, OPTION_B, OPTION_C + + } + + @Test + public void testSimpleMonoToolCallback() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("simpleMonoTool", String.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("simple-mono-tool", Map.of("input", "test message")); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Processed: test message"); + }).verifyComplete(); + } + + @Test + public void testSimpleFluxToolCallback() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("simpleFluxTool", String.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("simple-flux-tool", Map.of("input", "test message")); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Processed: test message"); + }).verifyComplete(); + } + + @Test + public void testSimplePublisherToolCallback() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("simplePublisherTool", String.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("simple-publisher-tool", Map.of("input", "test message")); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Processed: test message"); + }).verifyComplete(); + } + + @Test + public void testMathMonoToolCallback() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("addNumbersMono", int.class, int.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("math-mono-tool", Map.of("a", 5, "b", 3)); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("8"); + }).verifyComplete(); + } + + @Test + public void testMonoToolThatThrowsException() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("exceptionMonoTool", String.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("exception-mono-tool", Map.of("input", "test")); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).contains("Error invoking method"); + }).verifyComplete(); + } + + @Test + public void testComplexFluxToolCallback() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("complexFluxTool", String.class, int.class, + boolean.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("complex-flux-tool", + Map.of("name", "Alice", "age", 25, "active", false)); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Name: Alice, Age: 25, Active: false"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithExchangeParameter() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("monoToolWithExchange", McpAsyncServerExchange.class, + String.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("exchange-mono-tool", Map.of("message", "hello")); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Exchange tool: hello"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithListParameter() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("processListMono", List.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("list-mono-tool", + Map.of("items", List.of("item1", "item2", "item3"))); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Items: item1, item2, item3"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithObjectParameter() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("processObjectMono", TestObject.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("object-mono-tool", + Map.of("obj", Map.of("name", "test", "value", 42))); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Object: test - 42"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithNoParameters() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("noParamsMonoTool"); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("no-params-mono-tool", Map.of()); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("No parameters needed"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithEnumParameter() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("enumMonoTool", TestEnum.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("enum-mono-tool", Map.of("enumValue", "OPTION_B")); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Enum: OPTION_B"); + }).verifyComplete(); + } + + @Test + public void testComplexMonoToolCallback() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("complexMonoTool", String.class, int.class, + boolean.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("complex-mono-tool", + Map.of("name", "John", "age", 30, "active", true)); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Name: John, Age: 30, Active: true"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithMissingParameters() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("simpleMonoTool", String.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("simple-mono-tool", Map.of()); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Processed: null"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithPrimitiveTypes() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("primitiveTypesMonoTool", boolean.class, byte.class, + short.class, int.class, long.class, float.class, double.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("primitive-types-mono-tool", + Map.of("flag", true, "b", 1, "s", 2, "i", 3, "l", 4L, "f", 5.5f, "d", 6.6)); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()) + .isEqualTo("Primitives: true, 1, 2, 3, 4, 5.5, 6.6"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithNullParameters() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("simpleMonoTool", String.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + Map args = new java.util.HashMap<>(); + args.put("input", null); + CallToolRequest request = new CallToolRequest("simple-mono-tool", args); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Processed: null"); + }).verifyComplete(); + } + + @Test + public void testMonoToolThatReturnsNull() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("nullReturnMonoTool"); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("null-return-mono-tool", Map.of()); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()) + .isEqualTo("Error invoking method: Error invoking method: nullReturnMonoTool"); + }).verifyComplete(); + } + + @Test + public void testVoidMonoTool() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("voidMonoTool", String.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.VOID, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("void-mono-tool", Map.of("input", "test")); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("\"Done\""); + }).verifyComplete(); + } + + @Test + public void testVoidFluxTool() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("voidFluxTool", String.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.VOID, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("void-flux-tool", Map.of("input", "test")); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("\"Done\""); + }).verifyComplete(); + } + + @Test + public void testPrivateMonoToolMethod() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getDeclaredMethod("privateMonoTool", String.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("private-mono-tool", Map.of("input", "test")); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Private: test"); + }).verifyComplete(); + } + + @Test + public void testNullRequest() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("simpleMonoTool", String.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + + StepVerifier.create(callback.apply(exchange, null)).expectError(IllegalArgumentException.class).verify(); + } + + @Test + public void testMonoToolReturningComplexObject() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("returnObjectMonoTool", String.class, int.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.STRUCTURED, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("return-object-mono-tool", Map.of("name", "test", "value", 42)); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).isEmpty(); + assertThat(result.structuredContent()).isNotNull(); + assertThat(result.structuredContent()).containsEntry("name", "test"); + assertThat(result.structuredContent()).containsEntry("value", 42); + }).verifyComplete(); + } + + @Test + public void testEmptyMonoTool() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("emptyMonoTool"); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("empty-mono-tool", Map.of()); + + StepVerifier.create(callback.apply(exchange, request)).verifyComplete(); + } + + @Test + public void testMultipleFluxTool() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("multipleFluxTool", String.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("multiple-flux-tool", Map.of("prefix", "item")); + + // Flux tools should take the first element + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("item1"); + }).verifyComplete(); + } + + @Test + public void testNonReactiveToolShouldFail() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("nonReactiveTool", String.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("non-reactive-tool", Map.of("input", "test")); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).contains("Error invoking method"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithInvalidJsonConversion() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("processObjectMono", TestObject.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + // Pass invalid object structure that can't be converted to TestObject + CallToolRequest request = new CallToolRequest("object-mono-tool", Map.of("obj", "invalid-object-string")); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).contains("Error invoking method"); + }).verifyComplete(); + } + + @Test + public void testConstructorParameters() { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethods()[0]; // Any method + + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + // Verify that the callback was created successfully + assertThat(callback).isNotNull(); + } + + @Test + public void testIsExchangeType() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("simpleMonoTool", String.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + // Test that McpAsyncServerExchange is recognized as exchange type + assertThat(callback.isExchangeType(McpAsyncServerExchange.class)).isTrue(); + + // Test that other types are not recognized as exchange type + assertThat(callback.isExchangeType(String.class)).isFalse(); + assertThat(callback.isExchangeType(Integer.class)).isFalse(); + assertThat(callback.isExchangeType(Object.class)).isFalse(); + } + + @Test + public void testMonoToolWithOptionalParameters() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("monoToolWithOptionalParams", String.class, String.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("optional-params-mono-tool", + Map.of("required", "test", "optional", "optional-value")); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()) + .isEqualTo("Required: test, Optional: optional-value"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithOptionalParametersMissing() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("monoToolWithOptionalParams", String.class, String.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("optional-params-mono-tool", Map.of("required", "test")); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Required: test, Optional: null"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithStructuredOutput() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("processObjectMono", TestObject.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("object-mono-tool", + Map.of("obj", Map.of("name", "test", "value", 42))); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Object: test - 42"); + }).verifyComplete(); + } + + @Test + public void testCallbackReturnsCallToolResult() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("complexMonoTool", String.class, int.class, + boolean.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("complex-mono-tool", + Map.of("name", "Alice", "age", 25, "active", false)); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Name: Alice, Age: 25, Active: false"); + }).verifyComplete(); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallbackTests.java new file mode 100644 index 0000000..9d0661a --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallbackTests.java @@ -0,0 +1,510 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.tool; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.annotation.McpToolParam; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link SyncMcpToolMethodCallback}. + * + * @author Christian Tzolov + */ +public class SyncMcpToolMethodCallbackTests { + + private static class TestToolProvider { + + @McpTool(name = "simple-tool", description = "A simple tool") + public String simpleTool(String input) { + return "Processed: " + input; + } + + @McpTool(name = "math-tool", description = "A math tool") + public int addNumbers(int a, int b) { + return a + b; + } + + @McpTool(name = "complex-tool", description = "A complex tool") + public CallToolResult complexTool(String name, int age, boolean active) { + return CallToolResult.builder() + .addTextContent("Name: " + name + ", Age: " + age + ", Active: " + active) + .build(); + } + + @McpTool(name = "exchange-tool", description = "Tool with exchange parameter") + public String toolWithExchange(McpSyncServerExchange exchange, String message) { + return "Exchange tool: " + message; + } + + @McpTool(name = "list-tool", description = "Tool with list parameter") + public String processList(List items) { + return "Items: " + String.join(", ", items); + } + + @McpTool(name = "object-tool", description = "Tool with object parameter") + public String processObject(TestObject obj) { + return "Object: " + obj.name + " - " + obj.value; + } + + @McpTool(name = "optional-params-tool", description = "Tool with optional parameters") + public String toolWithOptionalParams(@McpToolParam(required = true) String required, + @McpToolParam(required = false) String optional) { + return "Required: " + required + ", Optional: " + (optional != null ? optional : "null"); + } + + @McpTool(name = "no-params-tool", description = "Tool with no parameters") + public String noParamsTool() { + return "No parameters needed"; + } + + @McpTool(name = "exception-tool", description = "Tool that throws exception") + public String exceptionTool(String input) { + throw new RuntimeException("Tool execution failed: " + input); + } + + @McpTool(name = "null-return-tool", description = "Tool that returns null") + public String nullReturnTool() { + return null; + } + + public String nonAnnotatedTool(String input) { + return "Non-annotated: " + input; + } + + @McpTool(name = "private-tool", description = "Private tool") + private String privateTool(String input) { + return "Private: " + input; + } + + @McpTool(name = "enum-tool", description = "Tool with enum parameter") + public String enumTool(TestEnum enumValue) { + return "Enum: " + enumValue.name(); + } + + @McpTool(name = "primitive-types-tool", description = "Tool with primitive types") + public String primitiveTypesTool(boolean flag, byte b, short s, int i, long l, float f, double d) { + return String.format("Primitives: %b, %d, %d, %d, %d, %.1f, %.1f", flag, b, s, i, l, f, d); + } + + @McpTool(name = "return-object-tool", description = "Tool that returns a complex object") + public TestObject returnObjectTool(String name, int value) { + return new TestObject(name, value); + } + + } + + public static class TestObject { + + public String name; + + public int value; + + public TestObject() { + } + + public TestObject(String name, int value) { + this.name = name; + this.value = value; + } + + } + + public enum TestEnum { + + OPTION_A, OPTION_B, OPTION_C + + } + + @Test + public void testSimpleToolCallback() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("simpleTool", String.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("simple-tool", Map.of("input", "test message")); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Processed: test message"); + } + + @Test + public void testMathToolCallback() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("addNumbers", int.class, int.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("math-tool", Map.of("a", 5, "b", 3)); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("8"); + } + + @Test + public void testComplexToolCallback() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("complexTool", String.class, int.class, boolean.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("complex-tool", + Map.of("name", "John", "age", 30, "active", true)); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Name: John, Age: 30, Active: true"); + } + + @Test + public void testToolWithExchangeParameter() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("toolWithExchange", McpSyncServerExchange.class, String.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("exchange-tool", Map.of("message", "hello")); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Exchange tool: hello"); + } + + @Test + public void testToolWithListParameter() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("processList", List.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("list-tool", Map.of("items", List.of("item1", "item2", "item3"))); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Items: item1, item2, item3"); + } + + @Test + public void testToolWithObjectParameter() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("processObject", TestObject.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("object-tool", + Map.of("obj", Map.of("name", "test", "value", 42))); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Object: test - 42"); + } + + @Test + public void testToolWithNoParameters() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("noParamsTool"); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("no-params-tool", Map.of()); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("No parameters needed"); + } + + @Test + public void testToolWithEnumParameter() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("enumTool", TestEnum.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("enum-tool", Map.of("enumValue", "OPTION_B")); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Enum: OPTION_B"); + } + + @Test + public void testToolWithPrimitiveTypes() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("primitiveTypesTool", boolean.class, byte.class, short.class, + int.class, long.class, float.class, double.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("primitive-types-tool", + Map.of("flag", true, "b", 1, "s", 2, "i", 3, "l", 4L, "f", 5.5f, "d", 6.6)); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Primitives: true, 1, 2, 3, 4, 5.5, 6.6"); + } + + @Test + public void testToolWithNullParameters() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("simpleTool", String.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + Map args = new java.util.HashMap<>(); + args.put("input", null); + CallToolRequest request = new CallToolRequest("simple-tool", args); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Processed: null"); + } + + @Test + public void testToolWithMissingParameters() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("simpleTool", String.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("simple-tool", Map.of()); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Processed: null"); + } + + @Test + public void testToolThatThrowsException() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("exceptionTool", String.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("exception-tool", Map.of("input", "test")); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).contains("Error invoking method"); + // The actual error message format may be different, so let's just check for the + // method name + assertThat(((TextContent) result.content().get(0)).text()).contains("exceptionTool"); + } + + @Test + public void testToolThatReturnsNull() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("nullReturnTool"); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("null-return-tool", Map.of()); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("null"); + } + + @Test + public void testPrivateToolMethod() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getDeclaredMethod("privateTool", String.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("private-tool", Map.of("input", "test")); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Private: test"); + } + + @Test + public void testNullRequest() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("simpleTool", String.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + + assertThatThrownBy(() -> callback.apply(exchange, null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Request must not be null"); + } + + @Test + public void testCallbackReturnsCallToolResult() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("complexTool", String.class, int.class, boolean.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("complex-tool", + Map.of("name", "Alice", "age", 25, "active", false)); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Name: Alice, Age: 25, Active: false"); + } + + @Test + public void testIsExchangeType() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("simpleTool", String.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + // Test that McpSyncServerExchange is recognized as exchange type + assertThat(callback.isExchangeType(McpSyncServerExchange.class)).isTrue(); + + // Test that other types are not recognized as exchange type + assertThat(callback.isExchangeType(String.class)).isFalse(); + assertThat(callback.isExchangeType(Integer.class)).isFalse(); + assertThat(callback.isExchangeType(Object.class)).isFalse(); + } + + @Test + public void testToolWithInvalidJsonConversion() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("processObject", TestObject.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + // Pass invalid object structure that can't be converted to TestObject + CallToolRequest request = new CallToolRequest("object-tool", Map.of("obj", "invalid-object-string")); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).contains("Error invoking method"); + } + + @Test + public void testConstructorParameters() { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethods()[0]; // Any method + + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + // Verify that the callback was created successfully + assertThat(callback).isNotNull(); + } + + @Test + public void testToolWithStructuredOutput() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("processObject", TestObject.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("object-tool", + Map.of("obj", Map.of("name", "test", "value", 42))); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Object: test - 42"); + } + + @Test + public void testToolReturningComplexObject() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("returnObjectTool", String.class, int.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.STRUCTURED, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("return-object-tool", Map.of("name", "test", "value", 42)); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + // For complex return types (non-primitive, non-wrapper, non-CallToolResult), + // the new implementation should return structured content + assertThat(result.content()).isEmpty(); + assertThat(result.structuredContent()).isNotNull(); + assertThat(result.structuredContent()).containsEntry("name", "test"); + assertThat(result.structuredContent()).containsEntry("value", 42); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/SyncMcpToolProviderTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/SyncMcpToolProviderTests.java new file mode 100644 index 0000000..6869f63 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/SyncMcpToolProviderTests.java @@ -0,0 +1,427 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpTool; + +import io.modelcontextprotocol.server.McpServerFeatures.SyncToolSpecification; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import reactor.core.publisher.Mono; + +/** + * Tests for {@link SyncMcpToolProvider}. + * + * @author Christian Tzolov + */ +public class SyncMcpToolProviderTests { + + @Test + void testConstructorWithNullToolObjects() { + assertThatThrownBy(() -> new SyncMcpToolProvider(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("toolObjects cannot be null"); + } + + @Test + void testGetToolSpecificationsWithSingleValidTool() { + // Create a class with only one valid tool method + class SingleValidTool { + + @McpTool(name = "test-tool", description = "A test tool") + public String testTool(String input) { + return "Processed: " + input; + } + + } + + SingleValidTool toolObject = new SingleValidTool(); + SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(toolObject)); + + List toolSpecs = provider.getToolSpecifications(); + + assertThat(toolSpecs).isNotNull(); + assertThat(toolSpecs).hasSize(1); + + SyncToolSpecification toolSpec = toolSpecs.get(0); + assertThat(toolSpec.tool().name()).isEqualTo("test-tool"); + assertThat(toolSpec.tool().description()).isEqualTo("A test tool"); + assertThat(toolSpec.tool().inputSchema()).isNotNull(); + assertThat(toolSpec.callHandler()).isNotNull(); + + // Test that the handler works + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("test-tool", Map.of("input", "hello")); + CallToolResult result = toolSpec.callHandler().apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Processed: hello"); + } + + @Test + void testGetToolSpecificationsWithCustomToolName() { + class CustomNameTool { + + @McpTool(name = "custom-name", description = "Custom named tool") + public String methodWithDifferentName(String input) { + return "Custom: " + input; + } + + } + + CustomNameTool toolObject = new CustomNameTool(); + SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(toolObject)); + + List toolSpecs = provider.getToolSpecifications(); + + assertThat(toolSpecs).hasSize(1); + assertThat(toolSpecs.get(0).tool().name()).isEqualTo("custom-name"); + assertThat(toolSpecs.get(0).tool().description()).isEqualTo("Custom named tool"); + } + + @Test + void testGetToolSpecificationsWithDefaultToolName() { + class DefaultNameTool { + + @McpTool(description = "Tool with default name") + public String defaultNameMethod(String input) { + return "Default: " + input; + } + + } + + DefaultNameTool toolObject = new DefaultNameTool(); + SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(toolObject)); + + List toolSpecs = provider.getToolSpecifications(); + + assertThat(toolSpecs).hasSize(1); + assertThat(toolSpecs.get(0).tool().name()).isEqualTo("defaultNameMethod"); + assertThat(toolSpecs.get(0).tool().description()).isEqualTo("Tool with default name"); + } + + @Test + void testGetToolSpecificationsWithEmptyToolName() { + class EmptyNameTool { + + @McpTool(name = "", description = "Tool with empty name") + public String emptyNameMethod(String input) { + return "Empty: " + input; + } + + } + + EmptyNameTool toolObject = new EmptyNameTool(); + SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(toolObject)); + + List toolSpecs = provider.getToolSpecifications(); + + assertThat(toolSpecs).hasSize(1); + assertThat(toolSpecs.get(0).tool().name()).isEqualTo("emptyNameMethod"); + assertThat(toolSpecs.get(0).tool().description()).isEqualTo("Tool with empty name"); + } + + @Test + void testGetToolSpecificationsFiltersOutMonoReturnTypes() { + class MonoReturnTool { + + @McpTool(name = "mono-tool", description = "Tool returning Mono") + public Mono monoTool(String input) { + return Mono.just("Mono: " + input); + } + + @McpTool(name = "sync-tool", description = "Synchronous tool") + public String syncTool(String input) { + return "Sync: " + input; + } + + } + + MonoReturnTool toolObject = new MonoReturnTool(); + SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(toolObject)); + + List toolSpecs = provider.getToolSpecifications(); + + assertThat(toolSpecs).hasSize(1); + assertThat(toolSpecs.get(0).tool().name()).isEqualTo("sync-tool"); + assertThat(toolSpecs.get(0).tool().description()).isEqualTo("Synchronous tool"); + } + + @Test + void testGetToolSpecificationsWithMultipleToolMethods() { + class MultipleToolMethods { + + @McpTool(name = "tool1", description = "First tool") + public String firstTool(String input) { + return "First: " + input; + } + + @McpTool(name = "tool2", description = "Second tool") + public String secondTool(String input) { + return "Second: " + input; + } + + } + + MultipleToolMethods toolObject = new MultipleToolMethods(); + SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(toolObject)); + + List toolSpecs = provider.getToolSpecifications(); + + assertThat(toolSpecs).hasSize(2); + assertThat(toolSpecs.get(0).tool().name()).isIn("tool1", "tool2"); + assertThat(toolSpecs.get(1).tool().name()).isIn("tool1", "tool2"); + assertThat(toolSpecs.get(0).tool().name()).isNotEqualTo(toolSpecs.get(1).tool().name()); + } + + @Test + void testGetToolSpecificationsWithMultipleToolObjects() { + class FirstToolObject { + + @McpTool(name = "first-tool", description = "First tool") + public String firstTool(String input) { + return "First: " + input; + } + + } + + class SecondToolObject { + + @McpTool(name = "second-tool", description = "Second tool") + public String secondTool(String input) { + return "Second: " + input; + } + + } + + FirstToolObject firstObject = new FirstToolObject(); + SecondToolObject secondObject = new SecondToolObject(); + SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(firstObject, secondObject)); + + List toolSpecs = provider.getToolSpecifications(); + + assertThat(toolSpecs).hasSize(2); + assertThat(toolSpecs.get(0).tool().name()).isIn("first-tool", "second-tool"); + assertThat(toolSpecs.get(1).tool().name()).isIn("first-tool", "second-tool"); + assertThat(toolSpecs.get(0).tool().name()).isNotEqualTo(toolSpecs.get(1).tool().name()); + } + + @Test + void testGetToolSpecificationsWithMixedMethods() { + class MixedMethods { + + @McpTool(name = "valid-tool", description = "Valid tool") + public String validTool(String input) { + return "Valid: " + input; + } + + public String nonAnnotatedMethod(String input) { + return "Non-annotated: " + input; + } + + @McpTool(name = "mono-tool", description = "Mono tool") + public Mono monoTool(String input) { + return Mono.just("Mono: " + input); + } + + } + + MixedMethods toolObject = new MixedMethods(); + SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(toolObject)); + + List toolSpecs = provider.getToolSpecifications(); + + assertThat(toolSpecs).hasSize(1); + assertThat(toolSpecs.get(0).tool().name()).isEqualTo("valid-tool"); + assertThat(toolSpecs.get(0).tool().description()).isEqualTo("Valid tool"); + } + + @Test + void testGetToolSpecificationsWithComplexParameters() { + class ComplexParameterTool { + + @McpTool(name = "complex-tool", description = "Tool with complex parameters") + public String complexTool(String name, int age, boolean active, List tags) { + return String.format("Name: %s, Age: %d, Active: %b, Tags: %s", name, age, active, + String.join(",", tags)); + } + + } + + ComplexParameterTool toolObject = new ComplexParameterTool(); + SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(toolObject)); + + List toolSpecs = provider.getToolSpecifications(); + + assertThat(toolSpecs).hasSize(1); + assertThat(toolSpecs.get(0).tool().name()).isEqualTo("complex-tool"); + assertThat(toolSpecs.get(0).tool().description()).isEqualTo("Tool with complex parameters"); + assertThat(toolSpecs.get(0).tool().inputSchema()).isNotNull(); + + // Test that the handler works with complex parameters + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("complex-tool", + Map.of("name", "John", "age", 30, "active", true, "tags", List.of("tag1", "tag2"))); + CallToolResult result = toolSpecs.get(0).callHandler().apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()) + .isEqualTo("Name: John, Age: 30, Active: true, Tags: tag1,tag2"); + } + + @Test + void testGetToolSpecificationsWithNoParameters() { + class NoParameterTool { + + @McpTool(name = "no-param-tool", description = "Tool with no parameters") + public String noParamTool() { + return "No parameters needed"; + } + + } + + NoParameterTool toolObject = new NoParameterTool(); + SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(toolObject)); + + List toolSpecs = provider.getToolSpecifications(); + + assertThat(toolSpecs).hasSize(1); + assertThat(toolSpecs.get(0).tool().name()).isEqualTo("no-param-tool"); + assertThat(toolSpecs.get(0).tool().description()).isEqualTo("Tool with no parameters"); + + // Test that the handler works with no parameters + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("no-param-tool", Map.of()); + CallToolResult result = toolSpecs.get(0).callHandler().apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("No parameters needed"); + } + + @Test + void testGetToolSpecificationsWithCallToolResultReturn() { + class CallToolResultTool { + + @McpTool(name = "result-tool", description = "Tool returning CallToolResult") + public CallToolResult resultTool(String message) { + return CallToolResult.builder().addTextContent("Result: " + message).build(); + } + + } + + CallToolResultTool toolObject = new CallToolResultTool(); + SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(toolObject)); + + List toolSpecs = provider.getToolSpecifications(); + + assertThat(toolSpecs).hasSize(1); + assertThat(toolSpecs.get(0).tool().name()).isEqualTo("result-tool"); + assertThat(toolSpecs.get(0).tool().description()).isEqualTo("Tool returning CallToolResult"); + + // Test that the handler works with CallToolResult return type + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("result-tool", Map.of("message", "test")); + CallToolResult result = toolSpecs.get(0).callHandler().apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Result: test"); + } + + @Test + void testGetToolSpecificationsWithPrivateMethod() { + class PrivateMethodTool { + + @McpTool(name = "private-tool", description = "Private tool method") + private String privateTool(String input) { + return "Private: " + input; + } + + } + + PrivateMethodTool toolObject = new PrivateMethodTool(); + SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(toolObject)); + + List toolSpecs = provider.getToolSpecifications(); + + assertThat(toolSpecs).hasSize(1); + assertThat(toolSpecs.get(0).tool().name()).isEqualTo("private-tool"); + assertThat(toolSpecs.get(0).tool().description()).isEqualTo("Private tool method"); + + // Test that the handler works with private methods + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("private-tool", Map.of("input", "test")); + CallToolResult result = toolSpecs.get(0).callHandler().apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Private: test"); + } + + @Test + void testGetToolSpecificationsJsonSchemaGeneration() { + class SchemaTestTool { + + @McpTool(name = "schema-tool", description = "Tool for schema testing") + public String schemaTool(String requiredParam, Integer optionalParam) { + return "Schema test: " + requiredParam + ", " + optionalParam; + } + + } + + SchemaTestTool toolObject = new SchemaTestTool(); + SyncMcpToolProvider provider = new SyncMcpToolProvider(List.of(toolObject)); + + List toolSpecs = provider.getToolSpecifications(); + + assertThat(toolSpecs).hasSize(1); + SyncToolSpecification toolSpec = toolSpecs.get(0); + + assertThat(toolSpec.tool().name()).isEqualTo("schema-tool"); + assertThat(toolSpec.tool().description()).isEqualTo("Tool for schema testing"); + assertThat(toolSpec.tool().inputSchema()).isNotNull(); + + // The input schema should be a valid JSON string containing parameter names + String schemaString = toolSpec.tool().inputSchema().toString(); + assertThat(schemaString).isNotEmpty(); + assertThat(schemaString).contains("requiredParam"); + assertThat(schemaString).contains("optionalParam"); + } + +}