Skip to content

Commit 0508e4c

Browse files
authored
refactor: enhance MCP method callbacks with improved type safety and error handling (#59)
Those changes affect the Prompt and Resources Method callback implementations only - Add abstract methods `validateParamType` and `assignExchangeType` for better parameter handling - Implement strict type validation for exchange parameters (Sync vs Async vs Transport Context) - Replace custom exceptions (`McpPromptMethodException`, `McpResourceMethodException`) with standardized `McpError` - Add comprehensive parameter type validation with detailed error messages - Enhance exchange type assignment logic with proper type checking - Update all concrete implementations (Sync, Async, Stateless variants) for both prompt and resource callbacks - Add test coverage for new validation logic and error scenarios Signed-off-by: Christian Tzolov <christian.tzolov@broadcom.com>
1 parent 793fc00 commit 0508e4c

19 files changed

+1106
-100
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*/
4+
5+
package org.springaicommunity.mcp;
6+
7+
import java.util.Objects;
8+
9+
public class ErrorUtils {
10+
11+
public static Throwable findCauseUsingPlainJava(Throwable throwable) {
12+
Objects.requireNonNull(throwable);
13+
Throwable rootCause = throwable;
14+
while (rootCause.getCause() != null && rootCause.getCause() != rootCause) {
15+
rootCause = rootCause.getCause();
16+
}
17+
return rootCause;
18+
}
19+
20+
}

mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AbstractMcpPromptMethodCallback.java

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
import org.springaicommunity.mcp.annotation.McpArg;
1313
import org.springaicommunity.mcp.annotation.McpMeta;
1414
import org.springaicommunity.mcp.annotation.McpProgressToken;
15-
15+
import io.modelcontextprotocol.common.McpTransportContext;
16+
import io.modelcontextprotocol.server.McpAsyncServerExchange;
17+
import io.modelcontextprotocol.server.McpSyncServerExchange;
1618
import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest;
1719
import io.modelcontextprotocol.spec.McpSchema.GetPromptResult;
1820
import io.modelcontextprotocol.spec.McpSchema.Prompt;
@@ -75,7 +77,10 @@ protected void validateMethod(Method method) {
7577
* @return true if the parameter type is compatible with the exchange type, false
7678
* otherwise
7779
*/
78-
protected abstract boolean isExchangeOrContextType(Class<?> paramType);
80+
protected abstract boolean isSupportedExchangeOrContextType(Class<?> paramType);
81+
82+
protected void validateParamType(Class<?> paramType) {
83+
}
7984

8085
/**
8186
* Validates method parameters.
@@ -95,6 +100,8 @@ protected void validateParameters(Method method) {
95100
for (java.lang.reflect.Parameter param : parameters) {
96101
Class<?> paramType = param.getType();
97102

103+
this.validateParamType(paramType);
104+
98105
// Skip @McpProgressToken annotated parameters from validation
99106
if (param.isAnnotationPresent(McpProgressToken.class)) {
100107
if (hasProgressTokenParam) {
@@ -115,7 +122,7 @@ protected void validateParameters(Method method) {
115122
continue;
116123
}
117124

118-
if (isExchangeOrContextType(paramType)) {
125+
if (isSupportedExchangeOrContextType(paramType)) {
119126
if (hasExchangeParam) {
120127
throw new IllegalArgumentException("Method cannot have more than one exchange parameter: "
121128
+ method.getName() + " in " + method.getDeclaringClass().getName());
@@ -140,6 +147,8 @@ else if (Map.class.isAssignableFrom(paramType)) {
140147
}
141148
}
142149

150+
protected abstract Object assignExchangeType(Class<?> paramType, Object exchange);
151+
143152
/**
144153
* Builds the arguments array for invoking the method.
145154
* <p>
@@ -182,8 +191,11 @@ protected Object[] buildArgs(Method method, Object exchange, GetPromptRequest re
182191
java.lang.reflect.Parameter param = parameters[i];
183192
Class<?> paramType = param.getType();
184193

185-
if (isExchangeOrContextType(paramType)) {
186-
args[i] = exchange;
194+
if (McpTransportContext.class.isAssignableFrom(paramType)
195+
|| McpSyncServerExchange.class.isAssignableFrom(paramType)
196+
|| McpAsyncServerExchange.class.isAssignableFrom(paramType)) {
197+
198+
args[i] = this.assignExchangeType(paramType, exchange);
187199
}
188200
else if (GetPromptRequest.class.isAssignableFrom(paramType)) {
189201
args[i] = request;
@@ -367,30 +379,4 @@ protected void validate() {
367379

368380
}
369381

370-
/**
371-
* Exception thrown when there is an error invoking a prompt method.
372-
*/
373-
public static class McpPromptMethodException extends RuntimeException {
374-
375-
private static final long serialVersionUID = 1L;
376-
377-
/**
378-
* Constructs a new exception with the specified detail message and cause.
379-
* @param message The detail message
380-
* @param cause The cause
381-
*/
382-
public McpPromptMethodException(String message, Throwable cause) {
383-
super(message, cause);
384-
}
385-
386-
/**
387-
* Constructs a new exception with the specified detail message.
388-
* @param message The detail message
389-
*/
390-
public McpPromptMethodException(String message) {
391-
super(message);
392-
}
393-
394-
}
395-
396382
}

mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AsyncMcpPromptMethodCallback.java

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77
import java.lang.reflect.Method;
88
import java.util.function.BiFunction;
99

10+
import org.springaicommunity.mcp.ErrorUtils;
1011
import org.springaicommunity.mcp.annotation.McpPrompt;
11-
12+
import io.modelcontextprotocol.common.McpTransportContext;
1213
import io.modelcontextprotocol.server.McpAsyncServerExchange;
14+
import io.modelcontextprotocol.server.McpSyncServerExchange;
15+
import io.modelcontextprotocol.spec.McpError;
16+
import io.modelcontextprotocol.spec.McpSchema.ErrorCodes;
1317
import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest;
1418
import io.modelcontextprotocol.spec.McpSchema.GetPromptResult;
1519
import reactor.core.publisher.Mono;
@@ -31,6 +35,48 @@ private AsyncMcpPromptMethodCallback(Builder builder) {
3135
super(builder.method, builder.bean, builder.prompt);
3236
}
3337

38+
@Override
39+
protected void validateParamType(Class<?> paramType) {
40+
41+
if (McpSyncServerExchange.class.isAssignableFrom(paramType)) {
42+
throw new IllegalArgumentException("Async prompt method must not declare parameter of type: "
43+
+ paramType.getName() + ". Use McpAsyncServerExchange instead." + " Method: "
44+
+ this.method.getName() + " in " + this.method.getDeclaringClass().getName());
45+
}
46+
}
47+
48+
@Override
49+
protected Object assignExchangeType(Class<?> paramType, Object exchange) {
50+
51+
if (McpTransportContext.class.isAssignableFrom(paramType)) {
52+
if (exchange instanceof McpTransportContext transportContext) {
53+
return transportContext;
54+
}
55+
else if (exchange instanceof McpSyncServerExchange syncServerExchange) {
56+
throw new IllegalArgumentException("Unsupported Async exchange type: "
57+
+ syncServerExchange.getClass().getName() + " for Async method: " + method.getName() + " in "
58+
+ method.getDeclaringClass().getName());
59+
60+
}
61+
else if (exchange instanceof McpAsyncServerExchange asyncServerExchange) {
62+
return asyncServerExchange.transportContext();
63+
}
64+
}
65+
else if (McpAsyncServerExchange.class.isAssignableFrom(paramType)) {
66+
if (exchange instanceof McpAsyncServerExchange asyncServerExchange) {
67+
return asyncServerExchange;
68+
}
69+
70+
throw new IllegalArgumentException(
71+
"Unsupported exchange type: " + (exchange != null ? exchange.getClass().getName() : "null")
72+
+ " for Async method: " + method.getName() + " in " + method.getDeclaringClass().getName());
73+
}
74+
75+
throw new IllegalArgumentException(
76+
"Unsupported exchange type: " + (exchange != null ? exchange.getClass().getName() : "null")
77+
+ " for method: " + method.getName() + " in " + method.getDeclaringClass().getName());
78+
}
79+
3480
/**
3581
* Apply the callback to the given exchange and request.
3682
* <p>
@@ -69,15 +115,24 @@ public Mono<GetPromptResult> apply(McpAsyncServerExchange exchange, GetPromptReq
69115
}
70116
}
71117
catch (Exception e) {
72-
return Mono
73-
.error(new McpPromptMethodException("Error invoking prompt method: " + this.method.getName(), e));
118+
if (e instanceof McpError mcpError && mcpError.getJsonRpcError() != null) {
119+
return Mono.error(mcpError);
120+
}
121+
122+
return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS)
123+
.message("Error invoking prompt method: " + this.method.getName() + " in "
124+
+ this.bean.getClass().getName() + ". /nCause: "
125+
+ ErrorUtils.findCauseUsingPlainJava(e).getMessage())
126+
.data(ErrorUtils.findCauseUsingPlainJava(e).getMessage())
127+
.build());
74128
}
75129
});
76130
}
77131

78132
@Override
79-
protected boolean isExchangeOrContextType(Class<?> paramType) {
80-
return McpAsyncServerExchange.class.isAssignableFrom(paramType);
133+
protected boolean isSupportedExchangeOrContextType(Class<?> paramType) {
134+
return (McpAsyncServerExchange.class.isAssignableFrom(paramType)
135+
|| McpTransportContext.class.isAssignableFrom(paramType));
81136
}
82137

83138
@Override

mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AsyncStatelessMcpPromptMethodCallback.java

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,13 @@
88
import java.util.List;
99
import java.util.function.BiFunction;
1010

11+
import org.springaicommunity.mcp.ErrorUtils;
1112
import org.springaicommunity.mcp.annotation.McpPrompt;
1213
import io.modelcontextprotocol.common.McpTransportContext;
14+
import io.modelcontextprotocol.server.McpAsyncServerExchange;
15+
import io.modelcontextprotocol.server.McpSyncServerExchange;
16+
import io.modelcontextprotocol.spec.McpError;
17+
import io.modelcontextprotocol.spec.McpSchema.ErrorCodes;
1318
import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest;
1419
import io.modelcontextprotocol.spec.McpSchema.GetPromptResult;
1520
import io.modelcontextprotocol.spec.McpSchema.PromptMessage;
@@ -32,6 +37,42 @@ private AsyncStatelessMcpPromptMethodCallback(Builder builder) {
3237
super(builder.method, builder.bean, builder.prompt);
3338
}
3439

40+
@Override
41+
protected void validateParamType(Class<?> paramType) {
42+
43+
if (McpSyncServerExchange.class.isAssignableFrom(paramType)
44+
|| McpAsyncServerExchange.class.isAssignableFrom(paramType)) {
45+
46+
throw new IllegalArgumentException(
47+
"Stateless Streamable-Http prompt method must not declare parameter of type: " + paramType.getName()
48+
+ ". Use McpTransportContext instead." + " Method: " + this.method.getName() + " in "
49+
+ this.method.getDeclaringClass().getName());
50+
}
51+
}
52+
53+
@Override
54+
protected Object assignExchangeType(Class<?> paramType, Object exchange) {
55+
56+
if (McpTransportContext.class.isAssignableFrom(paramType)) {
57+
if (exchange instanceof McpTransportContext transportContext) {
58+
return transportContext;
59+
}
60+
else if (exchange instanceof McpSyncServerExchange syncServerExchange) {
61+
throw new IllegalArgumentException("Unsupported Sync exchange type: "
62+
+ syncServerExchange.getClass().getName() + " for Sync method: " + method.getName() + " in "
63+
+ method.getDeclaringClass().getName());
64+
65+
}
66+
else if (exchange instanceof McpAsyncServerExchange asyncServerExchange) {
67+
return asyncServerExchange.transportContext();
68+
}
69+
}
70+
71+
throw new IllegalArgumentException(
72+
"Unsupported exchange type: " + (exchange != null ? exchange.getClass().getName() : "null")
73+
+ " for method: " + method.getName() + " in " + method.getDeclaringClass().getName());
74+
}
75+
3576
/**
3677
* Apply the callback to the given context and request.
3778
* <p>
@@ -70,14 +111,23 @@ public Mono<GetPromptResult> apply(McpTransportContext context, GetPromptRequest
70111
}
71112
}
72113
catch (Exception e) {
73-
return Mono
74-
.error(new McpPromptMethodException("Error invoking prompt method: " + this.method.getName(), e));
114+
115+
if (e instanceof McpError mcpError && mcpError.getJsonRpcError() != null) {
116+
return Mono.error(mcpError);
117+
}
118+
119+
return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS)
120+
.message("Error invoking prompt method: " + this.method.getName() + " in "
121+
+ this.bean.getClass().getName() + ". /nCause: "
122+
+ ErrorUtils.findCauseUsingPlainJava(e).getMessage())
123+
.data(ErrorUtils.findCauseUsingPlainJava(e).getMessage())
124+
.build());
75125
}
76126
});
77127
}
78128

79129
@Override
80-
protected boolean isExchangeOrContextType(Class<?> paramType) {
130+
protected boolean isSupportedExchangeOrContextType(Class<?> paramType) {
81131
return McpTransportContext.class.isAssignableFrom(paramType);
82132
}
83133

mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/SyncMcpPromptMethodCallback.java

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@
88
import java.util.List;
99
import java.util.function.BiFunction;
1010

11+
import org.springaicommunity.mcp.ErrorUtils;
1112
import org.springaicommunity.mcp.annotation.McpPrompt;
12-
13+
import io.modelcontextprotocol.common.McpTransportContext;
14+
import io.modelcontextprotocol.server.McpAsyncServerExchange;
1315
import io.modelcontextprotocol.server.McpSyncServerExchange;
16+
import io.modelcontextprotocol.spec.McpError;
17+
import io.modelcontextprotocol.spec.McpSchema.ErrorCodes;
1418
import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest;
1519
import io.modelcontextprotocol.spec.McpSchema.GetPromptResult;
1620
import io.modelcontextprotocol.spec.McpSchema.PromptMessage;
@@ -31,6 +35,47 @@ private SyncMcpPromptMethodCallback(Builder builder) {
3135
super(builder.method, builder.bean, builder.prompt);
3236
}
3337

38+
@Override
39+
protected void validateParamType(Class<?> paramType) {
40+
41+
if (McpAsyncServerExchange.class.isAssignableFrom(paramType)) {
42+
throw new IllegalArgumentException("Sync prompt method must not declare parameter of type: "
43+
+ paramType.getName() + ". Use McpSyncServerExchange instead." + " Method: " + this.method.getName()
44+
+ " in " + this.method.getDeclaringClass().getName());
45+
}
46+
}
47+
48+
@Override
49+
protected Object assignExchangeType(Class<?> paramType, Object exchange) {
50+
51+
if (McpTransportContext.class.isAssignableFrom(paramType)) {
52+
if (exchange instanceof McpTransportContext transportContext) {
53+
return transportContext;
54+
}
55+
else if (exchange instanceof McpSyncServerExchange syncServerExchange) {
56+
return syncServerExchange.transportContext();
57+
}
58+
else if (exchange instanceof McpAsyncServerExchange asyncServerExchange) {
59+
throw new IllegalArgumentException("Unsupported Async exchange type: "
60+
+ asyncServerExchange.getClass().getName() + " for Sync method: " + method.getName() + " in "
61+
+ method.getDeclaringClass().getName());
62+
}
63+
}
64+
else if (McpSyncServerExchange.class.isAssignableFrom(paramType)) {
65+
if (exchange instanceof McpSyncServerExchange syncServerExchange) {
66+
return syncServerExchange;
67+
}
68+
69+
throw new IllegalArgumentException(
70+
"Unsupported exchange type: " + (exchange != null ? exchange.getClass().getName() : "null")
71+
+ " for Sync method: " + method.getName() + " in " + method.getDeclaringClass().getName());
72+
}
73+
74+
throw new IllegalArgumentException(
75+
"Unsupported exchange type: " + (exchange != null ? exchange.getClass().getName() : "null")
76+
+ " for method: " + method.getName() + " in " + method.getDeclaringClass().getName());
77+
}
78+
3479
/**
3580
* Apply the callback to the given exchange and request.
3681
* <p>
@@ -62,13 +107,23 @@ public GetPromptResult apply(McpSyncServerExchange exchange, GetPromptRequest re
62107
return promptResult;
63108
}
64109
catch (Exception e) {
65-
throw new McpPromptMethodException("Error invoking prompt method: " + this.method.getName(), e);
110+
if (e instanceof McpError mcpError && mcpError.getJsonRpcError() != null) {
111+
throw mcpError;
112+
}
113+
114+
throw McpError.builder(ErrorCodes.INVALID_PARAMS)
115+
.message("Error invoking prompt method: " + this.method.getName() + " in "
116+
+ this.bean.getClass().getName() + "./nCause: "
117+
+ ErrorUtils.findCauseUsingPlainJava(e).getMessage())
118+
.data(ErrorUtils.findCauseUsingPlainJava(e).getMessage())
119+
.build();
66120
}
67121
}
68122

69123
@Override
70-
protected boolean isExchangeOrContextType(Class<?> paramType) {
71-
return McpSyncServerExchange.class.isAssignableFrom(paramType);
124+
protected boolean isSupportedExchangeOrContextType(Class<?> paramType) {
125+
return (McpSyncServerExchange.class.isAssignableFrom(paramType)
126+
|| McpTransportContext.class.isAssignableFrom(paramType));
72127
}
73128

74129
@Override

0 commit comments

Comments
 (0)