diff --git a/README.md b/README.md index e3f3b54..c5bf3b7 100644 --- a/README.md +++ b/README.md @@ -920,18 +920,21 @@ public String processWithContext( // Use exchange for additional operations... } - // Perform elicitation with default message - returns StructuredElicitResult - Optional> result = context.elicit(new TypeReference() {}); - - // Or perform elicitation with custom configuration - returns StructuredElicitResult - Optional> structuredResult = context.elicit( - e -> e.message("Please provide your information").meta("context", "user-registration"), - new TypeReference() {} - ); - - if (structuredResult.isPresent() && structuredResult.get().action() == ElicitResult.Action.ACCEPT) { - UserInfo info = structuredResult.get().structuredContent(); - return "Processed: " + data + " for user " + info.name(); + // Check if elicitation is supported before using it + if (context.elicitEnabled()) { + // Perform elicitation with default message - returns StructuredElicitResult + StructuredElicitResult result = context.elicit(new TypeReference() {}); + + // Or perform elicitation with custom configuration - returns StructuredElicitResult + StructuredElicitResult structuredResult = context.elicit( + e -> e.message("Please provide your information").meta("context", "user-registration"), + new TypeReference() {} + ); + + if (structuredResult.action() == ElicitResult.Action.ACCEPT) { + UserInfo info = structuredResult.structuredContent(); + return "Processed: " + data + " for user " + info.name(); + } } return "Processed: " + data; @@ -962,10 +965,14 @@ public GetPromptResult generateWithContext( // Log prompt generation context.info("Generating prompt for topic: " + topic); - // Perform sampling if needed - Optional samplingResult = context.sample( - "What are the key points about " + topic + "?" - ); + // Check if sampling is supported before using it + if (context.sampleEnabled()) { + // Perform sampling if needed + CreateMessageResult samplingResult = context.sample( + "What are the key points about " + topic + "?" + ); + // Use sampling result... + } String message = "Let's discuss " + topic; return new GetPromptResult("Generated Prompt", @@ -1059,16 +1066,24 @@ public Mono asyncGenerateWithContext( - `log(Consumer)` - Send log messages with custom configuration - `debug(String)`, `info(String)`, `warn(String)`, `error(String)` - Convenience logging methods - `progress(int)`, `progress(Consumer)` - Send progress updates -- `elicit(TypeReference)` - Request user input with default message, returns `StructuredElicitResult` with action, typed content, and metadata -- `elicit(Class)` - Request user input with default message using Class type, returns `StructuredElicitResult` -- `elicit(Consumer, TypeReference)` - Request user input with custom configuration, returns `StructuredElicitResult` -- `elicit(Consumer, Class)` - Request user input with custom configuration using Class type, returns `StructuredElicitResult` -- `elicit(ElicitRequest)` - Request user input with full control over the elicitation request -- `sample(...)` - Request LLM sampling with various configuration options -- `roots()` - Access root directories (returns `Optional`) +- `rootsEnabled()` - Check if roots capability is supported by the client +- `roots()` - Access root directories (throws `IllegalStateException` if not supported) +- `elicitEnabled()` - Check if elicitation capability is supported by the client +- `elicit(TypeReference)` - Request user input with default message, returns `StructuredElicitResult` with action, typed content, and metadata (throws `IllegalStateException` if not supported) +- `elicit(Class)` - Request user input with default message using Class type, returns `StructuredElicitResult` (throws `IllegalStateException` if not supported) +- `elicit(Consumer, TypeReference)` - Request user input with custom configuration, returns `StructuredElicitResult` (throws `IllegalStateException` if not supported) +- `elicit(Consumer, Class)` - Request user input with custom configuration using Class type, returns `StructuredElicitResult` (throws `IllegalStateException` if not supported) +- `elicit(ElicitRequest)` - Request user input with full control over the elicitation request (throws `IllegalStateException` if not supported) +- `sampleEnabled()` - Check if sampling capability is supported by the client +- `sample(...)` - Request LLM sampling with various configuration options (throws `IllegalStateException` if not supported) - `ping()` - Send ping to check connection -`McpAsyncRequestContext` provides the same methods but with reactive return types (`Mono` instead of `T` or `Optional`). +`McpAsyncRequestContext` provides the same methods but with reactive return types (`Mono` instead of `T`). Methods that throw `IllegalStateException` in sync context return `Mono.error(IllegalStateException)` in async context. + +**Important Notes on Capability Checking:** +- Always check capability support using `rootsEnabled()`, `elicitEnabled()`, or `sampleEnabled()` before calling the corresponding methods +- Calling capability methods when not supported will throw `IllegalStateException` (sync) or return `Mono.error()` (async) +- Stateless servers do not support bidirectional operations (roots, elicitation, sampling) and will always return `false` for capability checks This unified context approach simplifies method signatures and provides a consistent API across different operation types and execution modes (stateful vs stateless, sync vs async). @@ -2104,6 +2119,9 @@ public class StatelessResourceProvider { } ``` +**Important Note on Stateless Operations:** +Stateless server methods cannot use bidirectional parameters like `McpSyncRequestContext`, `McpAsyncRequestContext`, `McpSyncServerExchange`, or `McpAsyncServerExchange`. These parameters require client capabilities (roots, elicitation, sampling) that are not available in stateless mode. Methods with these parameters will be automatically filtered out and not registered as stateless operations. + #### Stateless Tool Example ```java diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java index dddcc34..0dab1fc 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java @@ -59,17 +59,31 @@ private DefaultMcpAsyncRequestContext(McpSchema.Request request, McpAsyncServerE // Roots + @Override + public Mono rootsEnabled() { + return Mono.just(!(this.exchange.getClientCapabilities() == null + || this.exchange.getClientCapabilities().roots() == null)); + } + @Override public Mono roots() { - if (this.exchange.getClientCapabilities() == null || this.exchange.getClientCapabilities().roots() == null) { - logger.warn("Roots not supported by the client! Ignoring the roots request for request:" + this.request); - return Mono.empty(); - } - return this.exchange.listRoots(); + return this.rootsEnabled().flatMap(enabled -> { + if (!enabled) { + return Mono.error(new IllegalStateException( + "Roots not supported by the client: " + this.exchange.getClientInfo())); + } + return this.exchange.listRoots(); + }); } // Elicitation + @Override + public Mono elicitEnabled() { + return Mono.just(!(this.exchange.getClientCapabilities() == null + || this.exchange.getClientCapabilities().elicitation() == null)); + } + @Override public Mono> elicit(Consumer spec, TypeReference type) { Assert.notNull(type, "Elicitation response type must not be null"); @@ -112,14 +126,13 @@ public Mono> elicit(Class type) { public Mono elicit(ElicitRequest elicitRequest) { Assert.notNull(elicitRequest, "Elicit request must not be null"); - if (this.exchange.getClientCapabilities() == null - || this.exchange.getClientCapabilities().elicitation() == null) { - logger.warn("Elicitation not supported by the client! Ignoring the elicitation request for request:" - + elicitRequest); - return Mono.empty(); - } - - return this.exchange.createElicitation(elicitRequest); + return this.elicitEnabled().flatMap(enabled -> { + if (!enabled) { + return Mono.error(new IllegalStateException( + "Elicitation not supported by the client: " + this.exchange.getClientInfo())); + } + return this.exchange.createElicitation(elicitRequest); + }); } public Mono elicitationInternal(String message, Type type, Map meta) { @@ -143,6 +156,12 @@ private Map generateElicitSchema(Type type) { // Sampling + @Override + public Mono sampleEnabled() { + return Mono.just(!(this.exchange.getClientCapabilities() == null + || this.exchange.getClientCapabilities().sampling() == null)); + } + @Override public Mono sample(String... messages) { return this.sample(s -> s.message(messages)); @@ -176,14 +195,13 @@ public Mono sample(Consumer samplingSpec) { @Override public Mono sample(CreateMessageRequest createMessageRequest) { - // check if supported - if (this.exchange.getClientCapabilities() == null || this.exchange.getClientCapabilities().sampling() == null) { - logger.warn("Sampling not supported by the client! Ignoring the sampling request for messages:" - + createMessageRequest); - return Mono.empty(); - } - - return this.exchange.createMessage(createMessageRequest); + return this.sampleEnabled().flatMap(enabled -> { + if (!enabled) { + return Mono.error(new IllegalStateException( + "Sampling not supported by the client: " + this.exchange.getClientInfo())); + } + return this.exchange.createMessage(createMessageRequest); + }); } // Progress @@ -317,10 +335,6 @@ public static class Builder { private McpAsyncServerExchange exchange; - private boolean isStateless = false; - - private McpTransportContext transportContext; - private Builder() { } @@ -334,178 +348,10 @@ public Builder exchange(McpAsyncServerExchange exchange) { return this; } - public Builder stateless(boolean isStateless) { - this.isStateless = isStateless; - return this; - } - - public Builder transportContext(McpTransportContext transportContext) { - this.transportContext = transportContext; - return this; - } - public McpAsyncRequestContext build() { - if (this.isStateless) { - return new StatelessAsyncRequestContext(this.request, this.transportContext); - } return new DefaultMcpAsyncRequestContext(this.request, this.exchange); } } - private static class StatelessAsyncRequestContext implements McpAsyncRequestContext { - - private final McpSchema.Request request; - - private McpTransportContext transportContext; - - public StatelessAsyncRequestContext(McpSchema.Request request, McpTransportContext transportContext) { - this.request = request; - this.transportContext = transportContext; - } - - @Override - public Mono roots() { - logger.warn("Roots not supported by the client! Ignoring the roots request"); - return Mono.empty(); - } - - @Override - public Mono> elicit(Consumer spec, TypeReference returnType) { - logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); - return Mono.empty(); - } - - @Override - public Mono> elicit(TypeReference type) { - logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); - return Mono.empty(); - } - - @Override - public Mono> elicit(Consumer spec, Class returnType) { - logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); - return Mono.empty(); - } - - @Override - public Mono> elicit(Class type) { - logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); - return Mono.empty(); - } - - @Override - public Mono elicit(ElicitRequest elicitRequest) { - logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); - return Mono.empty(); - } - - @Override - public Mono sample(String... messages) { - logger.warn("Sampling not supported by the client! Ignoring the sampling request"); - return Mono.empty(); - } - - @Override - public Mono sample(Consumer samplingSpec) { - logger.warn("Sampling not supported by the client! Ignoring the sampling request"); - return Mono.empty(); - } - - @Override - public Mono sample(CreateMessageRequest createMessageRequest) { - logger.warn("Sampling not supported by the client! Ignoring the sampling request"); - return Mono.empty(); - } - - @Override - public Mono progress(int progress) { - logger.warn("Progress not supported by the client! Ignoring the progress request"); - return Mono.empty(); - } - - @Override - public Mono progress(Consumer progressSpec) { - logger.warn("Progress not supported by the client! Ignoring the progress request"); - return Mono.empty(); - } - - @Override - public Mono progress(ProgressNotification progressNotification) { - logger.warn("Progress not supported by the client! Ignoring the progress request"); - return Mono.empty(); - } - - @Override - public Mono ping() { - logger.warn("Ping not supported by the client! Ignoring the ping request"); - return Mono.empty(); - } - - @Override - public Mono log(Consumer logSpec) { - logger.warn("Logging not supported by the client! Ignoring the logging request"); - return Mono.empty(); - } - - @Override - public Mono debug(String message) { - logger.warn("Debug not supported by the client! Ignoring the debug request"); - return Mono.empty(); - } - - @Override - public Mono info(String message) { - logger.warn("Info not supported by the client! Ignoring the info request"); - return Mono.empty(); - } - - @Override - public Mono warn(String message) { - logger.warn("Warn not supported by the client! Ignoring the warn request"); - return Mono.empty(); - } - - @Override - public Mono error(String message) { - logger.warn("Error not supported by the client! Ignoring the error request"); - return Mono.empty(); - } - - // Getters - - public McpSchema.Request request() { - return this.request; - } - - public McpAsyncServerExchange exchange() { - logger.warn("Stateless servers do not support exchange! Returning null"); - return null; - } - - public String sessionId() { - logger.warn("Stateless servers do not support session ID! Returning null"); - return null; - } - - public Implementation clientInfo() { - logger.warn("Stateless servers do not support client info! Returning null"); - return null; - } - - public ClientCapabilities clientCapabilities() { - logger.warn("Stateless servers do not support client capabilities! Returning null"); - return null; - } - - public Map requestMeta() { - return this.request.meta(); - } - - public McpTransportContext transportContext() { - return transportContext; - } - - } - } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java index 45bc302..d612331 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java @@ -6,7 +6,6 @@ import java.lang.reflect.Type; import java.util.Map; -import java.util.Optional; import java.util.function.Consumer; import com.fasterxml.jackson.core.type.TypeReference; @@ -56,111 +55,145 @@ private DefaultMcpSyncRequestContext(McpSchema.Request request, McpSyncServerExc // Roots - public Optional roots() { - if (this.exchange.getClientCapabilities() == null || this.exchange.getClientCapabilities().roots() == null) { - logger.warn("Roots not supported by the client! Ignoring the roots request for request:" + this.request); - return Optional.empty(); + @Override + public boolean rootsEnabled() { + return !(this.exchange.getClientCapabilities() == null + || this.exchange.getClientCapabilities().roots() == null); + } + + @Override + public ListRootsResult roots() { + if (!this.rootsEnabled()) { + throw new IllegalStateException("Roots not supported by the client: " + this.exchange.getClientInfo()); } - return Optional.of(this.exchange.listRoots()); + return this.exchange.listRoots(); } // Elicitation @Override - public Optional> elicit(Class type) { + public boolean elicitEnabled() { + return !(this.exchange.getClientCapabilities() == null + || this.exchange.getClientCapabilities().elicitation() == null); + } + + @Override + public StructuredElicitResult elicit(Class type) { + + if (!this.elicitEnabled()) { + throw new IllegalStateException( + "Elicitation not supported by the client: " + this.exchange.getClientInfo()); + } + Assert.notNull(type, "Elicitation response type must not be null"); - Optional elicitResult = this.elicitationInternal("Please provide the required information.", type, - null); + ElicitResult elicitResult = this.elicitationInternal("Please provide the required information.", type, null); - if (!elicitResult.isPresent() || elicitResult.get().action() != ElicitResult.Action.ACCEPT) { - return Optional.empty(); + if (elicitResult.action() != ElicitResult.Action.ACCEPT) { + return new StructuredElicitResult<>(elicitResult.action(), null, elicitResult.meta()); } - return Optional.of(new StructuredElicitResult<>(elicitResult.get().action(), - JsonParser.convertMapToType(elicitResult.get().content(), type), elicitResult.get().meta())); + return new StructuredElicitResult<>(elicitResult.action(), + JsonParser.convertMapToType(elicitResult.content(), type), elicitResult.meta()); } @Override - public Optional> elicit(TypeReference type) { + public StructuredElicitResult elicit(TypeReference type) { + + if (!this.elicitEnabled()) { + throw new IllegalStateException( + "Elicitation not supported by the client: " + this.exchange.getClientInfo()); + } + Assert.notNull(type, "Elicitation response type must not be null"); - Optional elicitResult = this.elicitationInternal("Please provide the required information.", - type.getType(), null); + ElicitResult elicitResult = this.elicitationInternal("Please provide the required information.", type.getType(), + null); - if (!elicitResult.isPresent() || elicitResult.get().action() != ElicitResult.Action.ACCEPT) { - return Optional.empty(); + if (elicitResult.action() != ElicitResult.Action.ACCEPT) { + return new StructuredElicitResult<>(elicitResult.action(), null, elicitResult.meta()); } - return Optional.of(new StructuredElicitResult<>(elicitResult.get().action(), - JsonParser.convertMapToType(elicitResult.get().content(), type), elicitResult.get().meta())); + return new StructuredElicitResult<>(elicitResult.action(), + JsonParser.convertMapToType(elicitResult.content(), type), elicitResult.meta()); } @Override - public Optional> elicit(Consumer params, Class returnType) { + public StructuredElicitResult elicit(Consumer params, Class returnType) { + + if (!this.elicitEnabled()) { + throw new IllegalStateException( + "Elicitation not supported by the client: " + this.exchange.getClientInfo()); + } + Assert.notNull(returnType, "Elicitation response type must not be null"); Assert.notNull(params, "Elicitation params must not be null"); DefaultElicitationSpec paramSpec = new DefaultElicitationSpec(); + params.accept(paramSpec); - Optional elicitResult = this.elicitationInternal(paramSpec.message(), returnType, - paramSpec.meta()); + ElicitResult elicitResult = this.elicitationInternal(paramSpec.message(), returnType, paramSpec.meta()); - if (!elicitResult.isPresent() || elicitResult.get().action() != ElicitResult.Action.ACCEPT) { - return Optional.empty(); + if (elicitResult.action() != ElicitResult.Action.ACCEPT) { + return new StructuredElicitResult<>(elicitResult.action(), null, null); } - return Optional.of(new StructuredElicitResult<>(elicitResult.get().action(), - JsonParser.convertMapToType(elicitResult.get().content(), returnType), elicitResult.get().meta())); + return new StructuredElicitResult<>(elicitResult.action(), + JsonParser.convertMapToType(elicitResult.content(), returnType), elicitResult.meta()); } @Override - public Optional> elicit(Consumer params, - TypeReference returnType) { + public StructuredElicitResult elicit(Consumer params, TypeReference returnType) { + + if (!this.elicitEnabled()) { + throw new IllegalStateException( + "Elicitation not supported by the client: " + this.exchange.getClientInfo()); + } + Assert.notNull(returnType, "Elicitation response type must not be null"); Assert.notNull(params, "Elicitation params must not be null"); DefaultElicitationSpec paramSpec = new DefaultElicitationSpec(); params.accept(paramSpec); - Optional elicitResult = this.elicitationInternal(paramSpec.message(), returnType.getType(), + ElicitResult elicitResult = this.elicitationInternal(paramSpec.message(), returnType.getType(), paramSpec.meta()); - if (!elicitResult.isPresent() || elicitResult.get().action() != ElicitResult.Action.ACCEPT) { - return Optional.empty(); + if (elicitResult.action() != ElicitResult.Action.ACCEPT) { + return new StructuredElicitResult<>(elicitResult.action(), null, null); } - return Optional.of(new StructuredElicitResult<>(elicitResult.get().action(), - JsonParser.convertMapToType(elicitResult.get().content(), returnType), elicitResult.get().meta())); + return new StructuredElicitResult<>(elicitResult.action(), + JsonParser.convertMapToType(elicitResult.content(), returnType), elicitResult.meta()); } @Override - public Optional elicit(ElicitRequest elicitRequest) { - Assert.notNull(elicitRequest, "Elicit request must not be null"); - - if (this.exchange.getClientCapabilities() == null - || this.exchange.getClientCapabilities().elicitation() == null) { - logger.warn("Elicitation not supported by the client! Ignoring the elicitation request for request:" - + elicitRequest); - return Optional.empty(); + public ElicitResult elicit(ElicitRequest elicitRequest) { + if (!this.elicitEnabled()) { + throw new IllegalStateException( + "Elicitation not supported by the client: " + this.exchange.getClientInfo()); } - ElicitResult elicitResult = this.exchange.createElicitation(elicitRequest); + Assert.notNull(elicitRequest, "Elicit request must not be null"); - return Optional.of(elicitResult); + return this.exchange.createElicitation(elicitRequest); } - private Optional elicitationInternal(String message, Type type, Map meta) { - Assert.hasText(message, "Elicitation message must not be empty"); - Assert.notNull(type, "Elicitation response type must not be null"); + private ElicitResult elicitationInternal(String message, Type type, Map meta) { // TODO add validation for the Elicitation Schema // https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#supported-schema-types Map schema = typeSchemaCache.computeIfAbsent(type, t -> this.generateElicitSchema(t)); - return this.elicit(ElicitRequest.builder().message(message).requestedSchema(schema).meta(meta).build()); + ElicitRequest elicitRequest = ElicitRequest.builder() + .message(message) + .requestedSchema(schema) + .meta(meta) + .build(); + + return this.exchange.createElicitation(elicitRequest); } private Map generateElicitSchema(Type type) { @@ -173,21 +206,30 @@ private Map generateElicitSchema(Type type) { // Sampling @Override - public Optional sample(String... messages) { + public boolean sampleEnabled() { + return !(this.exchange.getClientCapabilities() == null + || this.exchange.getClientCapabilities().sampling() == null); + } + + @Override + public CreateMessageResult sample(String... messages) { return this.sample(s -> s.message(messages)); } @Override - public Optional sample(Consumer samplingSpec) { + public CreateMessageResult sample(Consumer samplingSpec) { + + if (!this.sampleEnabled()) { + throw new IllegalStateException("Sampling not supported by the client: " + this.exchange.getClientInfo()); + } + Assert.notNull(samplingSpec, "Sampling spec consumer must not be null"); + DefaultSamplingSpec spec = new DefaultSamplingSpec(); samplingSpec.accept(spec); var progressToken = this.request.progressToken(); - if (!Utils.hasText(progressToken)) { - logger.warn("Progress notification not supported by the client!"); - } return this.sample(McpSchema.CreateMessageRequest.builder() .messages(spec.messages) .modelPreferences(spec.modelPreferences) @@ -203,16 +245,13 @@ public Optional sample(Consumer samplingSpec) } @Override - public Optional sample(CreateMessageRequest createMessageRequest) { + public CreateMessageResult sample(CreateMessageRequest createMessageRequest) { - // check if supported - if (this.exchange.getClientCapabilities() == null || this.exchange.getClientCapabilities().sampling() == null) { - logger.warn("Sampling not supported by the client! Ignoring the sampling request for messages:" - + createMessageRequest); - return Optional.empty(); + if (!this.sampleEnabled()) { + throw new IllegalStateException("Sampling not supported by the client: " + this.exchange.getClientInfo()); } - return Optional.of(this.exchange.createMessage(createMessageRequest)); + return this.exchange.createMessage(createMessageRequest); } // Progress @@ -342,10 +381,6 @@ public static class Builder { private McpSyncServerExchange exchange; - private McpTransportContext transportContext; - - private boolean isStateless = false; - private Builder() { } @@ -359,171 +394,10 @@ public Builder exchange(McpSyncServerExchange exchange) { return this; } - public Builder transportContext(McpTransportContext transportContext) { - this.transportContext = transportContext; - return this; - } - - public Builder stateless(boolean isStateless) { - this.isStateless = isStateless; - return this; - } - public McpSyncRequestContext build() { - if (this.isStateless) { - return new StatelessMcpSyncRequestContext(this.request, this.transportContext); - } return new DefaultMcpSyncRequestContext(this.request, this.exchange); } } - public final static class StatelessMcpSyncRequestContext implements McpSyncRequestContext { - - private static final Logger logger = LoggerFactory.getLogger(StatelessMcpSyncRequestContext.class); - - private final McpSchema.Request request; - - private final McpTransportContext transportContext; - - private StatelessMcpSyncRequestContext(McpSchema.Request request, McpTransportContext transportContext) { - this.request = request; - this.transportContext = transportContext; - } - - @Override - public Optional roots() { - logger.warn("Roots not supported by the client! Ignoring the roots request"); - return Optional.empty(); - } - - @Override - public Optional> elicit(Class type) { - logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); - return Optional.empty(); - } - - @Override - public Optional> elicit(Consumer params, Class returnType) { - logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); - return Optional.empty(); - } - - @Override - public Optional> elicit(TypeReference type) { - logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); - return Optional.empty(); - } - - @Override - public Optional> elicit(Consumer params, - TypeReference returnType) { - logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); - return Optional.empty(); - } - - @Override - public Optional elicit(ElicitRequest elicitRequest) { - logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); - return Optional.empty(); - } - - @Override - public Optional sample(String... messages) { - logger.warn("Stateless servers do not support sampling! Ignoring the sampling request"); - return Optional.empty(); - } - - @Override - public Optional sample(Consumer samplingSpec) { - logger.warn("Stateless servers do not support sampling! Ignoring the sampling request"); - return Optional.empty(); - } - - @Override - public Optional sample(CreateMessageRequest createMessageRequest) { - logger.warn("Stateless servers do not support sampling! Ignoring the sampling request"); - return Optional.empty(); - } - - @Override - public void progress(int progress) { - logger.warn("Stateless servers do not support progress notifications! Ignoring the progress request"); - } - - @Override - public void progress(Consumer progressSpec) { - logger.warn("Stateless servers do not support progress notifications! Ignoring the progress request"); - } - - @Override - public void progress(ProgressNotification progressNotification) { - logger.warn("Stateless servers do not support progress notifications! Ignoring the progress request"); - } - - @Override - public void ping() { - logger.warn("Stateless servers do not support ping! Ignoring the ping request"); - } - - @Override - public void log(Consumer logSpec) { - logger.warn("Stateless servers do not support logging! Ignoring the logging request"); - } - - @Override - public void debug(String message) { - logger.warn("Stateless servers do not support debugging! Ignoring the debugging request"); - } - - @Override - public void info(String message) { - logger.warn("Stateless servers do not support info logging! Ignoring the info request"); - } - - @Override - public void warn(String message) { - logger.warn("Stateless servers do not support warning logging! Ignoring the warning request"); - } - - @Override - public void error(String message) { - logger.warn("Stateless servers do not support error logging! Ignoring the error request"); - } - - public McpSchema.Request request() { - return this.request; - } - - public McpTransportContext transportContext() { - return transportContext; - } - - public String sessionId() { - logger.warn("Stateless servers do not support session ID! Returning null"); - return null; - } - - public Implementation clientInfo() { - logger.warn("Stateless servers do not support client info! Returning null"); - return null; - } - - public ClientCapabilities clientCapabilities() { - logger.warn("Stateless servers do not support client capabilities! Returning null"); - return null; - } - - public Map requestMeta() { - return this.request.meta(); - } - - @Override - public McpSyncServerExchange exchange() { - logger.warn("Stateless servers do not support exchange! Returning null"); - return null; - } - - } - } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java index 08bad44..d926dbd 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java @@ -4,7 +4,6 @@ package org.springaicommunity.mcp.context; -import java.util.Map; import java.util.function.Consumer; import com.fasterxml.jackson.core.type.TypeReference; @@ -27,11 +26,14 @@ public interface McpAsyncRequestContext extends McpRequestContextTypes rootsEnabled(); + Mono roots(); // -------------------------------------- // Elicitation // -------------------------------------- + Mono elicitEnabled(); Mono> elicit(Class type); @@ -46,6 +48,8 @@ public interface McpAsyncRequestContext extends McpRequestContextTypes sampleEnabled(); + Mono sample(String... messages); Mono sample(Consumer samplingSpec); diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java index 10eb3b6..7899342 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java @@ -24,29 +24,35 @@ public interface McpSyncRequestContext extends McpRequestContextTypes roots(); + boolean rootsEnabled(); + + ListRootsResult roots(); // -------------------------------------- // Elicitation // -------------------------------------- - Optional> elicit(Class type); + boolean elicitEnabled(); + + StructuredElicitResult elicit(Class type); - Optional> elicit(TypeReference type); + StructuredElicitResult elicit(TypeReference type); - Optional> elicit(Consumer params, Class returnType); + StructuredElicitResult elicit(Consumer params, Class returnType); - Optional> elicit(Consumer params, TypeReference returnType); + StructuredElicitResult elicit(Consumer params, TypeReference returnType); - Optional elicit(ElicitRequest elicitRequest); + ElicitResult elicit(ElicitRequest elicitRequest); // -------------------------------------- // Sampling // -------------------------------------- - Optional sample(String... messages); + boolean sampleEnabled(); + + CreateMessageResult sample(String... messages); - Optional sample(Consumer samplingSpec); + CreateMessageResult sample(Consumer samplingSpec); - Optional sample(CreateMessageRequest createMessageRequest); + CreateMessageResult sample(CreateMessageRequest createMessageRequest); // -------------------------------------- // Progress diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractMcpToolMethodCallback.java index 3e9ea3f..72314f2 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractMcpToolMethodCallback.java @@ -44,7 +44,7 @@ * McpSyncServerExchange, or McpAsyncServerExchange) * @author Christian Tzolov */ -public abstract class AbstractMcpToolMethodCallback { +public abstract class AbstractMcpToolMethodCallback> { protected final Method toolMethod; diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java index 06957d9..dac6010 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java @@ -22,7 +22,6 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import org.springaicommunity.mcp.annotation.McpTool; -import org.springaicommunity.mcp.context.DefaultMcpAsyncRequestContext; import org.springaicommunity.mcp.context.McpAsyncRequestContext; import reactor.core.publisher.Mono; @@ -57,12 +56,8 @@ protected boolean isExchangeOrContextType(Class paramType) { @Override protected McpAsyncRequestContext createRequestContext(McpTransportContext exchange, CallToolRequest request) { - - return DefaultMcpAsyncRequestContext.builder() - .request(request) - .transportContext(exchange) - .stateless(true) - .build(); + throw new UnsupportedOperationException( + "Stateless tool methods do not support McpAsyncRequestContext parameter."); } /** diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java index fcb14e3..c7f5244 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java @@ -54,7 +54,6 @@ protected boolean isExchangeOrContextType(Class paramType) { @Override protected McpSyncRequestContext createRequestContext(McpSyncServerExchange exchange, CallToolRequest request) { - return DefaultMcpSyncRequestContext.builder().request(request).exchange(exchange).build(); } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java index 88029e4..64c2850 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java @@ -56,12 +56,8 @@ protected boolean isExchangeOrContextType(Class paramType) { @Override protected McpSyncRequestContext createRequestContext(McpTransportContext exchange, CallToolRequest request) { - - return DefaultMcpSyncRequestContext.builder() - .request(request) - .transportContext(exchange) - .stateless(true) - .build(); + throw new UnsupportedOperationException( + "Stateless tool methods do not support McpSyncRequestContext parameter."); } @Override diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/ConcurrentReferenceHashMap.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/ConcurrentReferenceHashMap.java index 057b252..cbff2b9 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/ConcurrentReferenceHashMap.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/ConcurrentReferenceHashMap.java @@ -672,7 +672,7 @@ private Reference findInChain(Reference ref, Object key, int hash) { return null; } - @SuppressWarnings({ "rawtypes", "unchecked" }) + @SuppressWarnings({ "unchecked" }) private Reference[] createReferenceArray(int size) { return new Reference[size]; } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/McpProviderUtils.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/McpProviderUtils.java index a1dfade..db7bbeb 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/McpProviderUtils.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/McpProviderUtils.java @@ -20,9 +20,13 @@ import java.util.function.Predicate; import java.util.regex.Pattern; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpSyncServerExchange; import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.context.McpAsyncRequestContext; +import org.springaicommunity.mcp.context.McpSyncRequestContext; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -49,7 +53,7 @@ public static Predicate filterNonReactiveReturnTypeMethod() { if (isReactiveReturnType.test(method)) { return true; } - logger.info( + logger.warn( "Sync providers doesn't support reactive return types. Skipping method {} with reactive return type {}", method, method.getReturnType()); return false; @@ -61,11 +65,38 @@ public static Predicate filterReactiveReturnTypeMethod() { if (isNotReactiveReturnType.test(method)) { return true; } - logger.info( + logger.warn( "Sync providers doesn't support reactive return types. Skipping method {} with reactive return type {}", method, method.getReturnType()); return false; }; } + private static boolean hasBidirectionalParameters(Method method) { + + for (Class paramType : method.getParameterTypes()) { + if (McpSyncRequestContext.class.isAssignableFrom(paramType) + || McpAsyncRequestContext.class.isAssignableFrom(paramType) + || McpSyncServerExchange.class.isAssignableFrom(paramType) + || McpAsyncServerExchange.class.isAssignableFrom(paramType)) { + + return true; + } + } + + return false; + } + + public static Predicate filterMethodWithBidirectionalParameters() { + return method -> { + if (!hasBidirectionalParameters(method)) { + return true; + } + logger.warn( + "Stateless servers doesn't support bidirectional parameters. Skipping method {} with bidirectional parameters", + method); + return false; + }; + } + } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/complete/AsyncStatelessMcpCompleteProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/complete/AsyncStatelessMcpCompleteProvider.java index d34c113..0c59858 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/complete/AsyncStatelessMcpCompleteProvider.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/complete/AsyncStatelessMcpCompleteProvider.java @@ -69,6 +69,7 @@ public List getCompleteSpecifications() { .map(completeObject -> Stream.of(doGetClassMethods(completeObject)) .filter(method -> method.isAnnotationPresent(McpComplete.class)) .filter(McpProviderUtils.filterNonReactiveReturnTypeMethod()) + .filter(McpProviderUtils.filterMethodWithBidirectionalParameters()) .sorted((m1, m2) -> m1.getName().compareTo(m2.getName())) .map(mcpCompleteMethod -> { var completeAnnotation = mcpCompleteMethod.getAnnotation(McpComplete.class); diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/complete/SyncStatelessMcpCompleteProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/complete/SyncStatelessMcpCompleteProvider.java index cce8307..8a16160 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/complete/SyncStatelessMcpCompleteProvider.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/complete/SyncStatelessMcpCompleteProvider.java @@ -68,6 +68,7 @@ public List getCompleteSpecifications() { .map(completeObject -> Stream.of(doGetClassMethods(completeObject)) .filter(method -> method.isAnnotationPresent(McpComplete.class)) .filter(McpProviderUtils.filterReactiveReturnTypeMethod()) + .filter(McpProviderUtils.filterMethodWithBidirectionalParameters()) .sorted((m1, m2) -> m1.getName().compareTo(m2.getName())) .map(mcpCompleteMethod -> { var completeAnnotation = mcpCompleteMethod.getAnnotation(McpComplete.class); diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/prompt/AsyncStatelessMcpPromptProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/prompt/AsyncStatelessMcpPromptProvider.java index 458e55b..2da76f7 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/prompt/AsyncStatelessMcpPromptProvider.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/prompt/AsyncStatelessMcpPromptProvider.java @@ -69,6 +69,7 @@ public List getPromptSpecifications() { .map(promptObject -> Stream.of(doGetClassMethods(promptObject)) .filter(method -> method.isAnnotationPresent(McpPrompt.class)) .filter(McpProviderUtils.filterNonReactiveReturnTypeMethod()) + .filter(McpProviderUtils.filterMethodWithBidirectionalParameters()) .sorted((m1, m2) -> m1.getName().compareTo(m2.getName())) .map(mcpPromptMethod -> { var promptAnnotation = mcpPromptMethod.getAnnotation(McpPrompt.class); diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/prompt/SyncStatelessMcpPromptProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/prompt/SyncStatelessMcpPromptProvider.java index 1a7f1bc..9da3e10 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/prompt/SyncStatelessMcpPromptProvider.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/prompt/SyncStatelessMcpPromptProvider.java @@ -68,6 +68,7 @@ public List getPromptSpecifications() { .map(promptObject -> Stream.of(doGetClassMethods(promptObject)) .filter(method -> method.isAnnotationPresent(McpPrompt.class)) .filter(McpProviderUtils.filterReactiveReturnTypeMethod()) + .filter(McpProviderUtils.filterMethodWithBidirectionalParameters()) .sorted((m1, m2) -> m1.getName().compareTo(m2.getName())) .map(mcpPromptMethod -> { var promptAnnotation = mcpPromptMethod.getAnnotation(McpPrompt.class); diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/resource/AsyncStatelessMcpResourceProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/resource/AsyncStatelessMcpResourceProvider.java index 56495b3..24f58b3 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/resource/AsyncStatelessMcpResourceProvider.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/resource/AsyncStatelessMcpResourceProvider.java @@ -71,6 +71,7 @@ public List getResourceSpecifications() { .map(resourceObject -> Stream.of(doGetClassMethods(resourceObject)) .filter(method -> method.isAnnotationPresent(McpResource.class)) .filter(McpProviderUtils.filterNonReactiveReturnTypeMethod()) + .filter(McpProviderUtils.filterMethodWithBidirectionalParameters()) .sorted((m1, m2) -> m1.getName().compareTo(m2.getName())) .map(mcpResourceMethod -> { diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/resource/SyncStatelessMcpResourceProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/resource/SyncStatelessMcpResourceProvider.java index a464422..7ad4faf 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/resource/SyncStatelessMcpResourceProvider.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/resource/SyncStatelessMcpResourceProvider.java @@ -70,6 +70,7 @@ public List getResourceSpecifications() { .map(resourceObject -> Stream.of(doGetClassMethods(resourceObject)) .filter(method -> method.isAnnotationPresent(McpResource.class)) .filter(McpProviderUtils.filterReactiveReturnTypeMethod()) + .filter(McpProviderUtils.filterMethodWithBidirectionalParameters()) .sorted((m1, m2) -> m1.getName().compareTo(m2.getName())) .map(mcpResourceMethod -> { diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/AsyncStatelessMcpToolProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/AsyncStatelessMcpToolProvider.java index 3da8cc8..e97ebe1 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/AsyncStatelessMcpToolProvider.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/AsyncStatelessMcpToolProvider.java @@ -68,6 +68,7 @@ public List getToolSpecifications() { .map(toolObject -> Stream.of(doGetClassMethods(toolObject)) .filter(method -> method.isAnnotationPresent(McpTool.class)) .filter(McpProviderUtils.filterNonReactiveReturnTypeMethod()) + .filter(McpProviderUtils.filterMethodWithBidirectionalParameters()) .sorted((m1, m2) -> m1.getName().compareTo(m2.getName())) .map(mcpToolMethod -> { diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/SyncStatelessMcpToolProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/SyncStatelessMcpToolProvider.java index 1dc1640..4a2151d 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/SyncStatelessMcpToolProvider.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/tool/SyncStatelessMcpToolProvider.java @@ -65,6 +65,7 @@ public List getToolSpecifications() { .map(toolObject -> Stream.of(this.doGetClassMethods(toolObject)) .filter(method -> method.isAnnotationPresent(McpTool.class)) .filter(McpProviderUtils.filterReactiveReturnTypeMethod()) + .filter(McpProviderUtils.filterMethodWithBidirectionalParameters()) .sorted((m1, m2) -> m1.getName().compareTo(m2.getName())) .map(mcpToolMethod -> { @@ -107,8 +108,6 @@ public List getToolSpecifications() { } toolBuilder.title(title); - // ReactiveUtils.isReactiveReturnTypeOfCallToolResult(mcpToolMethod); - // Generate Output Schema from the method return type. // Output schema is not generated for primitive types, void, // CallToolResult, simple value types (String, etc.) diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java index f48a216..6e0c054 100644 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java @@ -112,7 +112,10 @@ public void testRootsWhenSupported() { public void testRootsWhenNotSupported() { when(exchange.getClientCapabilities()).thenReturn(null); - StepVerifier.create(context.roots()).verifyComplete(); + StepVerifier.create(context.roots()).verifyErrorSatisfies(e -> { + assertThat(e).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Roots not supported by the client"); + }); } @Test @@ -121,7 +124,11 @@ public void testRootsWhenCapabilitiesNullRoots() { when(capabilities.roots()).thenReturn(null); when(exchange.getClientCapabilities()).thenReturn(capabilities); - StepVerifier.create(context.roots()).verifyComplete(); + StepVerifier.create(context.roots()).verifyErrorSatisfies(e -> { + assertThat(e).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Roots not supported by the client"); + }); + } // Elicitation Tests @@ -197,16 +204,16 @@ record Person(String name, int age) { @Test public void testElicitationWithNullTypeReference() { - assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { - context.elicit((TypeReference) null); - })).hasMessageContaining("Elicitation response type must not be null"); + assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, + () -> context.elicit((TypeReference) null))) + .hasMessageContaining("Elicitation response type must not be null"); } @Test public void testElicitationWithNullClassType() { - assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { - context.elicit((Class) null); - })).hasMessageContaining("Elicitation response type must not be null"); + assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, + () -> context.elicit((Class) null))) + .hasMessageContaining("Elicitation response type must not be null"); } @Test @@ -229,11 +236,11 @@ public void testElicitationWithNullMessage() { public void testElicitationReturnsEmptyWhenNotSupported() { when(exchange.getClientCapabilities()).thenReturn(null); - Mono>> result = context.elicit(e -> e.message("Test message"), - new TypeReference>() { - }); - - StepVerifier.create(result).verifyComplete(); + StepVerifier.create(context.elicit(e -> e.message("Test message"), new TypeReference>() { + })).verifyErrorSatisfies(e -> { + assertThat(e).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Elicitation not supported by the client"); + }); } @Test @@ -371,9 +378,10 @@ public void testElicitationWhenNotSupported() { .requestedSchema(Map.of("type", "string")) .build(); - Mono result = context.elicit(elicitRequest); - - StepVerifier.create(result).verifyComplete(); + StepVerifier.create(context.elicit(elicitRequest)).verifyErrorSatisfies(e -> { + assertThat(e).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Elicitation not supported by the client"); + }); } // Sampling Tests @@ -450,9 +458,10 @@ public void testSamplingWhenNotSupported() { .maxTokens(500) .build(); - Mono result = context.sample(createRequest); - - StepVerifier.create(result).verifyComplete(); + StepVerifier.create(context.sample(createRequest)).verifyErrorSatisfies(e -> { + assertThat(e).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Sampling not supported by the client"); + }); } // Progress Tests diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java index b2ece1a..0269d57 100644 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java @@ -5,7 +5,7 @@ package org.springaicommunity.mcp.context; import java.util.Map; -import java.util.Optional; +import java.util.function.Consumer; import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.server.McpSyncServerExchange; @@ -27,6 +27,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.springaicommunity.mcp.context.McpRequestContextTypes.ElicitationSpec; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -87,6 +88,32 @@ public void testBuilderWithNullExchange() { // Roots Tests + @Test + public void testRootsEnabledWhenSupported() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + McpSchema.ClientCapabilities.RootCapabilities roots = mock(McpSchema.ClientCapabilities.RootCapabilities.class); + when(capabilities.roots()).thenReturn(roots); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + assertThat(context.rootsEnabled()).isTrue(); + } + + @Test + public void testRootsEnabledWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + assertThat(context.rootsEnabled()).isFalse(); + } + + @Test + public void testRootsEnabledWhenCapabilitiesNullRoots() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + when(capabilities.roots()).thenReturn(null); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + assertThat(context.rootsEnabled()).isFalse(); + } + @Test public void testRootsWhenSupported() { ClientCapabilities capabilities = mock(ClientCapabilities.class); @@ -97,10 +124,10 @@ public void testRootsWhenSupported() { ListRootsResult expectedResult = mock(ListRootsResult.class); when(exchange.listRoots()).thenReturn(expectedResult); - Optional result = context.roots(); + ListRootsResult result = context.roots(); - assertThat(result).isPresent(); - assertThat(result.get()).isEqualTo(expectedResult); + assertThat(result).isNotNull(); + assertThat(result).isEqualTo(expectedResult); verify(exchange).listRoots(); } @@ -108,9 +135,8 @@ public void testRootsWhenSupported() { public void testRootsWhenNotSupported() { when(exchange.getClientCapabilities()).thenReturn(null); - Optional result = context.roots(); - - assertThat(result).isEmpty(); + assertThatThrownBy(() -> context.roots()).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Roots not supported"); } @Test @@ -119,13 +145,38 @@ public void testRootsWhenCapabilitiesNullRoots() { when(capabilities.roots()).thenReturn(null); when(exchange.getClientCapabilities()).thenReturn(capabilities); - Optional result = context.roots(); - - assertThat(result).isEmpty(); + assertThatThrownBy(() -> context.roots()).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Roots not supported"); } // Elicitation Tests + @Test + public void testElicitEnabledWhenSupported() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + assertThat(context.elicitEnabled()).isTrue(); + } + + @Test + public void testElicitEnabledWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + assertThat(context.elicitEnabled()).isFalse(); + } + + @Test + public void testElicitEnabledWhenCapabilitiesNullElicitation() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + when(capabilities.elicitation()).thenReturn(null); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + assertThat(context.elicitEnabled()).isFalse(); + } + @Test public void testElicitationWithTypeAndMessage() { ClientCapabilities capabilities = mock(ClientCapabilities.class); @@ -140,15 +191,15 @@ public void testElicitationWithTypeAndMessage() { when(expectedResult.meta()).thenReturn(null); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); - Optional>> result = context.elicit(e -> e.message("Test message"), + StructuredElicitResult> result = context.elicit(e -> e.message("Test message"), new TypeReference>() { }); - assertThat(result).isPresent(); - assertThat(result.get().action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.get().structuredContent()).isNotNull(); - assertThat(result.get().structuredContent()).containsEntry("name", "John"); - assertThat(result.get().structuredContent()).containsEntry("age", 30); + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.structuredContent()).isNotNull(); + assertThat(result.structuredContent()).containsEntry("name", "John"); + assertThat(result.structuredContent()).containsEntry("age", 30); ArgumentCaptor captor = ArgumentCaptor.forClass(ElicitRequest.class); verify(exchange).createElicitation(captor.capture()); @@ -177,16 +228,16 @@ record Person(String name, int age) { when(expectedResult.meta()).thenReturn(resultMeta); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); - Optional> result = context - .elicit(e -> e.message("Test message").meta(requestMeta), new TypeReference() { - }); + StructuredElicitResult result = context.elicit(e -> e.message("Test message").meta(requestMeta), + new TypeReference() { + }); - assertThat(result).isPresent(); - assertThat(result.get().action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.get().structuredContent()).isNotNull(); - assertThat(result.get().structuredContent().name()).isEqualTo("Jane"); - assertThat(result.get().structuredContent().age()).isEqualTo(25); - assertThat(result.get().meta()).containsEntry("resultKey", "resultValue"); + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.structuredContent()).isNotNull(); + assertThat(result.structuredContent().name()).isEqualTo("Jane"); + assertThat(result.structuredContent().age()).isEqualTo(25); + assertThat(result.meta()).containsEntry("resultKey", "resultValue"); ArgumentCaptor captor = ArgumentCaptor.forClass(ElicitRequest.class); verify(exchange).createElicitation(captor.capture()); @@ -197,24 +248,18 @@ record Person(String name, int age) { @Test public void testElicitationWithNullResponseType() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + assertThatThrownBy(() -> context.elicit((TypeReference) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Elicitation response type must not be null"); } @Test - public void testElicitationWithTypeReturnsEmptyWhenNotSupported() { - when(exchange.getClientCapabilities()).thenReturn(null); - - Optional>> result = context - .elicit(new TypeReference>() { - }); - - assertThat(result).isEmpty(); - } - - @Test - public void testElicitationWithTypeReturnsEmptyWhenActionIsNotAccept() { + public void testElicitationWithTypeWhenActionIsNotAccept() { ClientCapabilities capabilities = mock(ClientCapabilities.class); ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); when(capabilities.elicitation()).thenReturn(elicitation); @@ -222,13 +267,16 @@ public void testElicitationWithTypeReturnsEmptyWhenActionIsNotAccept() { ElicitResult expectedResult = mock(ElicitResult.class); when(expectedResult.action()).thenReturn(ElicitResult.Action.DECLINE); + when(expectedResult.meta()).thenReturn(null); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); - Optional>> result = context.elicit(e -> e.message("Test message"), + StructuredElicitResult> result = context.elicit(e -> e.message("Test message"), new TypeReference>() { }); - assertThat(result).isEmpty(); + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.DECLINE); + assertThat(result.structuredContent()).isNull(); } @Test @@ -251,18 +299,18 @@ record PersonWithAddress(String name, int age, Address address) { when(expectedResult.meta()).thenReturn(null); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); - Optional> result = context - .elicit(e -> e.message("Test message").meta(null), new TypeReference() { - }); + StructuredElicitResult result = context.elicit(e -> e.message("Test message").meta(null), + new TypeReference() { + }); - assertThat(result).isPresent(); - assertThat(result.get().action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.get().structuredContent()).isNotNull(); - assertThat(result.get().structuredContent().name()).isEqualTo("John"); - assertThat(result.get().structuredContent().age()).isEqualTo(30); - assertThat(result.get().structuredContent().address()).isNotNull(); - assertThat(result.get().structuredContent().address().street()).isEqualTo("123 Main St"); - assertThat(result.get().structuredContent().address().city()).isEqualTo("Springfield"); + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.structuredContent()).isNotNull(); + assertThat(result.structuredContent().name()).isEqualTo("John"); + assertThat(result.structuredContent().age()).isEqualTo(30); + assertThat(result.structuredContent().address()).isNotNull(); + assertThat(result.structuredContent().address().street()).isEqualTo("123 Main St"); + assertThat(result.structuredContent().address().city()).isEqualTo("Springfield"); } @Test @@ -280,12 +328,12 @@ public void testElicitationWithTypeHandlesListTypes() { when(expectedResult.meta()).thenReturn(null); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); - Optional>> result = context - .elicit(e -> e.message("Test message").meta(null), new TypeReference>() { - }); + StructuredElicitResult> result = context.elicit(e -> e.message("Test message").meta(null), + new TypeReference>() { + }); - assertThat(result).isPresent(); - assertThat(result.get().structuredContent()).containsKey("items"); + assertThat(result).isNotNull(); + assertThat(result.structuredContent()).containsKey("items"); } @Test @@ -301,13 +349,13 @@ public void testElicitationWithTypeReference() { when(expectedResult.content()).thenReturn(contentMap); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); - Optional>> result = context - .elicit(e -> e.message("Test message").meta(null), new TypeReference>() { - }); + StructuredElicitResult> result = context.elicit(e -> e.message("Test message").meta(null), + new TypeReference>() { + }); - assertThat(result).isPresent(); - assertThat(result.get().structuredContent()).containsEntry("result", "success"); - assertThat(result.get().structuredContent()).containsEntry("data", "test value"); + assertThat(result).isNotNull(); + assertThat(result.structuredContent()).containsEntry("result", "success"); + assertThat(result.structuredContent()).containsEntry("data", "test value"); } @Test @@ -325,28 +373,62 @@ public void testElicitationWithRequest() { when(exchange.createElicitation(elicitRequest)).thenReturn(expectedResult); - Optional result = context.elicit(elicitRequest); + ElicitResult result = context.elicit(elicitRequest); - assertThat(result).isPresent(); - assertThat(result.get()).isEqualTo(expectedResult); + assertThat(result).isNotNull(); + assertThat(result).isEqualTo(expectedResult); } @Test public void testElicitationWhenNotSupported() { when(exchange.getClientCapabilities()).thenReturn(null); - ElicitRequest elicitRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema(Map.of("type", "string")) - .build(); + assertThatThrownBy(() -> context.elicit((ElicitRequest) null)).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Elicitation not supported by the clien"); + + assertThatThrownBy(() -> context.elicit((Consumer) null, (TypeReference) null)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Elicitation not supported by the clien"); - Optional result = context.elicit(elicitRequest); + assertThatThrownBy(() -> context.elicit((Consumer) null, (Class) null)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Elicitation not supported by the clien"); - assertThat(result).isEmpty(); + assertThatThrownBy(() -> context.elicit((TypeReference) null)).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Elicitation not supported by the clien"); + + assertThatThrownBy(() -> context.elicit((Class) null)).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Elicitation not supported by the clien"); } // Sampling Tests + @Test + public void testSampleEnabledWhenSupported() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Sampling sampling = mock(ClientCapabilities.Sampling.class); + when(capabilities.sampling()).thenReturn(sampling); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + assertThat(context.sampleEnabled()).isTrue(); + } + + @Test + public void testSampleEnabledWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + assertThat(context.sampleEnabled()).isFalse(); + } + + @Test + public void testSampleEnabledWhenCapabilitiesNullSampling() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + when(capabilities.sampling()).thenReturn(null); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + assertThat(context.sampleEnabled()).isFalse(); + } + @Test public void testSamplingWithMessages() { ClientCapabilities capabilities = mock(ClientCapabilities.class); @@ -357,10 +439,10 @@ public void testSamplingWithMessages() { CreateMessageResult expectedResult = mock(CreateMessageResult.class); when(exchange.createMessage(any(CreateMessageRequest.class))).thenReturn(expectedResult); - Optional result = context.sample("Message 1", "Message 2"); + CreateMessageResult result = context.sample("Message 1", "Message 2"); - assertThat(result).isPresent(); - assertThat(result.get()).isEqualTo(expectedResult); + assertThat(result).isNotNull(); + assertThat(result).isEqualTo(expectedResult); } @Test @@ -373,15 +455,15 @@ public void testSamplingWithConsumer() { CreateMessageResult expectedResult = mock(CreateMessageResult.class); when(exchange.createMessage(any(CreateMessageRequest.class))).thenReturn(expectedResult); - Optional result = context.sample(spec -> { + CreateMessageResult result = context.sample(spec -> { spec.message(new TextContent("Test message")); spec.systemPrompt("System prompt"); spec.temperature(0.7); spec.maxTokens(100); }); - assertThat(result).isPresent(); - assertThat(result.get()).isEqualTo(expectedResult); + assertThat(result).isNotNull(); + assertThat(result).isEqualTo(expectedResult); ArgumentCaptor captor = ArgumentCaptor.forClass(CreateMessageRequest.class); verify(exchange).createMessage(captor.capture()); @@ -407,10 +489,10 @@ public void testSamplingWithRequest() { when(exchange.createMessage(createRequest)).thenReturn(expectedResult); - Optional result = context.sample(createRequest); + CreateMessageResult result = context.sample(createRequest); - assertThat(result).isPresent(); - assertThat(result.get()).isEqualTo(expectedResult); + assertThat(result).isNotNull(); + assertThat(result).isEqualTo(expectedResult); } @Test @@ -422,9 +504,14 @@ public void testSamplingWhenNotSupported() { .maxTokens(500) .build(); - Optional result = context.sample(createRequest); + assertThatThrownBy(() -> context.sample(createRequest)).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Sampling not supported by the client"); + + assertThatThrownBy(() -> context.sample("Message 1")).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Sampling not supported by the client"); - assertThat(result).isEmpty(); + assertThatThrownBy(() -> context.sample(spec -> spec.message("Test"))).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Sampling not supported by the client"); } // Progress Tests diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/McpProviderUtilsTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/McpProviderUtilsTests.java new file mode 100644 index 0000000..5099d60 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/McpProviderUtilsTests.java @@ -0,0 +1,436 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.function.Predicate; + +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Publisher; +import org.springaicommunity.mcp.context.McpAsyncRequestContext; +import org.springaicommunity.mcp.context.McpSyncRequestContext; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link McpProviderUtils}. + * + * @author Christian Tzolov + */ +public class McpProviderUtilsTests { + + // Test classes for method reflection tests + static class TestMethods { + + public String nonReactiveMethod() { + return "test"; + } + + public Mono monoMethod() { + return Mono.just("test"); + } + + public Flux fluxMethod() { + return Flux.just("test"); + } + + public Publisher publisherMethod() { + return Mono.just("test"); + } + + public void voidMethod() { + } + + public List listMethod() { + return List.of("test"); + } + + public String methodWithSyncContext(McpSyncRequestContext context) { + return "test"; + } + + public String methodWithAsyncContext(McpAsyncRequestContext context) { + return "test"; + } + + public String methodWithSyncExchange(McpSyncServerExchange exchange) { + return "test"; + } + + public String methodWithAsyncExchange(McpAsyncServerExchange exchange) { + return "test"; + } + + public String methodWithMultipleParams(String param1, McpSyncRequestContext context, int param2) { + return "test"; + } + + public String methodWithoutBidirectionalParams(String param1, int param2) { + return "test"; + } + + } + + // URI Template Tests + + @Test + public void testIsUriTemplateWithSimpleVariable() { + assertThat(McpProviderUtils.isUriTemplate("/api/{id}")).isTrue(); + } + + @Test + public void testIsUriTemplateWithMultipleVariables() { + assertThat(McpProviderUtils.isUriTemplate("/api/{userId}/posts/{postId}")).isTrue(); + } + + @Test + public void testIsUriTemplateWithVariableAtStart() { + assertThat(McpProviderUtils.isUriTemplate("{id}/details")).isTrue(); + } + + @Test + public void testIsUriTemplateWithVariableAtEnd() { + assertThat(McpProviderUtils.isUriTemplate("/api/users/{id}")).isTrue(); + } + + @Test + public void testIsUriTemplateWithComplexVariableName() { + assertThat(McpProviderUtils.isUriTemplate("/api/{user_id}")).isTrue(); + assertThat(McpProviderUtils.isUriTemplate("/api/{userId123}")).isTrue(); + } + + @Test + public void testIsUriTemplateWithNoVariables() { + assertThat(McpProviderUtils.isUriTemplate("/api/users")).isFalse(); + } + + @Test + public void testIsUriTemplateWithEmptyString() { + assertThat(McpProviderUtils.isUriTemplate("")).isFalse(); + } + + @Test + public void testIsUriTemplateWithOnlySlashes() { + assertThat(McpProviderUtils.isUriTemplate("/")).isFalse(); + assertThat(McpProviderUtils.isUriTemplate("//")).isFalse(); + } + + @Test + public void testIsUriTemplateWithIncompleteBraces() { + assertThat(McpProviderUtils.isUriTemplate("/api/{id")).isFalse(); + assertThat(McpProviderUtils.isUriTemplate("/api/id}")).isFalse(); + } + + @Test + public void testIsUriTemplateWithEmptyBraces() { + assertThat(McpProviderUtils.isUriTemplate("/api/{}")).isFalse(); + } + + @Test + public void testIsUriTemplateWithNestedPath() { + assertThat(McpProviderUtils.isUriTemplate("/api/v1/users/{userId}/posts/{postId}/comments")).isTrue(); + } + + // Reactive Return Type Predicate Tests + + @Test + public void testIsReactiveReturnTypeWithMono() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("monoMethod"); + assertThat(McpProviderUtils.isReactiveReturnType.test(method)).isTrue(); + } + + @Test + public void testIsReactiveReturnTypeWithFlux() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("fluxMethod"); + assertThat(McpProviderUtils.isReactiveReturnType.test(method)).isTrue(); + } + + @Test + public void testIsReactiveReturnTypeWithPublisher() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("publisherMethod"); + assertThat(McpProviderUtils.isReactiveReturnType.test(method)).isTrue(); + } + + @Test + public void testIsReactiveReturnTypeWithNonReactive() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("nonReactiveMethod"); + assertThat(McpProviderUtils.isReactiveReturnType.test(method)).isFalse(); + } + + @Test + public void testIsReactiveReturnTypeWithVoid() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("voidMethod"); + assertThat(McpProviderUtils.isReactiveReturnType.test(method)).isFalse(); + } + + @Test + public void testIsReactiveReturnTypeWithList() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("listMethod"); + assertThat(McpProviderUtils.isReactiveReturnType.test(method)).isFalse(); + } + + // Non-Reactive Return Type Predicate Tests + + @Test + public void testIsNotReactiveReturnTypeWithMono() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("monoMethod"); + assertThat(McpProviderUtils.isNotReactiveReturnType.test(method)).isFalse(); + } + + @Test + public void testIsNotReactiveReturnTypeWithFlux() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("fluxMethod"); + assertThat(McpProviderUtils.isNotReactiveReturnType.test(method)).isFalse(); + } + + @Test + public void testIsNotReactiveReturnTypeWithPublisher() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("publisherMethod"); + assertThat(McpProviderUtils.isNotReactiveReturnType.test(method)).isFalse(); + } + + @Test + public void testIsNotReactiveReturnTypeWithNonReactive() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("nonReactiveMethod"); + assertThat(McpProviderUtils.isNotReactiveReturnType.test(method)).isTrue(); + } + + @Test + public void testIsNotReactiveReturnTypeWithVoid() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("voidMethod"); + assertThat(McpProviderUtils.isNotReactiveReturnType.test(method)).isTrue(); + } + + @Test + public void testIsNotReactiveReturnTypeWithList() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("listMethod"); + assertThat(McpProviderUtils.isNotReactiveReturnType.test(method)).isTrue(); + } + + // Filter Non-Reactive Return Type Method Tests + + @Test + public void testFilterNonReactiveReturnTypeMethodWithReactiveType() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("monoMethod"); + Predicate filter = McpProviderUtils.filterNonReactiveReturnTypeMethod(); + assertThat(filter.test(method)).isTrue(); + } + + @Test + public void testFilterNonReactiveReturnTypeMethodWithNonReactiveType() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("nonReactiveMethod"); + Predicate filter = McpProviderUtils.filterNonReactiveReturnTypeMethod(); + // This should return false and log a warning + assertThat(filter.test(method)).isFalse(); + } + + @Test + public void testFilterNonReactiveReturnTypeMethodWithFlux() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("fluxMethod"); + Predicate filter = McpProviderUtils.filterNonReactiveReturnTypeMethod(); + assertThat(filter.test(method)).isTrue(); + } + + @Test + public void testFilterNonReactiveReturnTypeMethodWithPublisher() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("publisherMethod"); + Predicate filter = McpProviderUtils.filterNonReactiveReturnTypeMethod(); + assertThat(filter.test(method)).isTrue(); + } + + // Filter Reactive Return Type Method Tests + + @Test + public void testFilterReactiveReturnTypeMethodWithReactiveType() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("monoMethod"); + Predicate filter = McpProviderUtils.filterReactiveReturnTypeMethod(); + // This should return false and log a warning + assertThat(filter.test(method)).isFalse(); + } + + @Test + public void testFilterReactiveReturnTypeMethodWithNonReactiveType() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("nonReactiveMethod"); + Predicate filter = McpProviderUtils.filterReactiveReturnTypeMethod(); + assertThat(filter.test(method)).isTrue(); + } + + @Test + public void testFilterReactiveReturnTypeMethodWithFlux() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("fluxMethod"); + Predicate filter = McpProviderUtils.filterReactiveReturnTypeMethod(); + // This should return false and log a warning + assertThat(filter.test(method)).isFalse(); + } + + @Test + public void testFilterReactiveReturnTypeMethodWithPublisher() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("publisherMethod"); + Predicate filter = McpProviderUtils.filterReactiveReturnTypeMethod(); + // This should return false and log a warning + assertThat(filter.test(method)).isFalse(); + } + + @Test + public void testFilterReactiveReturnTypeMethodWithVoid() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("voidMethod"); + Predicate filter = McpProviderUtils.filterReactiveReturnTypeMethod(); + assertThat(filter.test(method)).isTrue(); + } + + // Filter Method With Bidirectional Parameters Tests + + @Test + public void testFilterMethodWithBidirectionalParametersWithSyncContext() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("methodWithSyncContext", McpSyncRequestContext.class); + Predicate filter = McpProviderUtils.filterMethodWithBidirectionalParameters(); + // This should return false and log a warning + assertThat(filter.test(method)).isFalse(); + } + + @Test + public void testFilterMethodWithBidirectionalParametersWithAsyncContext() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("methodWithAsyncContext", McpAsyncRequestContext.class); + Predicate filter = McpProviderUtils.filterMethodWithBidirectionalParameters(); + // This should return false and log a warning + assertThat(filter.test(method)).isFalse(); + } + + @Test + public void testFilterMethodWithBidirectionalParametersWithSyncExchange() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("methodWithSyncExchange", McpSyncServerExchange.class); + Predicate filter = McpProviderUtils.filterMethodWithBidirectionalParameters(); + // This should return false and log a warning + assertThat(filter.test(method)).isFalse(); + } + + @Test + public void testFilterMethodWithBidirectionalParametersWithAsyncExchange() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("methodWithAsyncExchange", McpAsyncServerExchange.class); + Predicate filter = McpProviderUtils.filterMethodWithBidirectionalParameters(); + // This should return false and log a warning + assertThat(filter.test(method)).isFalse(); + } + + @Test + public void testFilterMethodWithBidirectionalParametersWithMultipleParams() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("methodWithMultipleParams", String.class, + McpSyncRequestContext.class, int.class); + Predicate filter = McpProviderUtils.filterMethodWithBidirectionalParameters(); + // This should return false because it has a bidirectional parameter + assertThat(filter.test(method)).isFalse(); + } + + @Test + public void testFilterMethodWithBidirectionalParametersWithoutBidirectionalParams() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("methodWithoutBidirectionalParams", String.class, int.class); + Predicate filter = McpProviderUtils.filterMethodWithBidirectionalParameters(); + assertThat(filter.test(method)).isTrue(); + } + + @Test + public void testFilterMethodWithBidirectionalParametersWithNoParams() throws NoSuchMethodException { + Method method = TestMethods.class.getMethod("nonReactiveMethod"); + Predicate filter = McpProviderUtils.filterMethodWithBidirectionalParameters(); + assertThat(filter.test(method)).isTrue(); + } + + // Combined Filter Tests + + @Test + public void testCombinedFiltersForStatelessSyncProvider() throws NoSuchMethodException { + // Stateless sync providers should filter out: + // 1. Methods with reactive return types + // 2. Methods with bidirectional parameters + + Method validMethod = TestMethods.class.getMethod("methodWithoutBidirectionalParams", String.class, int.class); + Method reactiveMethod = TestMethods.class.getMethod("monoMethod"); + Method bidirectionalMethod = TestMethods.class.getMethod("methodWithSyncContext", McpSyncRequestContext.class); + + Predicate reactiveFilter = McpProviderUtils.filterReactiveReturnTypeMethod(); + Predicate bidirectionalFilter = McpProviderUtils.filterMethodWithBidirectionalParameters(); + Predicate combinedFilter = reactiveFilter.and(bidirectionalFilter); + + assertThat(combinedFilter.test(validMethod)).isTrue(); + assertThat(combinedFilter.test(reactiveMethod)).isFalse(); + assertThat(combinedFilter.test(bidirectionalMethod)).isFalse(); + } + + @Test + public void testCombinedFiltersForStatelessAsyncProvider() throws NoSuchMethodException { + // Stateless async providers should filter out: + // 1. Methods with non-reactive return types + // 2. Methods with bidirectional parameters + + Method validMethod = TestMethods.class.getMethod("monoMethod"); + Method nonReactiveMethod = TestMethods.class.getMethod("nonReactiveMethod"); + Method bidirectionalMethod = TestMethods.class.getMethod("methodWithAsyncContext", + McpAsyncRequestContext.class); + + Predicate nonReactiveFilter = McpProviderUtils.filterNonReactiveReturnTypeMethod(); + Predicate bidirectionalFilter = McpProviderUtils.filterMethodWithBidirectionalParameters(); + Predicate combinedFilter = nonReactiveFilter.and(bidirectionalFilter); + + assertThat(combinedFilter.test(validMethod)).isTrue(); + assertThat(combinedFilter.test(nonReactiveMethod)).isFalse(); + assertThat(combinedFilter.test(bidirectionalMethod)).isFalse(); + } + + // Edge Case Tests + + @Test + public void testIsUriTemplateWithSpecialCharacters() { + assertThat(McpProviderUtils.isUriTemplate("/api/{user-id}")).isTrue(); + assertThat(McpProviderUtils.isUriTemplate("/api/{user.id}")).isTrue(); + } + + @Test + public void testIsUriTemplateWithQueryParameters() { + // Query parameters are not URI template variables + assertThat(McpProviderUtils.isUriTemplate("/api/users?id={id}")).isTrue(); + } + + @Test + public void testIsUriTemplateWithFragment() { + assertThat(McpProviderUtils.isUriTemplate("/api/users#{id}")).isTrue(); + } + + @Test + public void testIsUriTemplateWithMultipleConsecutiveVariables() { + assertThat(McpProviderUtils.isUriTemplate("/{id}{name}")).isTrue(); + } + + @Test + public void testPredicatesAreReusable() throws NoSuchMethodException { + // Test that predicates can be reused multiple times + Predicate filter = McpProviderUtils.filterReactiveReturnTypeMethod(); + + Method method1 = TestMethods.class.getMethod("nonReactiveMethod"); + Method method2 = TestMethods.class.getMethod("monoMethod"); + Method method3 = TestMethods.class.getMethod("listMethod"); + + assertThat(filter.test(method1)).isTrue(); + assertThat(filter.test(method2)).isFalse(); + assertThat(filter.test(method3)).isTrue(); + } + +}