diff --git a/README.md b/README.md index c53e4cc..d8a2d0a 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,9 @@ The Spring integration module provides seamless integration with Spring AI and S - **`@McpTool`** - Annotates methods that implement MCP tools with automatic JSON schema generation - **`@McpToolParam`** - Annotates tool method parameters with descriptions and requirement specifications +#### Special Parameter Annotations +- **`@McpProgressToken`** - Marks a method parameter to receive the progress token from the request. This parameter is automatically injected and excluded from the generated JSON schema + ### Method Callbacks The modules provide callback implementations for each operation type: @@ -527,6 +530,114 @@ This feature works with all tool callback types: - `SyncStatelessMcpToolMethodCallback` - Synchronous stateless - `AsyncStatelessMcpToolMethodCallback` - Asynchronous stateless +#### @McpProgressToken Support + +The `@McpProgressToken` annotation allows methods to receive progress tokens from MCP requests. This is useful for tracking long-running operations and providing progress updates to clients. + +When a method parameter is annotated with `@McpProgressToken`: +- The parameter automatically receives the progress token value from the request +- The parameter is excluded from the generated JSON schema +- The parameter type should be `String` to receive the token value +- If no progress token is present in the request, `null` is injected + +Example usage with tools: + +```java +@McpTool(name = "long-running-task", description = "Performs a long-running task with progress tracking") +public String performLongTask( + @McpProgressToken String progressToken, + @McpToolParam(description = "Task name", required = true) String taskName, + @McpToolParam(description = "Duration in seconds", required = true) int duration) { + + // Use the progress token to send progress updates + if (progressToken != null) { + // Send progress notifications using the token + sendProgressUpdate(progressToken, 0.0, "Starting task: " + taskName); + + // Simulate work with progress updates + for (int i = 1; i <= duration; i++) { + Thread.sleep(1000); + double progress = (double) i / duration; + sendProgressUpdate(progressToken, progress, "Processing... " + (i * 100 / duration) + "%"); + } + } + + return "Task " + taskName + " completed successfully"; +} + +// Tool with both CallToolRequest and progress token +@McpTool(name = "flexible-task", description = "Flexible task with progress tracking") +public CallToolResult flexibleTask( + @McpProgressToken String progressToken, + CallToolRequest request) { + + // Access progress token for tracking + if (progressToken != null) { + // Track progress for this operation + System.out.println("Progress token: " + progressToken); + } + + // Process the request + Map args = request.arguments(); + return CallToolResult.success("Processed with token: " + progressToken); +} +``` + +The `@McpProgressToken` annotation is also supported in other MCP callback types: + +**Resource callbacks:** +```java +@McpResource(uri = "data://{id}", name = "Data Resource", description = "Resource with progress tracking") +public ReadResourceResult getDataWithProgress( + @McpProgressToken String progressToken, + String id) { + + if (progressToken != null) { + // Use progress token for tracking resource access + trackResourceAccess(progressToken, id); + } + + return new ReadResourceResult(List.of( + new TextResourceContents("data://" + id, "text/plain", "Data for " + id) + )); +} +``` + +**Prompt callbacks:** +```java +@McpPrompt(name = "generate-content", description = "Generate content with progress tracking") +public GetPromptResult generateContent( + @McpProgressToken String progressToken, + @McpArg(name = "topic", required = true) String topic) { + + if (progressToken != null) { + // Track prompt generation progress + System.out.println("Generating prompt with token: " + progressToken); + } + + return new GetPromptResult("Generated Content", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Content about " + topic)))); +} +``` + +**Complete callbacks:** +```java +@McpComplete(prompt = "auto-complete") +public List completeWithProgress( + @McpProgressToken String progressToken, + String prefix) { + + if (progressToken != null) { + // Track completion progress + System.out.println("Completion with token: " + progressToken); + } + + return generateCompletions(prefix); +} +``` + +This feature enables better tracking and monitoring of MCP operations, especially for long-running tasks that need to report progress back to clients. + ### Async Tool Example ```java diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpProgress.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpProgress.java index 4240c7d..8f6f35b 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpProgress.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpProgress.java @@ -16,20 +16,19 @@ * *

* Methods annotated with this annotation can be used to consume progress messages from - * MCP servers. The methods takes a single parameter of type - * {@code ProgressMessageNotification} + * MCP servers. The methods takes a single parameter of type {@code ProgressNotification} * * *

* Example usage:

{@code
  * @McpProgress
- * public void handleProgressMessage(ProgressMessageNotification notification) {
+ * public void handleProgressMessage(ProgressNotification notification) {
  *     // Handle the notification *
  * }
* * @author Christian Tzolov * - * @see io.modelcontextprotocol.spec.McpSchema.ProgressMessageNotification + * @see io.modelcontextprotocol.spec.McpSchema.ProgressNotification */ @Target({ ElementType.METHOD, ElementType.ANNOTATION_TYPE }) @Retention(RetentionPolicy.RUNTIME) diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpProgressToken.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpProgressToken.java new file mode 100644 index 0000000..9002150 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpProgressToken.java @@ -0,0 +1,24 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Used to annotate method parameter that should hold the progress token value as received + * from the requester. + * + * @author Christian Tzolov + */ +@Target({ ElementType.METHOD, ElementType.ANNOTATION_TYPE, ElementType.PARAMETER }) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface McpProgressToken { + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/complete/AbstractMcpCompleteMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/complete/AbstractMcpCompleteMethodCallback.java index fc63b35..c039234 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/complete/AbstractMcpCompleteMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/complete/AbstractMcpCompleteMethodCallback.java @@ -11,6 +11,7 @@ import org.springaicommunity.mcp.annotation.CompleteAdapter; import org.springaicommunity.mcp.annotation.McpComplete; +import org.springaicommunity.mcp.annotation.McpProgressToken; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CompleteReference; @@ -126,20 +127,41 @@ protected void validateMethod(Method method) { protected void validateParameters(Method method) { Parameter[] parameters = method.getParameters(); - // Check parameter count - must have at most 3 parameters - if (parameters.length > 3) { - throw new IllegalArgumentException("Method can have at most 3 input parameters: " + method.getName() - + " in " + method.getDeclaringClass().getName() + " has " + parameters.length + " parameters"); + // Count non-progress-token parameters + int nonProgressTokenParamCount = 0; + for (Parameter param : parameters) { + if (!param.isAnnotationPresent(McpProgressToken.class)) { + nonProgressTokenParamCount++; + } + } + + // Check parameter count - must have at most 3 non-progress-token parameters + if (nonProgressTokenParamCount > 3) { + throw new IllegalArgumentException( + "Method can have at most 3 input parameters (excluding @McpProgressToken): " + method.getName() + + " in " + method.getDeclaringClass().getName() + " has " + nonProgressTokenParamCount + + " parameters"); } // Check parameter types boolean hasExchangeParam = false; boolean hasRequestParam = false; boolean hasArgumentParam = false; + boolean hasProgressTokenParam = false; for (Parameter param : parameters) { Class paramType = param.getType(); + // Skip @McpProgressToken annotated parameters from validation + if (param.isAnnotationPresent(McpProgressToken.class)) { + if (hasProgressTokenParam) { + throw new IllegalArgumentException("Method cannot have more than one @McpProgressToken parameter: " + + method.getName() + " in " + method.getDeclaringClass().getName()); + } + hasProgressTokenParam = true; + continue; + } + if (isExchangeType(paramType)) { if (hasExchangeParam) { throw new IllegalArgumentException("Method cannot have more than one exchange parameter: " @@ -184,7 +206,22 @@ protected Object[] buildArgs(Method method, Object exchange, CompleteRequest req Parameter[] parameters = method.getParameters(); Object[] args = new Object[parameters.length]; + // First, handle @McpProgressToken annotated parameters for (int i = 0; i < parameters.length; i++) { + if (parameters[i].isAnnotationPresent(McpProgressToken.class)) { + // CompleteRequest doesn't have a progressToken method in the current spec + // Set to null for now - this would need to be updated when the spec + // supports it + args[i] = null; + } + } + + for (int i = 0; i < parameters.length; i++) { + // Skip if already set (e.g., @McpProgressToken) + if (args[i] != null || parameters[i].isAnnotationPresent(McpProgressToken.class)) { + continue; + } + Parameter param = parameters[i]; Class paramType = param.getType(); diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AbstractMcpPromptMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AbstractMcpPromptMethodCallback.java index 7f9625c..6df44b5 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AbstractMcpPromptMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AbstractMcpPromptMethodCallback.java @@ -10,6 +10,7 @@ import java.util.Map; import org.springaicommunity.mcp.annotation.McpArg; +import org.springaicommunity.mcp.annotation.McpProgressToken; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; @@ -87,10 +88,21 @@ protected void validateParameters(Method method) { boolean hasExchangeParam = false; boolean hasRequestParam = false; boolean hasMapParam = false; + boolean hasProgressTokenParam = false; for (java.lang.reflect.Parameter param : parameters) { Class paramType = param.getType(); + // Skip @McpProgressToken annotated parameters from validation + if (param.isAnnotationPresent(McpProgressToken.class)) { + if (hasProgressTokenParam) { + throw new IllegalArgumentException("Method cannot have more than one @McpProgressToken parameter: " + + method.getName() + " in " + method.getDeclaringClass().getName()); + } + hasProgressTokenParam = true; + continue; + } + if (isExchangeOrContextType(paramType)) { if (hasExchangeParam) { throw new IllegalArgumentException("Method cannot have more than one exchange parameter: " @@ -130,7 +142,23 @@ protected Object[] buildArgs(Method method, Object exchange, GetPromptRequest re java.lang.reflect.Parameter[] parameters = method.getParameters(); Object[] args = new Object[parameters.length]; + // First, handle @McpProgressToken annotated parameters + for (int i = 0; i < parameters.length; i++) { + if (parameters[i].isAnnotationPresent(McpProgressToken.class)) { + // GetPromptRequest doesn't have a progressToken method in the current + // spec + // Set to null for now - this would need to be updated when the spec + // supports it + args[i] = null; + } + } + for (int i = 0; i < parameters.length; i++) { + // Skip if already set (e.g., @McpProgressToken) + if (args[i] != null || parameters[i].isAnnotationPresent(McpProgressToken.class)) { + continue; + } + java.lang.reflect.Parameter param = parameters[i]; Class paramType = param.getType(); diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/AbstractMcpResourceMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/AbstractMcpResourceMethodCallback.java index 7d3ec89..d45427d 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/AbstractMcpResourceMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/AbstractMcpResourceMethodCallback.java @@ -10,6 +10,8 @@ import java.util.List; import java.util.Map; +import org.springaicommunity.mcp.annotation.McpProgressToken; + import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest; import io.modelcontextprotocol.util.Assert; @@ -142,12 +144,20 @@ protected void validateMethod(Method method) { protected void validateParametersWithoutUriVariables(Method method) { Parameter[] parameters = method.getParameters(); - // Check parameter count - must have at most 2 parameters - if (parameters.length > 2) { + // Count parameters excluding @McpProgressToken annotated ones + int nonProgressTokenParamCount = 0; + for (Parameter param : parameters) { + if (!param.isAnnotationPresent(McpProgressToken.class)) { + nonProgressTokenParamCount++; + } + } + + // Check parameter count - must have at most 2 non-progress-token parameters + if (nonProgressTokenParamCount > 2) { throw new IllegalArgumentException( - "Method can have at most 2 input parameters when no URI variables are present: " + method.getName() - + " in " + method.getDeclaringClass().getName() + " has " + parameters.length - + " parameters"); + "Method can have at most 2 input parameters (excluding @McpProgressToken) when no URI variables are present: " + + method.getName() + " in " + method.getDeclaringClass().getName() + " has " + + nonProgressTokenParamCount + " non-progress-token parameters"); } // Check parameter types @@ -156,6 +166,11 @@ protected void validateParametersWithoutUriVariables(Method method) { boolean hasRequestOrUriParam = false; for (Parameter param : parameters) { + // Skip @McpProgressToken annotated parameters + if (param.isAnnotationPresent(McpProgressToken.class)) { + continue; + } + Class paramType = param.getType(); if (isExchangeOrContextType(paramType)) { @@ -177,13 +192,13 @@ else if (ReadResourceRequest.class.isAssignableFrom(paramType) } else { throw new IllegalArgumentException( - "Method parameters must be exchange, ReadResourceRequest, or String when no URI variables are present: " + "Method parameters must be exchange, ReadResourceRequest, String, or @McpProgressToken when no URI variables are present: " + method.getName() + " in " + method.getDeclaringClass().getName() + " has parameter of type " + paramType.getName()); } } - if (!hasValidParams && parameters.length > 0) { + if (!hasValidParams && nonProgressTokenParamCount > 0) { throw new IllegalArgumentException( "Method must have either ReadResourceRequest or String parameter when no URI variables are present: " + method.getName() + " in " + method.getDeclaringClass().getName()); @@ -199,17 +214,23 @@ else if (ReadResourceRequest.class.isAssignableFrom(paramType) protected void validateParametersWithUriVariables(Method method) { Parameter[] parameters = method.getParameters(); - // Count special parameters (exchange and request) + // Count special parameters (exchange, request, and progress token) int exchangeParamCount = 0; int requestParamCount = 0; + int progressTokenParamCount = 0; for (Parameter param : parameters) { - Class paramType = param.getType(); - if (isExchangeOrContextType(paramType)) { - exchangeParamCount++; + if (param.isAnnotationPresent(McpProgressToken.class)) { + progressTokenParamCount++; } - else if (ReadResourceRequest.class.isAssignableFrom(paramType)) { - requestParamCount++; + else { + Class paramType = param.getType(); + if (isExchangeOrContextType(paramType)) { + exchangeParamCount++; + } + else if (ReadResourceRequest.class.isAssignableFrom(paramType)) { + requestParamCount++; + } } } @@ -226,7 +247,7 @@ else if (ReadResourceRequest.class.isAssignableFrom(paramType)) { } // Calculate how many parameters should be for URI variables - int specialParamCount = exchangeParamCount + requestParamCount; + int specialParamCount = exchangeParamCount + requestParamCount + progressTokenParamCount; int uriVarParamCount = parameters.length - specialParamCount; // Check if we have the right number of parameters for URI variables @@ -239,6 +260,11 @@ else if (ReadResourceRequest.class.isAssignableFrom(paramType)) { // Check that all non-special parameters are String type (for URI variables) for (Parameter param : parameters) { + // Skip @McpProgressToken annotated parameters + if (param.isAnnotationPresent(McpProgressToken.class)) { + continue; + } + Class paramType = param.getType(); if (!isExchangeOrContextType(paramType) && !ReadResourceRequest.class.isAssignableFrom(paramType) && !String.class.isAssignableFrom(paramType)) { @@ -253,7 +279,7 @@ else if (ReadResourceRequest.class.isAssignableFrom(paramType)) { * Builds the arguments array for invoking the method. *

* This method constructs an array of arguments based on the method's parameter types - * and the available values (exchange, request, URI variables). + * and the available values (exchange, request, URI variables, progress token). * @param method The method to build arguments for * @param exchange The server exchange * @param request The resource request @@ -265,6 +291,14 @@ protected Object[] buildArgs(Method method, Object exchange, ReadResourceRequest Parameter[] parameters = method.getParameters(); Object[] args = new Object[parameters.length]; + // First, handle @McpProgressToken annotated parameters + for (int i = 0; i < parameters.length; i++) { + if (parameters[i].isAnnotationPresent(McpProgressToken.class)) { + // Get progress token from request + args[i] = request != null ? request.progressToken() : null; + } + } + if (!this.uriVariables.isEmpty()) { this.buildArgsWithUriVariables(parameters, args, exchange, request, uriVariableValues); } @@ -290,8 +324,14 @@ protected void buildArgsWithUriVariables(Parameter[] parameters, Object[] args, // Track which URI variables have been assigned List assignedVariables = new ArrayList<>(); - // First pass: assign special parameters (exchange and request) + // First pass: assign special parameters (exchange, request, and skip progress + // token) for (int i = 0; i < parameters.length; i++) { + // Skip if parameter is annotated with @McpProgressToken (already handled) + if (parameters[i].isAnnotationPresent(McpProgressToken.class)) { + continue; + } + Class paramType = parameters[i].getType(); if (isExchangeOrContextType(paramType)) { args[i] = exchange; @@ -304,8 +344,9 @@ else if (ReadResourceRequest.class.isAssignableFrom(paramType)) { // Second pass: assign URI variables to the remaining parameters int variableIndex = 0; for (int i = 0; i < parameters.length; i++) { - // Skip parameters that already have values (exchange or request) - if (args[i] != null) { + // Skip if parameter is annotated with @McpProgressToken (already handled) + // or if it's already assigned (exchange or request) + if (parameters[i].isAnnotationPresent(McpProgressToken.class) || args[i] != null) { continue; } @@ -336,6 +377,11 @@ else if (ReadResourceRequest.class.isAssignableFrom(paramType)) { protected void buildArgsWithoutUriVariables(Parameter[] parameters, Object[] args, Object exchange, ReadResourceRequest request) { for (int i = 0; i < parameters.length; i++) { + // Skip if parameter is annotated with @McpProgressToken (already handled) + if (parameters[i].isAnnotationPresent(McpProgressToken.class)) { + continue; + } + Parameter param = parameters[i]; Class paramType = param.getType(); diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java index 5e32cd8..692d28d 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java @@ -23,6 +23,7 @@ import java.util.stream.Stream; import org.reactivestreams.Publisher; +import org.springaicommunity.mcp.annotation.McpProgressToken; import org.springaicommunity.mcp.annotation.McpTool; import org.springaicommunity.mcp.method.tool.utils.JsonParser; @@ -85,17 +86,6 @@ protected Object callMethod(Object[] methodArguments) { return result; } - /** - * Builds the method arguments from the context and tool input arguments. - * @param exchangeOrContext The exchange or context object (e.g., - * McpAsyncServerExchange or McpTransportContext) - * @param toolInputArguments The input arguments from the tool request - * @return An array of method arguments - */ - protected Object[] buildMethodArguments(T exchangeOrContext, Map toolInputArguments) { - return buildMethodArguments(exchangeOrContext, toolInputArguments, null); - } - /** * Builds the method arguments from the context, tool input arguments, and optionally * the full request. @@ -108,6 +98,12 @@ protected Object[] buildMethodArguments(T exchangeOrContext, Map protected Object[] buildMethodArguments(T exchangeOrContext, Map toolInputArguments, CallToolRequest request) { return Stream.of(this.toolMethod.getParameters()).map(parameter -> { + // Check if parameter is annotated with @McpProgressToken + if (parameter.isAnnotationPresent(McpProgressToken.class)) { + // Return the progress token from the request + return request != null ? request.progressToken() : null; + } + // Check if parameter is CallToolRequest type if (CallToolRequest.class.isAssignableFrom(parameter.getType())) { return request; diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java index 6ca579f..b208be1 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java @@ -22,6 +22,7 @@ import java.util.Map; import java.util.stream.Stream; +import org.springaicommunity.mcp.annotation.McpProgressToken; import org.springaicommunity.mcp.annotation.McpTool; import org.springaicommunity.mcp.method.tool.utils.JsonParser; @@ -81,17 +82,6 @@ protected Object callMethod(Object[] methodArguments) { return result; } - /** - * Builds the method arguments from the context and tool input arguments. - * @param exchangeOrContext The exchange or context object (e.g., - * McpSyncServerExchange or McpTransportContext) - * @param toolInputArguments The input arguments from the tool request - * @return An array of method arguments - */ - protected Object[] buildMethodArguments(T exchangeOrContext, Map toolInputArguments) { - return buildMethodArguments(exchangeOrContext, toolInputArguments, null); - } - /** * Builds the method arguments from the context, tool input arguments, and optionally * the full request. @@ -104,6 +94,12 @@ protected Object[] buildMethodArguments(T exchangeOrContext, Map protected Object[] buildMethodArguments(T exchangeOrContext, Map toolInputArguments, CallToolRequest request) { return Stream.of(this.toolMethod.getParameters()).map(parameter -> { + // Check if parameter is annotated with @McpProgressToken + if (parameter.isAnnotationPresent(McpProgressToken.class)) { + // Return the progress token from the request + return request != null ? request.progressToken() : null; + } + // Check if parameter is CallToolRequest type if (CallToolRequest.class.isAssignableFrom(parameter.getType())) { return request; diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/JsonSchemaGenerator.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/JsonSchemaGenerator.java index 6f84b58..a911df1 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/JsonSchemaGenerator.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/JsonSchemaGenerator.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; +import org.springaicommunity.mcp.annotation.McpProgressToken; import org.springaicommunity.mcp.annotation.McpToolParam; import com.fasterxml.jackson.annotation.JsonProperty; @@ -101,13 +102,15 @@ private static String internalGenerateFromMethodArguments(Method method) { // If method has CallToolRequest, return minimal schema if (hasCallToolRequestParam) { - // Check if there are other parameters besides CallToolRequest and exchange - // types + // Check if there are other parameters besides CallToolRequest, exchange + // types, + // and @McpProgressToken annotated parameters boolean hasOtherParams = Arrays.stream(method.getParameters()).anyMatch(param -> { Class type = param.getType(); return !CallToolRequest.class.isAssignableFrom(type) && !McpSyncServerExchange.class.isAssignableFrom(type) - && !McpAsyncServerExchange.class.isAssignableFrom(type); + && !McpAsyncServerExchange.class.isAssignableFrom(type) + && !param.isAnnotationPresent(McpProgressToken.class); }); // If only CallToolRequest (and possibly exchange), return empty schema @@ -128,9 +131,15 @@ private static String internalGenerateFromMethodArguments(Method method) { List required = new ArrayList<>(); for (int i = 0; i < method.getParameterCount(); i++) { - String parameterName = method.getParameters()[i].getName(); + Parameter parameter = method.getParameters()[i]; + String parameterName = parameter.getName(); Type parameterType = method.getGenericParameterTypes()[i]; + // Skip parameters annotated with @McpProgressToken + if (parameter.isAnnotationPresent(McpProgressToken.class)) { + continue; + } + // Skip special parameter types if (parameterType instanceof Class parameterClass && (ClassUtils.isAssignable(McpSyncServerExchange.class, parameterClass) diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/complete/SyncMcpCompleteMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/complete/SyncMcpCompleteMethodCallbackTests.java index 48a640a..301c504 100644 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/complete/SyncMcpCompleteMethodCallbackTests.java +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/complete/SyncMcpCompleteMethodCallbackTests.java @@ -16,7 +16,7 @@ import io.modelcontextprotocol.spec.McpSchema.ResourceReference; import org.junit.jupiter.api.Test; import org.springaicommunity.mcp.annotation.McpComplete; -import org.springaicommunity.mcp.method.complete.SyncMcpCompleteMethodCallback; +import org.springaicommunity.mcp.annotation.McpProgressToken; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -106,6 +106,25 @@ public CompleteResult duplicateArgumentParameters(CompleteRequest.CompleteArgume return new CompleteResult(new CompleteCompletion(List.of(), 0, false)); } + public CompleteResult getCompletionWithProgressToken(@McpProgressToken String progressToken, + CompleteRequest request) { + String tokenInfo = progressToken != null ? " (token: " + progressToken + ")" : " (no token)"; + return new CompleteResult(new CompleteCompletion( + List.of("Completion with progress" + tokenInfo + " for: " + request.argument().value()), 1, false)); + } + + public CompleteResult getCompletionWithMixedAndProgress(McpSyncServerExchange exchange, + @McpProgressToken String progressToken, String value, CompleteRequest request) { + String tokenInfo = progressToken != null ? " (token: " + progressToken + ")" : " (no token)"; + return new CompleteResult(new CompleteCompletion(List.of("Mixed completion" + tokenInfo + " with value: " + + value + " and request: " + request.argument().value()), 1, false)); + } + + public CompleteResult duplicateProgressTokenParameters(@McpProgressToken String token1, + @McpProgressToken String token2) { + return new CompleteResult(new CompleteCompletion(List.of(), 0, false)); + } + } // Helper method to create a mock McpComplete annotation @@ -487,4 +506,71 @@ public void testNullRequest() throws Exception { .hasMessageContaining("Request must not be null"); } + @Test + public void testCallbackWithProgressToken() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("getCompletionWithProgressToken", String.class, + CompleteRequest.class); + + BiFunction callback = SyncMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + CompleteResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + // Since CompleteRequest doesn't have progressToken, it should be null + assertThat(result.completion().values().get(0)).isEqualTo("Completion with progress (no token) for: value"); + } + + @Test + public void testCallbackWithMixedAndProgressToken() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("getCompletionWithMixedAndProgress", + McpSyncServerExchange.class, String.class, String.class, CompleteRequest.class); + + BiFunction callback = SyncMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + CompleteResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + // Since CompleteRequest doesn't have progressToken, it should be null + assertThat(result.completion().values().get(0)) + .isEqualTo("Mixed completion (no token) with value: value and request: value"); + } + + @Test + public void testDuplicateProgressTokenParameters() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("duplicateProgressTokenParameters", String.class, + String.class); + + assertThatThrownBy(() -> SyncMcpCompleteMethodCallback.builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method cannot have more than one @McpProgressToken parameter"); + } + } diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/prompt/SyncMcpPromptMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/prompt/SyncMcpPromptMethodCallbackTests.java index ff8374a..129ea39 100644 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/prompt/SyncMcpPromptMethodCallbackTests.java +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/prompt/SyncMcpPromptMethodCallbackTests.java @@ -16,6 +16,7 @@ import org.junit.jupiter.api.Test; import org.springaicommunity.mcp.annotation.McpArg; +import org.springaicommunity.mcp.annotation.McpProgressToken; import org.springaicommunity.mcp.annotation.McpPrompt; import io.modelcontextprotocol.server.McpSyncServerExchange; @@ -110,6 +111,29 @@ public GetPromptResult duplicateMapParameters(Map args1, Map callback = SyncMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + Map args = new HashMap<>(); + args.put("name", "John"); + // Note: GetPromptRequest doesn't have progressToken in current spec, so it will + // be null + GetPromptRequest request = new GetPromptRequest("progress-token", args); + + GetPromptResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.description()).isEqualTo("Progress token prompt"); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + // Since GetPromptRequest doesn't have progressToken, it should be null + assertThat(((TextContent) message.content()).text()).isEqualTo("Hello John (no token)"); + } + + @Test + public void testCallbackWithMixedAndProgressToken() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getPromptWithMixedAndProgress", McpSyncServerExchange.class, + String.class, String.class, GetPromptRequest.class); + + Prompt prompt = createTestPrompt("mixed-with-progress", "A prompt with mixed args and progress token"); + + BiFunction callback = SyncMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("mixed-with-progress", args); + + GetPromptResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.description()).isEqualTo("Mixed with progress prompt"); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + // Since GetPromptRequest doesn't have progressToken, it should be null + assertThat(((TextContent) message.content()).text()) + .isEqualTo("Hello John from mixed-with-progress (no token)"); + } + + @Test + public void testDuplicateProgressTokenParameters() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("duplicateProgressTokenParameters", String.class, + String.class); + + Prompt prompt = createTestPrompt("invalid", "Invalid parameters"); + + assertThatThrownBy( + () -> SyncMcpPromptMethodCallback.builder().method(method).bean(provider).prompt(prompt).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method cannot have more than one @McpProgressToken parameter"); + } + } diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/AsyncMcpResourceMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/AsyncMcpResourceMethodCallbackTests.java index 4d46a77..4163bef 100644 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/AsyncMcpResourceMethodCallbackTests.java +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/AsyncMcpResourceMethodCallbackTests.java @@ -18,6 +18,7 @@ import io.modelcontextprotocol.util.McpUriTemplateManager; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpProgressToken; import org.springaicommunity.mcp.annotation.McpResource; import org.springaicommunity.mcp.annotation.ResourceAdaptor; @@ -27,6 +28,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Tests for {@link AsyncMcpResourceMethodCallback}. @@ -43,6 +45,49 @@ public ReadResourceResult getResourceWithRequest(ReadResourceRequest request) { List.of(new TextResourceContents(request.uri(), "text/plain", "Content for " + request.uri()))); } + // Methods for testing @McpProgressToken + public ReadResourceResult getResourceWithProgressToken(@McpProgressToken String progressToken, + ReadResourceRequest request) { + String content = "Content with progress token: " + progressToken + " for " + request.uri(); + return new ReadResourceResult(List.of(new TextResourceContents(request.uri(), "text/plain", content))); + } + + public Mono getResourceWithProgressTokenAsync(@McpProgressToken String progressToken, + ReadResourceRequest request) { + String content = "Async content with progress token: " + progressToken + " for " + request.uri(); + return Mono + .just(new ReadResourceResult(List.of(new TextResourceContents(request.uri(), "text/plain", content)))); + } + + public ReadResourceResult getResourceWithProgressTokenOnly(@McpProgressToken String progressToken) { + String content = "Content with only progress token: " + progressToken; + return new ReadResourceResult(List.of(new TextResourceContents("test://resource", "text/plain", content))); + } + + @McpResource(uri = "users/{userId}/posts/{postId}") + public ReadResourceResult getResourceWithProgressTokenAndUriVariables(@McpProgressToken String progressToken, + String userId, String postId) { + String content = "User: " + userId + ", Post: " + postId + ", Progress: " + progressToken; + return new ReadResourceResult( + List.of(new TextResourceContents("users/" + userId + "/posts/" + postId, "text/plain", content))); + } + + public Mono getResourceWithExchangeAndProgressToken(McpAsyncServerExchange exchange, + @McpProgressToken String progressToken, ReadResourceRequest request) { + String content = "Async content with exchange and progress token: " + progressToken + " for " + + request.uri(); + return Mono + .just(new ReadResourceResult(List.of(new TextResourceContents(request.uri(), "text/plain", content)))); + } + + public ReadResourceResult getResourceWithMultipleProgressTokens(@McpProgressToken String progressToken1, + @McpProgressToken String progressToken2, ReadResourceRequest request) { + // This should only use the first progress token + String content = "Content with progress tokens: " + progressToken1 + " and " + progressToken2 + " for " + + request.uri(); + return new ReadResourceResult(List.of(new TextResourceContents(request.uri(), "text/plain", content))); + } + public ReadResourceResult getResourceWithExchange(McpAsyncServerExchange exchange, ReadResourceRequest request) { return new ReadResourceResult(List.of(new TextResourceContents(request.uri(), "text/plain", @@ -569,4 +614,212 @@ public McpUriTemplateManager create(String uriTemplate) { .verify(); } + // Tests for @McpProgressToken functionality + @Test + public void testCallbackWithProgressToken() throws Exception { + TestAsyncResourceProvider provider = new TestAsyncResourceProvider(); + Method method = TestAsyncResourceProvider.class.getMethod("getResourceWithProgressToken", String.class, + ReadResourceRequest.class); + + BiFunction> callback = AsyncMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + ReadResourceRequest request = mock(ReadResourceRequest.class); + when(request.uri()).thenReturn("test/resource"); + when(request.progressToken()).thenReturn("progress-123"); + + Mono resultMono = callback.apply(exchange, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Content with progress token: progress-123 for test/resource"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithProgressTokenAsync() throws Exception { + TestAsyncResourceProvider provider = new TestAsyncResourceProvider(); + Method method = TestAsyncResourceProvider.class.getMethod("getResourceWithProgressTokenAsync", String.class, + ReadResourceRequest.class); + + BiFunction> callback = AsyncMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + ReadResourceRequest request = mock(ReadResourceRequest.class); + when(request.uri()).thenReturn("test/resource"); + when(request.progressToken()).thenReturn("progress-456"); + + Mono resultMono = callback.apply(exchange, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()) + .isEqualTo("Async content with progress token: progress-456 for test/resource"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithProgressTokenNull() throws Exception { + TestAsyncResourceProvider provider = new TestAsyncResourceProvider(); + Method method = TestAsyncResourceProvider.class.getMethod("getResourceWithProgressToken", String.class, + ReadResourceRequest.class); + + BiFunction> callback = AsyncMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + ReadResourceRequest request = mock(ReadResourceRequest.class); + when(request.uri()).thenReturn("test/resource"); + when(request.progressToken()).thenReturn(null); + + Mono resultMono = callback.apply(exchange, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Content with progress token: null for test/resource"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithProgressTokenOnly() throws Exception { + TestAsyncResourceProvider provider = new TestAsyncResourceProvider(); + Method method = TestAsyncResourceProvider.class.getMethod("getResourceWithProgressTokenOnly", String.class); + + BiFunction> callback = AsyncMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + ReadResourceRequest request = mock(ReadResourceRequest.class); + when(request.uri()).thenReturn("test/resource"); + when(request.progressToken()).thenReturn("progress-789"); + + Mono resultMono = callback.apply(exchange, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Content with only progress token: progress-789"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithProgressTokenAndUriVariables() throws Exception { + TestAsyncResourceProvider provider = new TestAsyncResourceProvider(); + Method method = TestAsyncResourceProvider.class.getMethod("getResourceWithProgressTokenAndUriVariables", + String.class, String.class, String.class); + McpResource resourceAnnotation = method.getAnnotation(McpResource.class); + + BiFunction> callback = AsyncMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(resourceAnnotation)) + .build(); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + ReadResourceRequest request = mock(ReadResourceRequest.class); + when(request.uri()).thenReturn("users/123/posts/456"); + when(request.progressToken()).thenReturn("progress-abc"); + + Mono resultMono = callback.apply(exchange, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("User: 123, Post: 456, Progress: progress-abc"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithExchangeAndProgressToken() throws Exception { + TestAsyncResourceProvider provider = new TestAsyncResourceProvider(); + Method method = TestAsyncResourceProvider.class.getMethod("getResourceWithExchangeAndProgressToken", + McpAsyncServerExchange.class, String.class, ReadResourceRequest.class); + + BiFunction> callback = AsyncMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + ReadResourceRequest request = mock(ReadResourceRequest.class); + when(request.uri()).thenReturn("test/resource"); + when(request.progressToken()).thenReturn("progress-def"); + + Mono resultMono = callback.apply(exchange, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()) + .isEqualTo("Async content with exchange and progress token: progress-def for test/resource"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithMultipleProgressTokens() throws Exception { + TestAsyncResourceProvider provider = new TestAsyncResourceProvider(); + Method method = TestAsyncResourceProvider.class.getMethod("getResourceWithMultipleProgressTokens", String.class, + String.class, ReadResourceRequest.class); + + BiFunction> callback = AsyncMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + ReadResourceRequest request = mock(ReadResourceRequest.class); + when(request.uri()).thenReturn("test/resource"); + when(request.progressToken()).thenReturn("progress-first"); + + Mono resultMono = callback.apply(exchange, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + // Both progress tokens should receive the same value from the request + assertThat(textContent.text()) + .isEqualTo("Content with progress tokens: progress-first and progress-first for test/resource"); + }).verifyComplete(); + } + } diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/SyncMcpResourceMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/SyncMcpResourceMethodCallbackTests.java index 49d64c2..1641bbf 100644 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/SyncMcpResourceMethodCallbackTests.java +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/SyncMcpResourceMethodCallbackTests.java @@ -7,12 +7,14 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import java.lang.reflect.Method; import java.util.List; import java.util.function.BiFunction; import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpProgressToken; import org.springaicommunity.mcp.annotation.McpResource; import org.springaicommunity.mcp.annotation.ResourceAdaptor; @@ -37,6 +39,47 @@ public ReadResourceResult getResourceWithRequest(ReadResourceRequest request) { List.of(new TextResourceContents(request.uri(), "text/plain", "Content for " + request.uri()))); } + // Methods for testing @McpProgressToken + public ReadResourceResult getResourceWithProgressToken(@McpProgressToken String progressToken, + ReadResourceRequest request) { + String content = "Content with progress token: " + progressToken + " for " + request.uri(); + return new ReadResourceResult(List.of(new TextResourceContents(request.uri(), "text/plain", content))); + } + + public ReadResourceResult getResourceWithProgressTokenOnly(@McpProgressToken String progressToken) { + String content = "Content with only progress token: " + progressToken; + return new ReadResourceResult(List.of(new TextResourceContents("test://resource", "text/plain", content))); + } + + @McpResource(uri = "users/{userId}/posts/{postId}") + public ReadResourceResult getResourceWithProgressTokenAndUriVariables(@McpProgressToken String progressToken, + String userId, String postId) { + String content = "User: " + userId + ", Post: " + postId + ", Progress: " + progressToken; + return new ReadResourceResult( + List.of(new TextResourceContents("users/" + userId + "/posts/" + postId, "text/plain", content))); + } + + public ReadResourceResult getResourceWithExchangeAndProgressToken(McpSyncServerExchange exchange, + @McpProgressToken String progressToken, ReadResourceRequest request) { + String content = "Content with exchange and progress token: " + progressToken + " for " + request.uri(); + return new ReadResourceResult(List.of(new TextResourceContents(request.uri(), "text/plain", content))); + } + + public ReadResourceResult getResourceWithMultipleProgressTokens(@McpProgressToken String progressToken1, + @McpProgressToken String progressToken2, ReadResourceRequest request) { + // This should only use the first progress token + String content = "Content with progress tokens: " + progressToken1 + " and " + progressToken2 + " for " + + request.uri(); + return new ReadResourceResult(List.of(new TextResourceContents(request.uri(), "text/plain", content))); + } + + @McpResource(uri = "users/{userId}") + public ReadResourceResult getResourceWithProgressTokenAndMixedParams(@McpProgressToken String progressToken, + String userId) { + String content = "User: " + userId + ", Progress: " + progressToken; + return new ReadResourceResult(List.of(new TextResourceContents("users/" + userId, "text/plain", content))); + } + public ReadResourceResult getResourceWithExchange(McpSyncServerExchange exchange, ReadResourceRequest request) { return new ReadResourceResult(List.of(new TextResourceContents(request.uri(), "text/plain", "Content with exchange for " + request.uri()))); @@ -610,4 +653,198 @@ public void testMethodWithoutMcpResourceAnnotation() throws Exception { .hasMessageContaining("URI must not be null or empty"); } + // Tests for @McpProgressToken functionality + @Test + public void testCallbackWithProgressToken() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getResourceWithProgressToken", String.class, + ReadResourceRequest.class); + + BiFunction callback = SyncMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + ReadResourceRequest request = mock(ReadResourceRequest.class); + when(request.uri()).thenReturn("test/resource"); + when(request.progressToken()).thenReturn("progress-123"); + + ReadResourceResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Content with progress token: progress-123 for test/resource"); + } + + @Test + public void testCallbackWithProgressTokenNull() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getResourceWithProgressToken", String.class, + ReadResourceRequest.class); + + BiFunction callback = SyncMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + ReadResourceRequest request = mock(ReadResourceRequest.class); + when(request.uri()).thenReturn("test/resource"); + when(request.progressToken()).thenReturn(null); + + ReadResourceResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Content with progress token: null for test/resource"); + } + + @Test + public void testCallbackWithProgressTokenOnly() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getResourceWithProgressTokenOnly", String.class); + + BiFunction callback = SyncMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + ReadResourceRequest request = mock(ReadResourceRequest.class); + when(request.uri()).thenReturn("test/resource"); + when(request.progressToken()).thenReturn("progress-456"); + + ReadResourceResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Content with only progress token: progress-456"); + } + + @Test + public void testCallbackWithProgressTokenAndUriVariables() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getResourceWithProgressTokenAndUriVariables", + String.class, String.class, String.class); + McpResource resourceAnnotation = method.getAnnotation(McpResource.class); + + BiFunction callback = SyncMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(resourceAnnotation)) + .build(); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + ReadResourceRequest request = mock(ReadResourceRequest.class); + when(request.uri()).thenReturn("users/123/posts/456"); + when(request.progressToken()).thenReturn("progress-789"); + + ReadResourceResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("User: 123, Post: 456, Progress: progress-789"); + } + + @Test + public void testCallbackWithExchangeAndProgressToken() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getResourceWithExchangeAndProgressToken", + McpSyncServerExchange.class, String.class, ReadResourceRequest.class); + + BiFunction callback = SyncMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + ReadResourceRequest request = mock(ReadResourceRequest.class); + when(request.uri()).thenReturn("test/resource"); + when(request.progressToken()).thenReturn("progress-abc"); + + ReadResourceResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()) + .isEqualTo("Content with exchange and progress token: progress-abc for test/resource"); + } + + @Test + public void testCallbackWithMultipleProgressTokens() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getResourceWithMultipleProgressTokens", String.class, + String.class, ReadResourceRequest.class); + + BiFunction callback = SyncMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + ReadResourceRequest request = mock(ReadResourceRequest.class); + when(request.uri()).thenReturn("test/resource"); + when(request.progressToken()).thenReturn("progress-first"); + + ReadResourceResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + // Both progress tokens should receive the same value from the request + assertThat(textContent.text()) + .isEqualTo("Content with progress tokens: progress-first and progress-first for test/resource"); + } + + @Test + public void testCallbackWithProgressTokenAndMixedParams() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getResourceWithProgressTokenAndMixedParams", String.class, + String.class); + McpResource resourceAnnotation = method.getAnnotation(McpResource.class); + + BiFunction callback = SyncMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(resourceAnnotation)) + .build(); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + ReadResourceRequest request = mock(ReadResourceRequest.class); + when(request.uri()).thenReturn("users/john"); + when(request.progressToken()).thenReturn("progress-xyz"); + + ReadResourceResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("User: john, Progress: progress-xyz"); + } + } diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/CallToolRequestSupportTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/CallToolRequestSupportTests.java index 3a99d44..ab66fb5 100644 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/CallToolRequestSupportTests.java +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/CallToolRequestSupportTests.java @@ -25,6 +25,7 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.TextContent; import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpProgressToken; import org.springaicommunity.mcp.annotation.McpTool; import org.springaicommunity.mcp.annotation.McpToolParam; import org.springaicommunity.mcp.method.tool.utils.JsonSchemaGenerator; @@ -121,6 +122,33 @@ public CallToolResult validateSchema(CallToolRequest request) { .build(); } + /** + * Tool with @McpProgressToken parameter + */ + @McpTool(name = "progress-token-tool", description = "Tool with progress token") + public CallToolResult progressTokenTool( + @McpToolParam(description = "Input parameter", required = true) String input, + @McpProgressToken String progressToken) { + return CallToolResult.builder() + .addTextContent("Input: " + input + ", Progress Token: " + progressToken) + .build(); + } + + /** + * Tool with mixed special parameters including @McpProgressToken + */ + @McpTool(name = "mixed-special-params-tool", description = "Tool with all special parameters") + public CallToolResult mixedSpecialParamsTool(McpSyncServerExchange exchange, CallToolRequest request, + @McpProgressToken String progressToken, + @McpToolParam(description = "Regular parameter", required = true) String regularParam) { + + return CallToolResult.builder() + .addTextContent(String.format("Exchange: %s, Request: %s, Token: %s, Param: %s", + exchange != null ? "present" : "null", request != null ? request.name() : "null", + progressToken != null ? progressToken : "null", regularParam)) + .build(); + } + /** * Regular tool without CallToolRequest for comparison */ @@ -359,7 +387,7 @@ public void testSyncMcpToolProviderWithCallToolRequest() { var toolSpecs = toolProvider.getToolSpecifications(); // Should have all tools registered - assertThat(toolSpecs).hasSize(6); // All 6 tools from the provider + assertThat(toolSpecs).hasSize(8); // All 8 tools from the provider // Find the dynamic tool var dynamicToolSpec = toolSpecs.stream() @@ -430,4 +458,144 @@ public void testCallToolRequestParameterInjection() throws Exception { assertThat(((TextContent) result.content().get(0)).text()).contains("tool: dynamic-tool"); } + @Test + public void testProgressTokenParameterInjection() throws Exception { + // Test that @McpProgressToken parameter receives the progress token from request + CallToolRequestTestProvider provider = new CallToolRequestTestProvider(); + Method method = CallToolRequestTestProvider.class.getMethod("progressTokenTool", String.class, String.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + + // Create request with progress token + CallToolRequest request = CallToolRequest.builder() + .name("progress-token-tool") + .arguments(Map.of("input", "test-input")) + .progressToken("test-progress-token-123") + .build(); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(((TextContent) result.content().get(0)).text()) + .isEqualTo("Input: test-input, Progress Token: test-progress-token-123"); + } + + @Test + public void testProgressTokenParameterWithNullToken() throws Exception { + // Test that @McpProgressToken parameter handles null progress token + CallToolRequestTestProvider provider = new CallToolRequestTestProvider(); + Method method = CallToolRequestTestProvider.class.getMethod("progressTokenTool", String.class, String.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + + // Create request without progress token + CallToolRequest request = new CallToolRequest("progress-token-tool", Map.of("input", "test-input")); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Input: test-input, Progress Token: null"); + } + + @Test + public void testMixedSpecialParameters() throws Exception { + // Test tool with all types of special parameters + CallToolRequestTestProvider provider = new CallToolRequestTestProvider(); + Method method = CallToolRequestTestProvider.class.getMethod("mixedSpecialParamsTool", + McpSyncServerExchange.class, CallToolRequest.class, String.class, String.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + + CallToolRequest request = CallToolRequest.builder() + .name("mixed-special-params-tool") + .arguments(Map.of("regularParam", "test-value")) + .progressToken("progress-123") + .build(); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(((TextContent) result.content().get(0)).text()) + .isEqualTo("Exchange: present, Request: mixed-special-params-tool, Token: progress-123, Param: test-value"); + } + + @Test + public void testJsonSchemaGenerationExcludesProgressToken() throws Exception { + // Test that schema generation excludes @McpProgressToken parameters + Method progressTokenMethod = CallToolRequestTestProvider.class.getMethod("progressTokenTool", String.class, + String.class); + String progressTokenSchema = JsonSchemaGenerator.generateForMethodInput(progressTokenMethod); + + // Parse the schema + JsonNode schemaNode = objectMapper.readTree(progressTokenSchema); + + // Should only have the 'input' parameter, not the progressToken + assertThat(schemaNode.has("properties")).isTrue(); + JsonNode properties = schemaNode.get("properties"); + assertThat(properties.has("input")).isTrue(); + assertThat(properties.has("progressToken")).isFalse(); + assertThat(properties.size()).isEqualTo(1); + + // Check required array + assertThat(schemaNode.has("required")).isTrue(); + JsonNode required = schemaNode.get("required"); + assertThat(required.size()).isEqualTo(1); + assertThat(required.get(0).asText()).isEqualTo("input"); + } + + @Test + public void testJsonSchemaGenerationForMixedSpecialParameters() throws Exception { + // Test schema generation for method with all special parameters + Method mixedMethod = CallToolRequestTestProvider.class.getMethod("mixedSpecialParamsTool", + McpSyncServerExchange.class, CallToolRequest.class, String.class, String.class); + String mixedSchema = JsonSchemaGenerator.generateForMethodInput(mixedMethod); + + // Parse the schema + JsonNode schemaNode = objectMapper.readTree(mixedSchema); + + // Should only have the 'regularParam' parameter + assertThat(schemaNode.has("properties")).isTrue(); + JsonNode properties = schemaNode.get("properties"); + assertThat(properties.has("regularParam")).isTrue(); + assertThat(properties.has("progressToken")).isFalse(); + assertThat(properties.size()).isEqualTo(1); + + // Check required array + assertThat(schemaNode.has("required")).isTrue(); + JsonNode required = schemaNode.get("required"); + assertThat(required.size()).isEqualTo(1); + assertThat(required.get(0).asText()).isEqualTo("regularParam"); + } + + @Test + public void testSyncMcpToolProviderWithProgressToken() { + // Test that SyncMcpToolProvider handles @McpProgressToken tools correctly + CallToolRequestTestProvider provider = new CallToolRequestTestProvider(); + SyncMcpToolProvider toolProvider = new SyncMcpToolProvider(List.of(provider)); + + var toolSpecs = toolProvider.getToolSpecifications(); + + // Find the progress token tool + var progressTokenToolSpec = toolSpecs.stream() + .filter(spec -> spec.tool().name().equals("progress-token-tool")) + .findFirst() + .orElse(null); + + assertThat(progressTokenToolSpec).isNotNull(); + assertThat(progressTokenToolSpec.tool().description()).isEqualTo("Tool with progress token"); + + // The input schema should only contain the regular parameter + var inputSchema = progressTokenToolSpec.tool().inputSchema(); + assertThat(inputSchema).isNotNull(); + String schemaStr = inputSchema.toString(); + assertThat(schemaStr).contains("input"); + assertThat(schemaStr).doesNotContain("progressToken"); + } + }