diff --git a/README.md b/README.md index c69cdaa..e3f3b54 100644 --- a/README.md +++ b/README.md @@ -120,11 +120,15 @@ Each operation type has both synchronous and asynchronous implementations, allow - **`@McpToolParam`** - Annotates tool method parameters with descriptions and requirement specifications #### Special Parameters and Annotations -- **`@McpProgressToken`** - Marks a method parameter to receive the progress token from the request. This parameter is automatically injected and excluded from the generated JSON schema -- **`McpMeta`** - Special parameter type that provides access to metadata from MCP requests, notifications, and results. This parameter is automatically injected and excluded from parameter count limits and JSON schema generation -- **`McpSyncServerExchange`** - Special parameter type for stateful synchronous operations that provides access to server exchange functionality including logging notifications, progress updates, and other server-side operations. This parameter is automatically injected and excluded from JSON schema generation -- **`McpAsyncServerExchange`** - Special parameter type for stateful asynchronous operations that provides access to server exchange functionality with reactive support. This parameter is automatically injected and excluded from JSON schema generation +- **`McpSyncRequestContext`** - Special parameter type for synchronous operations that provides a unified interface for accessing MCP request context, including the original request, server exchange (for stateful operations), transport context (for stateless operations), and convenient methods for logging, progress, sampling, and elicitation. This parameter is automatically injected and excluded from JSON schema generation +- **`McpAsyncRequestContext`** - Special parameter type for asynchronous operations that provides the same unified interface as `McpSyncRequestContext` but with reactive (Mono-based) return types. This parameter is automatically injected and excluded from JSON schema generation +- **(Deprecated and replaced by `McpSyncRequestContext`) `McpSyncServerExchange`** - Special parameter type for stateful synchronous operations that provides access to server exchange functionality including logging notifications, progress updates, and other server-side operations. This parameter is automatically injected and excluded from JSON schema generation. +- **(Deprecated and replaced by `McpAsyncRequestContext`) `McpAsyncServerExchange`** - Special parameter type for stateful asynchronous operations that provides access to server exchange functionality with reactive support. This parameter is automatically injected and excluded from JSON schema generation - **`McpTransportContext`** - Special parameter type for stateless operations that provides lightweight access to transport-level context without full server exchange functionality. This parameter is automatically injected and excluded from JSON schema generation +- **(Deprecated. Handled internally by `McpSyncRequestContext` and `McpAsyncRequestContext`)`@McpProgressToken`** - Marks a method parameter to receive the progress token from the request. This parameter is automatically injected and excluded from the generated JSON schema +**Note:** if using the `McpSyncRequestContext` or `McpAsyncRequestContext` the progress token is handled internally. +- **`McpMeta`** - Special parameter type that provides access to metadata from MCP requests, notifications, and results. This parameter is automatically injected and excluded from parameter count limits and JSON schema generation. +**Note:** if using the McpSyncRequestContext or McpAsyncRequestContext the meta can be obatined via `requestMeta()` instead. ### Method Callbacks @@ -870,6 +874,204 @@ public List smartComplete( This feature enables context-aware MCP operations where the behavior can be customized based on client-provided metadata such as user identity, preferences, session information, or any other contextual data. +#### McpRequestContext Support + +The library provides unified request context interfaces (`McpSyncRequestContext` and `McpAsyncRequestContext`) that offer a higher-level abstraction over the underlying MCP infrastructure. These context objects provide convenient access to: + +- The original request (CallToolRequest, ReadResourceRequest, etc.) +- Server exchange (for stateful operations) or transport context (for stateless operations) +- Convenient methods for logging, progress updates, sampling, elicitation, and more + +**Key Benefits:** +- **Unified API**: Single parameter type works for both stateful and stateless operations +- **Convenience Methods**: Built-in helpers for common operations like logging and progress tracking +- **Type Safety**: Strongly-typed access to request data and context +- **Automatic Injection**: Context is automatically created and injected by the framework + +When a method parameter is of type `McpSyncRequestContext` or `McpAsyncRequestContext`: +- The parameter is automatically injected with the appropriate context implementation +- The parameter is excluded from JSON schema generation +- For stateful operations, the context provides access to `McpSyncServerExchange` or `McpAsyncServerExchange` +- For stateless operations, the context provides access to `McpTransportContext` + +**Synchronous Context Example:** + +```java +public record UserInfo(String name, String email, Number age) {} + +@McpTool(name = "process-with-context", description = "Process data with unified context") +public String processWithContext( + McpSyncRequestContext context, + @McpToolParam(description = "Data to process", required = true) String data) { + + // Access the original request + CallToolRequest request = (CallToolRequest) context.request(); + + // Log information + context.info("Processing data: " + data); + + // Send progress updates + context.progress(50); // 50% complete + + // Check if running in stateful mode + if (!context.isStateless()) { + // Access server exchange for stateful operations + McpSyncServerExchange exchange = context.exchange().orElseThrow(); + // Use exchange for additional operations... + } + + // Perform elicitation with default message - returns StructuredElicitResult + Optional> result = context.elicit(new TypeReference() {}); + + // Or perform elicitation with custom configuration - returns StructuredElicitResult + Optional> structuredResult = context.elicit( + e -> e.message("Please provide your information").meta("context", "user-registration"), + new TypeReference() {} + ); + + if (structuredResult.isPresent() && structuredResult.get().action() == ElicitResult.Action.ACCEPT) { + UserInfo info = structuredResult.get().structuredContent(); + return "Processed: " + data + " for user " + info.name(); + } + + return "Processed: " + data; +} + +@McpResource(uri = "data://{id}", name = "Data Resource", description = "Resource with context") +public ReadResourceResult getDataWithContext( + McpSyncRequestContext context, + String id) { + + // Log the resource access + context.debug("Accessing resource: " + id); + + // Access metadata from the request + Map metadata = context.request()._meta(); + + String content = "Data for " + id; + return new ReadResourceResult(List.of( + new TextResourceContents("data://" + id, "text/plain", content) + )); +} + +@McpPrompt(name = "generate-with-context", description = "Generate prompt with context") +public GetPromptResult generateWithContext( + McpSyncRequestContext context, + @McpArg(name = "topic", required = true) String topic) { + + // Log prompt generation + context.info("Generating prompt for topic: " + topic); + + // Perform sampling if needed + Optional samplingResult = context.sample( + "What are the key points about " + topic + "?" + ); + + String message = "Let's discuss " + topic; + return new GetPromptResult("Generated Prompt", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent(message)))); +} +``` + +**Asynchronous Context Example:** + +```java +public record UserInfo(String name, String email, int age) {} + +@McpTool(name = "async-process-with-context", description = "Async process with unified context") +public Mono asyncProcessWithContext( + McpAsyncRequestContext context, + @McpToolParam(description = "Data to process", required = true) String data) { + + return Mono.fromCallable(() -> { + // Access the original request + CallToolRequest request = (CallToolRequest) context.request(); + return data; + }) + .flatMap(processedData -> { + // Log information (returns Mono) + return context.info("Processing data: " + processedData) + .thenReturn(processedData); + }) + .flatMap(processedData -> { + // Send progress updates (returns Mono) + return context.progress(50) + .thenReturn(processedData); + }) + .flatMap(processedData -> { + // Perform elicitation with default message - returns Mono + return context.elicitation(new TypeReference() {}) + .map(userInfo -> "Processed: " + processedData + " for user " + userInfo.name()); + }) + .switchIfEmpty(Mono.fromCallable(() -> { + // Or perform elicitation with custom message and metadata - returns Mono> + return context.elicitation( + new TypeReference() {}, + "Please provide your information", + Map.of("context", "user-registration") + ) + .filter(result -> result.action() == ElicitResult.Action.ACCEPT) + .map(result -> "Processed: " + data + " for user " + result.structuredContent().name()) + .defaultIfEmpty("Processed: " + data); + }).flatMap(mono -> mono)); +} + +@McpResource(uri = "async-data://{id}", name = "Async Data Resource", + description = "Async resource with context") +public Mono getAsyncDataWithContext( + McpAsyncRequestContext context, + String id) { + + // Log the resource access (returns Mono) + return context.debug("Accessing async resource: " + id) + .then(Mono.fromCallable(() -> { + String content = "Async data for " + id; + return new ReadResourceResult(List.of( + new TextResourceContents("async-data://" + id, "text/plain", content) + )); + })); +} + +@McpPrompt(name = "async-generate-with-context", + description = "Async generate prompt with context") +public Mono asyncGenerateWithContext( + McpAsyncRequestContext context, + @McpArg(name = "topic", required = true) String topic) { + + // Log prompt generation and perform sampling + return context.info("Generating async prompt for topic: " + topic) + .then(context.sampling("What are the key points about " + topic + "?")) + .map(samplingResult -> { + String message = "Let's discuss " + topic; + return new GetPromptResult("Generated Async Prompt", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent(message)))); + }); +} +``` + +**Available Context Methods:** + +`McpSyncRequestContext` provides: +- `request()` - Access the original request object +- `exchange()` - Access the server exchange (for stateful operations) +- `transportContext()` - Access the transport context (for stateless operations) +- `isStateless()` - Check if running in stateless mode +- `log(Consumer)` - Send log messages with custom configuration +- `debug(String)`, `info(String)`, `warn(String)`, `error(String)` - Convenience logging methods +- `progress(int)`, `progress(Consumer)` - Send progress updates +- `elicit(TypeReference)` - Request user input with default message, returns `StructuredElicitResult` with action, typed content, and metadata +- `elicit(Class)` - Request user input with default message using Class type, returns `StructuredElicitResult` +- `elicit(Consumer, TypeReference)` - Request user input with custom configuration, returns `StructuredElicitResult` +- `elicit(Consumer, Class)` - Request user input with custom configuration using Class type, returns `StructuredElicitResult` +- `elicit(ElicitRequest)` - Request user input with full control over the elicitation request +- `sample(...)` - Request LLM sampling with various configuration options +- `roots()` - Access root directories (returns `Optional`) +- `ping()` - Send ping to check connection + +`McpAsyncRequestContext` provides the same methods but with reactive return types (`Mono` instead of `T` or `Optional`). + +This unified context approach simplifies method signatures and provides a consistent API across different operation types and execution modes (stateful vs stateless, sync vs async). + ### Async Tool Example ```java @@ -1771,7 +1973,7 @@ public class AsyncElicitationHandler { public class MyMcpClient { public static McpSyncClient createSyncClientWithElicitation(ElicitationHandler elicitationHandler) { - Function elicitationHandler = + Function elicitationHandlerFunc = new SyncMcpElicitationProvider(List.of(elicitationHandler)).getElicitationHandler(); McpSyncClient client = McpClient.sync(transport) @@ -1779,14 +1981,14 @@ public class MyMcpClient { .elicitation() // Enable elicitation support // Other capabilities... .build()) - .elicitationHandler(elicitationHandler) + .elicitationHandler(elicitationHandlerFunc) .build(); return client; } public static McpAsyncClient createAsyncClientWithElicitation(AsyncElicitationHandler asyncElicitationHandler) { - Function> elicitationHandler = + Function> elicitationHandlerFunc = new AsyncMcpElicitationProvider(List.of(asyncElicitationHandler)).getElicitationHandler(); McpAsyncClient client = McpClient.async(transport) @@ -1794,7 +1996,7 @@ public class MyMcpClient { .elicitation() // Enable elicitation support // Other capabilities... .build()) - .elicitationHandler(elicitationHandler) + .elicitationHandler(elicitationHandlerFunc) .build(); return client; diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultElicitationSpec.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultElicitationSpec.java new file mode 100644 index 0000000..674f53d --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultElicitationSpec.java @@ -0,0 +1,48 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.HashMap; +import java.util.Map; + +import org.springaicommunity.mcp.context.McpRequestContextTypes.ElicitationSpec; + +public class DefaultElicitationSpec implements ElicitationSpec { + + protected String message; + + protected Map meta = new HashMap<>(); + + protected String message() { + return message; + } + + protected Map meta() { + return meta; + } + + @Override + public ElicitationSpec message(String message) { + this.message = message; + return this; + } + + @Override + public ElicitationSpec meta(Map m) { + if (m != null) { + this.meta.putAll(m); + } + return this; + } + + @Override + public ElicitationSpec meta(String k, Object v) { + if (k != null && v != null) { + this.meta.put(k, v); + } + return this; + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultLoggingSpec.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultLoggingSpec.java new file mode 100644 index 0000000..85c5e83 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultLoggingSpec.java @@ -0,0 +1,60 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.HashMap; +import java.util.Map; + +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import org.springaicommunity.mcp.context.McpRequestContextTypes.LoggingSpec; + +/** + * @author Christian Tzolov + */ +public class DefaultLoggingSpec implements LoggingSpec { + + protected String message; + + protected String logger; + + protected LoggingLevel level = LoggingLevel.INFO; + + protected Map meta = new HashMap<>(); + + @Override + public LoggingSpec message(String message) { + this.message = message; + return this; + } + + @Override + public LoggingSpec logger(String logger) { + this.logger = logger; + return this; + } + + @Override + public LoggingSpec level(LoggingLevel level) { + this.level = level; + return this; + } + + @Override + public LoggingSpec meta(Map m) { + if (m != null) { + this.meta.putAll(m); + } + return this; + } + + @Override + public LoggingSpec meta(String k, Object v) { + if (k != null && v != null) { + this.meta.put(k, v); + } + return this; + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java new file mode 100644 index 0000000..716c195 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java @@ -0,0 +1,518 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.lang.reflect.Type; +import java.util.Map; +import java.util.function.Consumer; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.Implementation; +import io.modelcontextprotocol.spec.McpSchema.ListRootsResult; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.method.tool.utils.ConcurrentReferenceHashMap; +import org.springaicommunity.mcp.method.tool.utils.JsonParser; +import org.springaicommunity.mcp.method.tool.utils.JsonSchemaGenerator; +import reactor.core.publisher.Mono; + +/** + * Async (Reactor) implementation of McpAsyncRequestContext that returns Mono of value + * types. + * + * @author Christian Tzolov + */ +public class DefaultMcpAsyncRequestContext implements McpAsyncRequestContext { + + private static final Logger logger = LoggerFactory.getLogger(DefaultMcpAsyncRequestContext.class); + + private static final Map> typeSchemaCache = new ConcurrentReferenceHashMap<>(256); + + private static TypeReference> MAP_TYPE_REF = new TypeReference>() { + }; + + private final McpSchema.Request request; + + private final McpAsyncServerExchange exchange; + + private DefaultMcpAsyncRequestContext(McpSchema.Request request, McpAsyncServerExchange exchange) { + Assert.notNull(request, "Request must not be null"); + Assert.notNull(exchange, "Exchange must not be null"); + this.request = request; + this.exchange = exchange; + } + + // Roots + + @Override + public Mono roots() { + if (this.exchange.getClientCapabilities() == null || this.exchange.getClientCapabilities().roots() == null) { + logger.warn("Roots not supported by the client! Ignoring the roots request for request:" + this.request); + return Mono.empty(); + } + return this.exchange.listRoots(); + } + + // Elicitation + + @Override + public Mono> elicit(Consumer spec, TypeReference type) { + Assert.notNull(type, "Elicitation response type must not be null"); + Assert.notNull(spec, "Elicitation spec consumer must not be null"); + DefaultElicitationSpec elicitationSpec = new DefaultElicitationSpec(); + spec.accept(elicitationSpec); + return this.elicitationInternal(elicitationSpec.message, type.getType(), elicitationSpec.meta) + .map(er -> new StructuredElicitResult(er.action(), convertMapToType(er.content(), type), er.meta())); + } + + @Override + public Mono> elicit(Consumer spec, Class type) { + Assert.notNull(type, "Elicitation response type must not be null"); + Assert.notNull(spec, "Elicitation spec consumer must not be null"); + DefaultElicitationSpec elicitationSpec = new DefaultElicitationSpec(); + spec.accept(elicitationSpec); + return this.elicitationInternal(elicitationSpec.message, type, elicitationSpec.meta) + .map(er -> new StructuredElicitResult(er.action(), convertMapToType(er.content(), type), er.meta())); + } + + @Override + public Mono> elicit(TypeReference type) { + Assert.notNull(type, "Elicitation response type must not be null"); + return this.elicitationInternal("Please provide the required information.", type.getType(), null) + .map(er -> new StructuredElicitResult(er.action(), convertMapToType(er.content(), type), er.meta())); + } + + @Override + public Mono> elicit(Class type) { + Assert.notNull(type, "Elicitation response type must not be null"); + return this.elicitationInternal("Please provide the required information.", type, null) + .map(er -> new StructuredElicitResult(er.action(), convertMapToType(er.content(), type), er.meta())); + } + + @Override + public Mono elicit(ElicitRequest elicitRequest) { + Assert.notNull(elicitRequest, "Elicit request must not be null"); + + if (this.exchange.getClientCapabilities() == null + || this.exchange.getClientCapabilities().elicitation() == null) { + logger.warn("Elicitation not supported by the client! Ignoring the elicitation request for request:" + + elicitRequest); + return Mono.empty(); + } + + return this.exchange.createElicitation(elicitRequest); + } + + public Mono elicitationInternal(String message, Type type, Map meta) { + Assert.hasText(message, "Elicitation message must not be empty"); + Assert.notNull(type, "Elicitation response type must not be null"); + + Map schema = typeSchemaCache.computeIfAbsent(type, t -> this.generateElicitSchema(t)); + + return this.elicit(ElicitRequest.builder().message(message).requestedSchema(schema).meta(meta).build()); + } + + private Map generateElicitSchema(Type type) { + Map schema = JsonParser.fromJson(JsonSchemaGenerator.generateFromType(type), MAP_TYPE_REF); + // remove as elicitation schema does not support it + schema.remove("$schema"); + return schema; + } + + private static T convertMapToType(Map map, Class targetType) { + ObjectMapper mapper = new ObjectMapper(); + JavaType javaType = mapper.getTypeFactory().constructType(targetType); + return mapper.convertValue(map, javaType); + } + + private static T convertMapToType(Map map, TypeReference targetType) { + ObjectMapper mapper = new ObjectMapper(); + JavaType javaType = mapper.getTypeFactory().constructType(targetType); + return mapper.convertValue(map, javaType); + } + + // Sampling + + @Override + public Mono sample(String... messages) { + return this.sample(s -> s.message(messages)); + } + + @Override + public Mono sample(Consumer samplingSpec) { + Assert.notNull(samplingSpec, "Sampling spec consumer must not be null"); + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + samplingSpec.accept(spec); + + var progressToken = this.request.progressToken(); + + if (!Utils.hasText(progressToken)) { + logger.warn("Progress notification not supported by the client!"); + } + return this.sample(McpSchema.CreateMessageRequest.builder() + .messages(spec.messages) + .modelPreferences(spec.modelPreferences) + .systemPrompt(spec.systemPrompt) + .temperature(spec.temperature) + .maxTokens(spec.maxTokens != null && spec.maxTokens > 0 ? spec.maxTokens : 500) + .stopSequences(spec.stopSequences.isEmpty() ? null : spec.stopSequences) + .includeContext(spec.includeContextStrategy) + .meta(spec.metadata.isEmpty() ? null : spec.metadata) + .progressToken(progressToken) + .meta(spec.meta.isEmpty() ? null : spec.meta) + .build()); + } + + @Override + public Mono sample(CreateMessageRequest createMessageRequest) { + + // check if supported + if (this.exchange.getClientCapabilities() == null || this.exchange.getClientCapabilities().sampling() == null) { + logger.warn("Sampling not supported by the client! Ignoring the sampling request for messages:" + + createMessageRequest); + return Mono.empty(); + } + + return this.exchange.createMessage(createMessageRequest); + } + + // Progress + + @Override + public Mono progress(int percentage) { + Assert.isTrue(percentage >= 0 && percentage <= 100, "Percentage must be between 0 and 100"); + return this.progress(p -> p.progress(percentage / 100.0).total(1.0).message(null)); + } + + @Override + public Mono progress(Consumer progressSpec) { + + Assert.notNull(progressSpec, "Progress spec consumer must not be null"); + DefaultProgressSpec spec = new DefaultProgressSpec(); + + progressSpec.accept(spec); + + if (!Utils.hasText(this.request.progressToken())) { + logger.warn("Progress notification not supported by the client!"); + return Mono.empty(); + } + + return this.progress(new ProgressNotification(this.request.progressToken(), spec.progress, spec.total, + spec.message, spec.meta)); + } + + @Override + public Mono progress(ProgressNotification progressNotification) { + return this.exchange.progressNotification(progressNotification).then(Mono.empty()); + } + + // Ping + + @Override + public Mono ping() { + return this.exchange.ping(); + } + + // Logging + + @Override + public Mono log(Consumer logSpec) { + Assert.notNull(logSpec, "Logging spec consumer must not be null"); + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + logSpec.accept(spec); + + return this.exchange + .loggingNotification(LoggingMessageNotification.builder() + .data(spec.message) + .level(spec.level) + .logger(spec.logger) + .meta(spec.meta) + .build()) + .then(); + } + + @Override + public Mono debug(String message) { + return this.logInternal(message, LoggingLevel.DEBUG); + } + + @Override + public Mono info(String message) { + return this.logInternal(message, LoggingLevel.INFO); + } + + @Override + public Mono warn(String message) { + return this.logInternal(message, LoggingLevel.WARNING); + } + + @Override + public Mono error(String message) { + return this.logInternal(message, LoggingLevel.ERROR); + } + + private Mono logInternal(String message, LoggingLevel level) { + Assert.hasText(message, "Log message must not be empty"); + return this.exchange + .loggingNotification(LoggingMessageNotification.builder().data(message).level(level).build()) + .then(); + } + + // Getters + + @Override + public McpSchema.Request request() { + return this.request; + } + + @Override + public McpAsyncServerExchange exchange() { + return this.exchange; + } + + @Override + public String sessionId() { + return this.exchange.sessionId(); + } + + @Override + public Implementation clientInfo() { + return this.exchange.getClientInfo(); + } + + @Override + public ClientCapabilities clientCapabilities() { + return this.exchange.getClientCapabilities(); + } + + @Override + public Map requestMeta() { + return this.request.meta(); + } + + @Override + public McpTransportContext transportContext() { + return this.exchange.transportContext(); + } + + // Builder + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private McpSchema.Request request; + + private McpAsyncServerExchange exchange; + + private boolean isStateless = false; + + private McpTransportContext transportContext; + + private Builder() { + } + + public Builder request(McpSchema.Request request) { + this.request = request; + return this; + } + + public Builder exchange(McpAsyncServerExchange exchange) { + this.exchange = exchange; + return this; + } + + public Builder stateless(boolean isStateless) { + this.isStateless = isStateless; + return this; + } + + public Builder transportContext(McpTransportContext transportContext) { + this.transportContext = transportContext; + return this; + } + + public McpAsyncRequestContext build() { + if (this.isStateless) { + return new StatelessAsyncRequestContext(this.request, this.transportContext); + } + return new DefaultMcpAsyncRequestContext(this.request, this.exchange); + } + + } + + private static class StatelessAsyncRequestContext implements McpAsyncRequestContext { + + private final McpSchema.Request request; + + private McpTransportContext transportContext; + + public StatelessAsyncRequestContext(McpSchema.Request request, McpTransportContext transportContext) { + this.request = request; + this.transportContext = transportContext; + } + + @Override + public Mono roots() { + logger.warn("Roots not supported by the client! Ignoring the roots request"); + return Mono.empty(); + } + + @Override + public Mono> elicit(Consumer spec, TypeReference returnType) { + logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); + return Mono.empty(); + } + + @Override + public Mono> elicit(TypeReference type) { + logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); + return Mono.empty(); + } + + @Override + public Mono> elicit(Consumer spec, Class returnType) { + logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); + return Mono.empty(); + } + + @Override + public Mono> elicit(Class type) { + logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); + return Mono.empty(); + } + + @Override + public Mono elicit(ElicitRequest elicitRequest) { + logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); + return Mono.empty(); + } + + @Override + public Mono sample(String... messages) { + logger.warn("Sampling not supported by the client! Ignoring the sampling request"); + return Mono.empty(); + } + + @Override + public Mono sample(Consumer samplingSpec) { + logger.warn("Sampling not supported by the client! Ignoring the sampling request"); + return Mono.empty(); + } + + @Override + public Mono sample(CreateMessageRequest createMessageRequest) { + logger.warn("Sampling not supported by the client! Ignoring the sampling request"); + return Mono.empty(); + } + + @Override + public Mono progress(int progress) { + logger.warn("Progress not supported by the client! Ignoring the progress request"); + return Mono.empty(); + } + + @Override + public Mono progress(Consumer progressSpec) { + logger.warn("Progress not supported by the client! Ignoring the progress request"); + return Mono.empty(); + } + + @Override + public Mono progress(ProgressNotification progressNotification) { + logger.warn("Progress not supported by the client! Ignoring the progress request"); + return Mono.empty(); + } + + @Override + public Mono ping() { + logger.warn("Ping not supported by the client! Ignoring the ping request"); + return Mono.empty(); + } + + @Override + public Mono log(Consumer logSpec) { + logger.warn("Logging not supported by the client! Ignoring the logging request"); + return Mono.empty(); + } + + @Override + public Mono debug(String message) { + logger.warn("Debug not supported by the client! Ignoring the debug request"); + return Mono.empty(); + } + + @Override + public Mono info(String message) { + logger.warn("Info not supported by the client! Ignoring the info request"); + return Mono.empty(); + } + + @Override + public Mono warn(String message) { + logger.warn("Warn not supported by the client! Ignoring the warn request"); + return Mono.empty(); + } + + @Override + public Mono error(String message) { + logger.warn("Error not supported by the client! Ignoring the error request"); + return Mono.empty(); + } + + // Getters + + public McpSchema.Request request() { + return this.request; + } + + public McpAsyncServerExchange exchange() { + logger.warn("Stateless servers do not support exchange! Returning null"); + return null; + } + + public String sessionId() { + logger.warn("Stateless servers do not support session ID! Returning null"); + return null; + } + + public Implementation clientInfo() { + logger.warn("Stateless servers do not support client info! Returning null"); + return null; + } + + public ClientCapabilities clientCapabilities() { + logger.warn("Stateless servers do not support client capabilities! Returning null"); + return null; + } + + public Map requestMeta() { + return this.request.meta(); + } + + public McpTransportContext transportContext() { + return transportContext; + } + + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java new file mode 100644 index 0000000..0450b8a --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java @@ -0,0 +1,540 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.lang.reflect.Type; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.Implementation; +import io.modelcontextprotocol.spec.McpSchema.ListRootsResult; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.method.tool.utils.ConcurrentReferenceHashMap; +import org.springaicommunity.mcp.method.tool.utils.JsonParser; +import org.springaicommunity.mcp.method.tool.utils.JsonSchemaGenerator; + +/** + * @author Christian Tzolov + */ +public class DefaultMcpSyncRequestContext implements McpSyncRequestContext { + + private static final Logger logger = LoggerFactory.getLogger(DefaultMcpSyncRequestContext.class); + + private static final Map> typeSchemaCache = new ConcurrentReferenceHashMap<>(256); + + private static TypeReference> MAP_TYPE_REF = new TypeReference>() { + }; + + private final McpSchema.Request request; + + private final McpSyncServerExchange exchange; + + private DefaultMcpSyncRequestContext(McpSchema.Request request, McpSyncServerExchange exchange) { + Assert.notNull(request, "Request must not be null"); + Assert.notNull(exchange, "Exchange must not be null"); + this.request = request; + this.exchange = exchange; + } + + // Roots + + public Optional roots() { + if (this.exchange.getClientCapabilities() == null || this.exchange.getClientCapabilities().roots() == null) { + logger.warn("Roots not supported by the client! Ignoring the roots request for request:" + this.request); + return Optional.empty(); + } + return Optional.of(this.exchange.listRoots()); + } + + // Elicitation + + @Override + public Optional> elicit(Class type) { + Assert.notNull(type, "Elicitation response type must not be null"); + + Optional elicitResult = this.elicitationInternal("Please provide the required information.", type, + null); + + if (!elicitResult.isPresent() || elicitResult.get().action() != ElicitResult.Action.ACCEPT) { + return Optional.empty(); + } + + return Optional.of(new StructuredElicitResult<>(elicitResult.get().action(), + convertMapToType(elicitResult.get().content(), type), elicitResult.get().meta())); + } + + @Override + public Optional> elicit(TypeReference type) { + Assert.notNull(type, "Elicitation response type must not be null"); + + Optional elicitResult = this.elicitationInternal("Please provide the required information.", + type.getType(), null); + + if (!elicitResult.isPresent() || elicitResult.get().action() != ElicitResult.Action.ACCEPT) { + return Optional.empty(); + } + + return Optional.of(new StructuredElicitResult<>(elicitResult.get().action(), + convertMapToType(elicitResult.get().content(), type), elicitResult.get().meta())); + } + + @Override + public Optional> elicit(Consumer params, Class returnType) { + Assert.notNull(returnType, "Elicitation response type must not be null"); + Assert.notNull(params, "Elicitation params must not be null"); + + DefaultElicitationSpec paramSpec = new DefaultElicitationSpec(); + params.accept(paramSpec); + + Optional elicitResult = this.elicitationInternal(paramSpec.message(), returnType, + paramSpec.meta()); + + if (!elicitResult.isPresent() || elicitResult.get().action() != ElicitResult.Action.ACCEPT) { + return Optional.empty(); + } + + return Optional.of(new StructuredElicitResult<>(elicitResult.get().action(), + convertMapToType(elicitResult.get().content(), returnType), elicitResult.get().meta())); + } + + @Override + public Optional> elicit(Consumer params, + TypeReference returnType) { + Assert.notNull(returnType, "Elicitation response type must not be null"); + Assert.notNull(params, "Elicitation params must not be null"); + + DefaultElicitationSpec paramSpec = new DefaultElicitationSpec(); + params.accept(paramSpec); + + Optional elicitResult = this.elicitationInternal(paramSpec.message(), returnType.getType(), + paramSpec.meta()); + + if (!elicitResult.isPresent() || elicitResult.get().action() != ElicitResult.Action.ACCEPT) { + return Optional.empty(); + } + + return Optional.of(new StructuredElicitResult<>(elicitResult.get().action(), + convertMapToType(elicitResult.get().content(), returnType), elicitResult.get().meta())); + } + + @Override + public Optional elicit(ElicitRequest elicitRequest) { + Assert.notNull(elicitRequest, "Elicit request must not be null"); + + if (this.exchange.getClientCapabilities() == null + || this.exchange.getClientCapabilities().elicitation() == null) { + logger.warn("Elicitation not supported by the client! Ignoring the elicitation request for request:" + + elicitRequest); + return Optional.empty(); + } + + ElicitResult elicitResult = this.exchange.createElicitation(elicitRequest); + + return Optional.of(elicitResult); + } + + private Optional elicitationInternal(String message, Type type, Map meta) { + Assert.hasText(message, "Elicitation message must not be empty"); + Assert.notNull(type, "Elicitation response type must not be null"); + + Map schema = typeSchemaCache.computeIfAbsent(type, t -> this.generateElicitSchema(t)); + + return this.elicit(ElicitRequest.builder().message(message).requestedSchema(schema).meta(meta).build()); + } + + private Map generateElicitSchema(Type type) { + Map schema = JsonParser.fromJson(JsonSchemaGenerator.generateFromType(type), MAP_TYPE_REF); + // remove $schema as elicitation schema does not support it + schema.remove("$schema"); + return schema; + } + + private static T convertMapToType(Map map, Class targetType) { + ObjectMapper mapper = new ObjectMapper(); + JavaType javaType = mapper.getTypeFactory().constructType(targetType); + return mapper.convertValue(map, javaType); + } + + private static T convertMapToType(Map map, TypeReference targetType) { + ObjectMapper mapper = new ObjectMapper(); + JavaType javaType = mapper.getTypeFactory().constructType(targetType); + return mapper.convertValue(map, javaType); + } + + // Sampling + + @Override + public Optional sample(String... messages) { + return this.sample(s -> s.message(messages)); + } + + @Override + public Optional sample(Consumer samplingSpec) { + Assert.notNull(samplingSpec, "Sampling spec consumer must not be null"); + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + samplingSpec.accept(spec); + + var progressToken = this.request.progressToken(); + + if (!Utils.hasText(progressToken)) { + logger.warn("Progress notification not supported by the client!"); + } + return this.sample(McpSchema.CreateMessageRequest.builder() + .messages(spec.messages) + .modelPreferences(spec.modelPreferences) + .systemPrompt(spec.systemPrompt) + .temperature(spec.temperature) + .maxTokens(spec.maxTokens != null && spec.maxTokens > 0 ? spec.maxTokens : 500) + .stopSequences(spec.stopSequences.isEmpty() ? null : spec.stopSequences) + .includeContext(spec.includeContextStrategy) + .meta(spec.metadata.isEmpty() ? null : spec.metadata) + .progressToken(progressToken) + .meta(spec.meta.isEmpty() ? null : spec.meta) + .build()); + } + + @Override + public Optional sample(CreateMessageRequest createMessageRequest) { + + // check if supported + if (this.exchange.getClientCapabilities() == null || this.exchange.getClientCapabilities().sampling() == null) { + logger.warn("Sampling not supported by the client! Ignoring the sampling request for messages:" + + createMessageRequest); + return Optional.empty(); + } + + return Optional.of(this.exchange.createMessage(createMessageRequest)); + } + + // Progress + + @Override + public void progress(int percentage) { + Assert.isTrue(percentage >= 0 && percentage <= 100, "Percentage must be between 0 and 100"); + this.progress(p -> p.progress(percentage / 100.0).total(1.0).message(null)); + } + + @Override + public void progress(Consumer progressSpec) { + + Assert.notNull(progressSpec, "Progress spec consumer must not be null"); + DefaultProgressSpec spec = new DefaultProgressSpec(); + + progressSpec.accept(spec); + + if (!Utils.hasText(this.request.progressToken())) { + logger.warn("Progress notification not supported by the client!"); + return; + } + + this.progress(new ProgressNotification(this.request.progressToken(), spec.progress, spec.total, spec.message, + spec.meta)); + } + + @Override + public void progress(ProgressNotification progressNotification) { + this.exchange.progressNotification(progressNotification); + } + + // Ping + + @Override + public void ping() { + this.exchange.ping(); + } + + // Logging + + @Override + public void log(Consumer logSpec) { + Assert.notNull(logSpec, "Logging spec consumer must not be null"); + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + logSpec.accept(spec); + + this.exchange.loggingNotification(LoggingMessageNotification.builder() + .data(spec.message) + .level(spec.level) + .logger(spec.logger) + .meta(spec.meta) + .build()); + } + + @Override + public void debug(String message) { + this.logInternal(message, LoggingLevel.DEBUG); + } + + @Override + public void info(String message) { + this.logInternal(message, LoggingLevel.INFO); + } + + @Override + public void warn(String message) { + this.logInternal(message, LoggingLevel.WARNING); + } + + @Override + public void error(String message) { + this.logInternal(message, LoggingLevel.ERROR); + } + + private void logInternal(String message, LoggingLevel level) { + Assert.hasText(message, "Log message must not be empty"); + this.exchange.loggingNotification(LoggingMessageNotification.builder().data(message).level(level).build()); + } + + // Getters + + @Override + public McpSchema.Request request() { + return this.request; + } + + @Override + public McpSyncServerExchange exchange() { + return this.exchange; + } + + @Override + public String sessionId() { + return this.exchange.sessionId(); + } + + @Override + public Implementation clientInfo() { + return this.exchange.getClientInfo(); + } + + @Override + public ClientCapabilities clientCapabilities() { + return this.exchange.getClientCapabilities(); + } + + @Override + public Map requestMeta() { + return this.request.meta(); + } + + @Override + public McpTransportContext transportContext() { + return this.exchange.transportContext(); + } + + // Builder + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private McpSchema.Request request; + + private McpSyncServerExchange exchange; + + private McpTransportContext transportContext; + + private boolean isStateless = false; + + private Builder() { + } + + public Builder request(McpSchema.Request request) { + this.request = request; + return this; + } + + public Builder exchange(McpSyncServerExchange exchange) { + this.exchange = exchange; + return this; + } + + public Builder transportContext(McpTransportContext transportContext) { + this.transportContext = transportContext; + return this; + } + + public Builder stateless(boolean isStateless) { + this.isStateless = isStateless; + return this; + } + + public McpSyncRequestContext build() { + if (this.isStateless) { + return new StatelessMcpSyncRequestContext(this.request, this.transportContext); + } + return new DefaultMcpSyncRequestContext(this.request, this.exchange); + } + + } + + public final static class StatelessMcpSyncRequestContext implements McpSyncRequestContext { + + private static final Logger logger = LoggerFactory.getLogger(StatelessMcpSyncRequestContext.class); + + private final McpSchema.Request request; + + private final McpTransportContext transportContext; + + private StatelessMcpSyncRequestContext(McpSchema.Request request, McpTransportContext transportContext) { + this.request = request; + this.transportContext = transportContext; + } + + @Override + public Optional roots() { + logger.warn("Roots not supported by the client! Ignoring the roots request"); + return Optional.empty(); + } + + @Override + public Optional> elicit(Class type) { + logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); + return Optional.empty(); + } + + @Override + public Optional> elicit(Consumer params, Class returnType) { + logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); + return Optional.empty(); + } + + @Override + public Optional> elicit(TypeReference type) { + logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); + return Optional.empty(); + } + + @Override + public Optional> elicit(Consumer params, + TypeReference returnType) { + logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); + return Optional.empty(); + } + + @Override + public Optional elicit(ElicitRequest elicitRequest) { + logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); + return Optional.empty(); + } + + @Override + public Optional sample(String... messages) { + logger.warn("Stateless servers do not support sampling! Ignoring the sampling request"); + return Optional.empty(); + } + + @Override + public Optional sample(Consumer samplingSpec) { + logger.warn("Stateless servers do not support sampling! Ignoring the sampling request"); + return Optional.empty(); + } + + @Override + public Optional sample(CreateMessageRequest createMessageRequest) { + logger.warn("Stateless servers do not support sampling! Ignoring the sampling request"); + return Optional.empty(); + } + + @Override + public void progress(int progress) { + logger.warn("Stateless servers do not support progress notifications! Ignoring the progress request"); + } + + @Override + public void progress(Consumer progressSpec) { + logger.warn("Stateless servers do not support progress notifications! Ignoring the progress request"); + } + + @Override + public void progress(ProgressNotification progressNotification) { + logger.warn("Stateless servers do not support progress notifications! Ignoring the progress request"); + } + + @Override + public void ping() { + logger.warn("Stateless servers do not support ping! Ignoring the ping request"); + } + + @Override + public void log(Consumer logSpec) { + logger.warn("Stateless servers do not support logging! Ignoring the logging request"); + } + + @Override + public void debug(String message) { + logger.warn("Stateless servers do not support debugging! Ignoring the debugging request"); + } + + @Override + public void info(String message) { + logger.warn("Stateless servers do not support info logging! Ignoring the info request"); + } + + @Override + public void warn(String message) { + logger.warn("Stateless servers do not support warning logging! Ignoring the warning request"); + } + + @Override + public void error(String message) { + logger.warn("Stateless servers do not support error logging! Ignoring the error request"); + } + + public McpSchema.Request request() { + return this.request; + } + + public McpTransportContext transportContext() { + return transportContext; + } + + public String sessionId() { + logger.warn("Stateless servers do not support session ID! Returning null"); + return null; + } + + public Implementation clientInfo() { + logger.warn("Stateless servers do not support client info! Returning null"); + return null; + } + + public ClientCapabilities clientCapabilities() { + logger.warn("Stateless servers do not support client capabilities! Returning null"); + return null; + } + + public Map requestMeta() { + return this.request.meta(); + } + + @Override + public McpSyncServerExchange exchange() { + logger.warn("Stateless servers do not support exchange! Returning null"); + return null; + } + + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultProgressSpec.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultProgressSpec.java new file mode 100644 index 0000000..65af97f --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultProgressSpec.java @@ -0,0 +1,59 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.HashMap; +import java.util.Map; + +import org.springaicommunity.mcp.context.McpRequestContextTypes.ProgressSpec; + +/** + * @author Christian Tzolov + */ +public class DefaultProgressSpec implements ProgressSpec { + + protected double progress = 0.0; + + protected double total = 1.0; + + protected String message; + + protected Map meta = new HashMap<>(); + + @Override + public ProgressSpec progress(double progress) { + this.progress = progress; + return this; + } + + @Override + public ProgressSpec total(double total) { + this.total = total; + return this; + } + + @Override + public ProgressSpec message(String message) { + this.message = message; + return this; + } + + @Override + public ProgressSpec meta(Map m) { + if (m != null) { + this.meta.putAll(m); + } + return this; + } + + @Override + public ProgressSpec meta(String k, Object v) { + if (k != null && v != null) { + this.meta.put(k, v); + } + return this; + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultSamplingSpec.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultSamplingSpec.java new file mode 100644 index 0000000..1801b7f --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultSamplingSpec.java @@ -0,0 +1,199 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import io.modelcontextprotocol.spec.McpSchema.AudioContent; +import io.modelcontextprotocol.spec.McpSchema.Content; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest.ContextInclusionStrategy; +import io.modelcontextprotocol.spec.McpSchema.EmbeddedResource; +import io.modelcontextprotocol.spec.McpSchema.ImageContent; +import io.modelcontextprotocol.spec.McpSchema.ModelHint; +import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; +import io.modelcontextprotocol.spec.McpSchema.ResourceLink; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.util.Assert; +import org.springaicommunity.mcp.context.McpRequestContextTypes.ModelPreferenceSpec; +import org.springaicommunity.mcp.context.McpRequestContextTypes.SamplingSpec; + +/** + * @author Christian Tzolov + */ +public class DefaultSamplingSpec implements SamplingSpec { + + protected List messages = new ArrayList<>(); + + protected ModelPreferences modelPreferences; + + protected String systemPrompt; + + protected Double temperature; + + protected Integer maxTokens; + + protected List stopSequences = new ArrayList<>(); + + protected Map metadata = new HashMap<>(); + + protected Map meta = new HashMap<>(); + + protected ContextInclusionStrategy includeContextStrategy = ContextInclusionStrategy.NONE; + + @Override + public SamplingSpec message(ResourceLink... content) { + return this.messageInternal(content); + } + + @Override + public SamplingSpec message(EmbeddedResource... content) { + return this.messageInternal(content); + } + + @Override + public SamplingSpec message(AudioContent... content) { + return this.messageInternal(content); + } + + @Override + public SamplingSpec message(ImageContent... content) { + return this.messageInternal(content); + } + + @Override + public SamplingSpec message(TextContent... content) { + return this.messageInternal(content); + } + + private SamplingSpec messageInternal(Content... content) { + this.messages.addAll(List.of(content).stream().map(c -> new SamplingMessage(Role.USER, c)).toList()); + return this; + } + + @Override + public SamplingSpec message(SamplingMessage... message) { + this.messages.addAll(List.of(message)); + return this; + } + + @Override + public SamplingSpec modelPreferences(Consumer modelPreferenceSpec) { + var modelPreferencesSpec = new DefaultModelPreferenceSpec(); + modelPreferenceSpec.accept(modelPreferencesSpec); + + this.modelPreferences = ModelPreferences.builder() + .hints(modelPreferencesSpec.modelHints) + .costPriority(modelPreferencesSpec.costPriority) + .speedPriority(modelPreferencesSpec.speedPriority) + .intelligencePriority(modelPreferencesSpec.intelligencePriority) + .build(); + return this; + } + + @Override + public SamplingSpec systemPrompt(String systemPrompt) { + this.systemPrompt = systemPrompt; + return this; + } + + @Override + public SamplingSpec includeContextStrategy(ContextInclusionStrategy includeContextStrategy) { + this.includeContextStrategy = includeContextStrategy; + return this; + } + + @Override + public SamplingSpec temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + @Override + public SamplingSpec maxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + @Override + public SamplingSpec stopSequences(String... stopSequences) { + this.stopSequences.addAll(List.of(stopSequences)); + return this; + } + + @Override + public SamplingSpec metadata(Map m) { + this.metadata.putAll(m); + return this; + } + + @Override + public SamplingSpec metadata(String k, Object v) { + this.metadata.put(k, v); + return this; + } + + @Override + public SamplingSpec meta(Map m) { + this.meta.putAll(m); + return this; + } + + @Override + public SamplingSpec meta(String k, Object v) { + this.meta.put(k, v); + return this; + } + + public static class DefaultModelPreferenceSpec implements ModelPreferenceSpec { + + private List modelHints = new ArrayList<>(); + + private Double costPriority; + + private Double speedPriority; + + private Double intelligencePriority; + + @Override + public ModelPreferenceSpec modelHints(String... models) { + Assert.notNull(models, "Models must not be null"); + this.modelHints.addAll(List.of(models).stream().map(ModelHint::new).toList()); + return this; + } + + @Override + public ModelPreferenceSpec modelHint(String modelHint) { + Assert.notNull(modelHint, "Model hint must not be null"); + this.modelHints.add(new ModelHint(modelHint)); + return this; + } + + @Override + public ModelPreferenceSpec costPriority(Double costPriority) { + this.costPriority = costPriority; + return this; + } + + @Override + public ModelPreferenceSpec speedPriority(Double speedPriority) { + this.speedPriority = speedPriority; + return this; + } + + @Override + public ModelPreferenceSpec intelligencePriority(Double intelligencePriority) { + this.intelligencePriority = intelligencePriority; + return this; + } + + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java new file mode 100644 index 0000000..08bad44 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java @@ -0,0 +1,82 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.Map; +import java.util.function.Consumer; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.ListRootsResult; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import reactor.core.publisher.Mono; + +/** + * Async (Reactor) version of McpSyncRequestContext that returns Mono of value types. + * + * @author Christian Tzolov + */ +public interface McpAsyncRequestContext extends McpRequestContextTypes { + + // -------------------------------------- + // Roots + // -------------------------------------- + Mono roots(); + + // -------------------------------------- + // Elicitation + // -------------------------------------- + + Mono> elicit(Class type); + + Mono> elicit(TypeReference type); + + Mono> elicit(Consumer spec, TypeReference returnType); + + Mono> elicit(Consumer spec, Class returnType); + + Mono elicit(ElicitRequest elicitRequest); + + // -------------------------------------- + // Sampling + // -------------------------------------- + Mono sample(String... messages); + + Mono sample(Consumer samplingSpec); + + Mono sample(CreateMessageRequest createMessageRequest); + + // -------------------------------------- + // Progress + // -------------------------------------- + Mono progress(int progress); + + Mono progress(Consumer progressSpec); + + Mono progress(ProgressNotification progressNotification); + + // -------------------------------------- + // Ping + // -------------------------------------- + Mono ping(); + + // -------------------------------------- + // Logging + // -------------------------------------- + Mono log(Consumer logSpec); + + Mono debug(String message); + + Mono info(String message); + + Mono warn(String message); + + Mono error(String message); + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpRequestContextTypes.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpRequestContextTypes.java new file mode 100644 index 0000000..70c9b40 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpRequestContextTypes.java @@ -0,0 +1,150 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.AudioContent; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest.ContextInclusionStrategy; +import io.modelcontextprotocol.spec.McpSchema.EmbeddedResource; +import io.modelcontextprotocol.spec.McpSchema.ImageContent; +import io.modelcontextprotocol.spec.McpSchema.Implementation; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.ResourceLink; +import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; +import io.modelcontextprotocol.spec.McpSchema.TextContent; + +/** + * @author Christian Tzolov + */ +public interface McpRequestContextTypes { + + interface ElicitationSpec { + + ElicitationSpec message(String message); + + ElicitationSpec meta(Map m); + + ElicitationSpec meta(String k, Object v); + + } + + // -------------------------------------- + // Sampling + // -------------------------------------- + + interface ModelPreferenceSpec { + + ModelPreferenceSpec modelHints(String... models); + + ModelPreferenceSpec modelHint(String modelHint); + + ModelPreferenceSpec costPriority(Double costPriority); + + ModelPreferenceSpec speedPriority(Double speedPriority); + + ModelPreferenceSpec intelligencePriority(Double intelligencePriority); + + } + + interface SamplingSpec { + + SamplingSpec message(ResourceLink... content); + + SamplingSpec message(EmbeddedResource... content); + + SamplingSpec message(AudioContent... content); + + SamplingSpec message(ImageContent... content); + + SamplingSpec message(TextContent... content); + + default SamplingSpec message(String... text) { + return message(List.of(text).stream().map(t -> new TextContent(t)).toList().toArray(new TextContent[0])); + } + + SamplingSpec message(SamplingMessage... message); + + SamplingSpec modelPreferences(Consumer modelPreferenceSpec); + + SamplingSpec systemPrompt(String systemPrompt); + + SamplingSpec includeContextStrategy(ContextInclusionStrategy includeContextStrategy); + + SamplingSpec temperature(Double temperature); + + SamplingSpec maxTokens(Integer maxTokens); + + SamplingSpec stopSequences(String... stopSequences); + + SamplingSpec metadata(Map m); + + SamplingSpec metadata(String k, Object v); + + SamplingSpec meta(Map m); + + SamplingSpec meta(String k, Object v); + + } + + // -------------------------------------- + // Progress + // -------------------------------------- + + interface ProgressSpec { + + ProgressSpec progress(double progress); + + ProgressSpec total(double total); + + ProgressSpec message(String message); + + ProgressSpec meta(Map m); + + ProgressSpec meta(String k, Object v); + + } + + // -------------------------------------- + // Logging + // -------------------------------------- + + interface LoggingSpec { + + LoggingSpec message(String message); + + LoggingSpec logger(String logger); + + LoggingSpec level(LoggingLevel level); + + LoggingSpec meta(Map m); + + LoggingSpec meta(String k, Object v); + + } + + // -------------------------------------- + // Getters + // -------------------------------------- + McpSchema.Request request(); + + ET exchange(); + + String sessionId(); + + Implementation clientInfo(); + + ClientCapabilities clientCapabilities(); + + Map requestMeta(); + + McpTransportContext transportContext(); + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java new file mode 100644 index 0000000..38e74f1 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java @@ -0,0 +1,78 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.Optional; +import java.util.function.Consumer; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.ListRootsResult; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; + +/** + * @author Christian Tzolov + */ +public interface McpSyncRequestContext extends McpRequestContextTypes { + + // -------------------------------------- + // Roots + // -------------------------------------- + Optional roots(); + + // -------------------------------------- + // Elicitation + // -------------------------------------- + Optional> elicit(Class type); + + Optional> elicit(TypeReference type); + + Optional> elicit(Consumer params, Class returnType); + + Optional> elicit(Consumer params, TypeReference returnType); + + Optional elicit(ElicitRequest elicitRequest); + + // -------------------------------------- + // Sampling + // -------------------------------------- + Optional sample(String... messages); + + Optional sample(Consumer samplingSpec); + + Optional sample(CreateMessageRequest createMessageRequest); + + // -------------------------------------- + // Progress + // -------------------------------------- + void progress(int progress); + + void progress(Consumer progressSpec); + + void progress(ProgressNotification progressNotification); + + // -------------------------------------- + // Ping + // -------------------------------------- + void ping(); + + // -------------------------------------- + // Logging + // -------------------------------------- + void log(Consumer logSpec); + + void debug(String message); + + void info(String message); + + void warn(String message); + + void error(String message); + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/StructuredElicitResult.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/StructuredElicitResult.java new file mode 100644 index 0000000..0afd5f5 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/StructuredElicitResult.java @@ -0,0 +1,19 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.Map; + +import io.modelcontextprotocol.spec.McpSchema.ElicitResult.Action; + +/** + * A record representing the result of a structured elicit action. + * + * @param the type of the structured content + * @author Christian Tzolov + */ +public record StructuredElicitResult(Action action, T structuredContent, Map meta) { + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java index b1e19be..43cec0a 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java @@ -22,6 +22,7 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import org.reactivestreams.Publisher; import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.context.McpRequestContextTypes; import org.springaicommunity.mcp.method.tool.utils.JsonParser; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -37,7 +38,8 @@ * McpTransportContext) * @author Christian Tzolov */ -public abstract class AbstractAsyncMcpToolMethodCallback extends AbstractMcpToolMethodCallback { +public abstract class AbstractAsyncMcpToolMethodCallback> + extends AbstractMcpToolMethodCallback { protected final Class toolCallExceptionClass; diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractMcpToolMethodCallback.java index f5fa170..3e9ea3f 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractMcpToolMethodCallback.java @@ -28,6 +28,9 @@ import org.springaicommunity.mcp.annotation.McpMeta; import org.springaicommunity.mcp.annotation.McpProgressToken; import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.context.McpAsyncRequestContext; +import org.springaicommunity.mcp.context.McpRequestContextTypes; +import org.springaicommunity.mcp.context.McpSyncRequestContext; import org.springaicommunity.mcp.method.tool.utils.JsonParser; /** @@ -41,7 +44,7 @@ * McpSyncServerExchange, or McpAsyncServerExchange) * @author Christian Tzolov */ -public abstract class AbstractMcpToolMethodCallback { +public abstract class AbstractMcpToolMethodCallback { protected final Method toolMethod; @@ -89,7 +92,15 @@ protected Object callMethod(Object[] methodArguments) { */ protected Object[] buildMethodArguments(T exchangeOrContext, Map toolInputArguments, CallToolRequest request) { + return Stream.of(this.toolMethod.getParameters()).map(parameter -> { + + if (McpSyncRequestContext.class.isAssignableFrom(parameter.getType()) + || McpAsyncRequestContext.class.isAssignableFrom(parameter.getType())) { + + return this.createRequestContext(exchangeOrContext, request); + } + // Check if parameter is annotated with @McpProgressToken if (parameter.isAnnotationPresent(McpProgressToken.class)) { // Return the progress token from the request @@ -203,4 +214,6 @@ protected Throwable findCauseUsingPlainJava(Throwable throwable) { return rootCause; } + protected abstract RC createRequestContext(T exchange, CallToolRequest request); + } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java index d05d016..0ad552f 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java @@ -20,6 +20,7 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import org.springaicommunity.mcp.context.McpRequestContextTypes; /** * Abstract base class for creating Function callbacks around synchronous tool methods. @@ -33,7 +34,8 @@ * McpSyncServerExchange) * @author Christian Tzolov */ -public abstract class AbstractSyncMcpToolMethodCallback extends AbstractAsyncMcpToolMethodCallback { +public abstract class AbstractSyncMcpToolMethodCallback> + extends AbstractAsyncMcpToolMethodCallback { protected AbstractSyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject, Class toolCallExceptionClass) { diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java index ab51cab..d8cdfd8 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java @@ -19,11 +19,12 @@ import java.lang.reflect.Method; import java.util.function.BiFunction; -import org.springaicommunity.mcp.annotation.McpTool; - import io.modelcontextprotocol.server.McpAsyncServerExchange; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.context.DefaultMcpAsyncRequestContext; +import org.springaicommunity.mcp.context.McpAsyncRequestContext; import reactor.core.publisher.Mono; /** @@ -34,7 +35,8 @@ * * @author Christian Tzolov */ -public final class AsyncMcpToolMethodCallback extends AbstractAsyncMcpToolMethodCallback +public final class AsyncMcpToolMethodCallback + extends AbstractAsyncMcpToolMethodCallback implements BiFunction> { public AsyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject) { @@ -48,7 +50,14 @@ public AsyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Obje @Override protected boolean isExchangeOrContextType(Class paramType) { - return McpAsyncServerExchange.class.isAssignableFrom(paramType); + return McpAsyncServerExchange.class.isAssignableFrom(paramType) + || McpAsyncRequestContext.class.isAssignableFrom(paramType); + } + + @Override + protected McpAsyncRequestContext createRequestContext(McpAsyncServerExchange exchange, CallToolRequest request) { + + return DefaultMcpAsyncRequestContext.builder().request(request).exchange(exchange).build(); } /** diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java index 300f9cf..06957d9 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java @@ -18,10 +18,12 @@ import java.util.function.BiFunction; -import org.springaicommunity.mcp.annotation.McpTool; import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.context.DefaultMcpAsyncRequestContext; +import org.springaicommunity.mcp.context.McpAsyncRequestContext; import reactor.core.publisher.Mono; /** @@ -33,7 +35,8 @@ * * @author Christian Tzolov */ -public final class AsyncStatelessMcpToolMethodCallback extends AbstractAsyncMcpToolMethodCallback +public final class AsyncStatelessMcpToolMethodCallback + extends AbstractAsyncMcpToolMethodCallback implements BiFunction> { public AsyncStatelessMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod, @@ -48,7 +51,18 @@ public AsyncStatelessMcpToolMethodCallback(ReturnMode returnMode, java.lang.refl @Override protected boolean isExchangeOrContextType(Class paramType) { - return McpTransportContext.class.isAssignableFrom(paramType); + return McpTransportContext.class.isAssignableFrom(paramType) + || McpAsyncRequestContext.class.isAssignableFrom(paramType); + } + + @Override + protected McpAsyncRequestContext createRequestContext(McpTransportContext exchange, CallToolRequest request) { + + return DefaultMcpAsyncRequestContext.builder() + .request(request) + .transportContext(exchange) + .stateless(true) + .build(); } /** diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java index eaba1b2..fcb14e3 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java @@ -19,7 +19,8 @@ import java.util.function.BiFunction; import org.springaicommunity.mcp.annotation.McpTool; - +import org.springaicommunity.mcp.context.DefaultMcpSyncRequestContext; +import org.springaicommunity.mcp.context.McpSyncRequestContext; import io.modelcontextprotocol.server.McpSyncServerExchange; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -32,7 +33,8 @@ * * @author Christian Tzolov */ -public final class SyncMcpToolMethodCallback extends AbstractSyncMcpToolMethodCallback +public final class SyncMcpToolMethodCallback + extends AbstractSyncMcpToolMethodCallback implements BiFunction { public SyncMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod, Object toolObject) { @@ -46,7 +48,14 @@ public SyncMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method @Override protected boolean isExchangeOrContextType(Class paramType) { - return McpSyncServerExchange.class.isAssignableFrom(paramType); + return McpSyncServerExchange.class.isAssignableFrom(paramType) + || McpSyncRequestContext.class.isAssignableFrom(paramType); + } + + @Override + protected McpSyncRequestContext createRequestContext(McpSyncServerExchange exchange, CallToolRequest request) { + + return DefaultMcpSyncRequestContext.builder().request(request).exchange(exchange).build(); } /** diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java index a17cc9f..88029e4 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java @@ -18,10 +18,12 @@ import java.util.function.BiFunction; -import org.springaicommunity.mcp.annotation.McpTool; import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.context.DefaultMcpSyncRequestContext; +import org.springaicommunity.mcp.context.McpSyncRequestContext; /** * Class for creating Function callbacks around tool methods. @@ -32,7 +34,8 @@ * @author James Ward * @author Christian Tzolov */ -public final class SyncStatelessMcpToolMethodCallback extends AbstractSyncMcpToolMethodCallback +public final class SyncStatelessMcpToolMethodCallback + extends AbstractSyncMcpToolMethodCallback implements BiFunction { public SyncStatelessMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod, @@ -47,7 +50,18 @@ public SyncStatelessMcpToolMethodCallback(ReturnMode returnMode, java.lang.refle @Override protected boolean isExchangeOrContextType(Class paramType) { - return McpTransportContext.class.isAssignableFrom(paramType); + return McpTransportContext.class.isAssignableFrom(paramType) + || McpSyncRequestContext.class.isAssignableFrom(paramType); + } + + @Override + protected McpSyncRequestContext createRequestContext(McpTransportContext exchange, CallToolRequest request) { + + return DefaultMcpSyncRequestContext.builder() + .request(request) + .transportContext(exchange) + .stateless(true) + .build(); } @Override diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/JsonSchemaGenerator.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/JsonSchemaGenerator.java index 814e2da..a480d67 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/JsonSchemaGenerator.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/JsonSchemaGenerator.java @@ -27,7 +27,8 @@ import org.springaicommunity.mcp.annotation.McpMeta; import org.springaicommunity.mcp.annotation.McpProgressToken; import org.springaicommunity.mcp.annotation.McpToolParam; - +import org.springaicommunity.mcp.context.McpAsyncRequestContext; +import org.springaicommunity.mcp.context.McpSyncRequestContext; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; import com.fasterxml.jackson.databind.JsonNode; @@ -110,7 +111,9 @@ private static String internalGenerateFromMethodArguments(Method method) { // @McpProgressToken annotated parameters, and McpMeta parameters boolean hasOtherParams = Arrays.stream(method.getParameters()).anyMatch(param -> { Class type = param.getType(); - return !CallToolRequest.class.isAssignableFrom(type) + return !McpSyncRequestContext.class.isAssignableFrom(type) + && !McpAsyncRequestContext.class.isAssignableFrom(type) + && !CallToolRequest.class.isAssignableFrom(type) && !McpSyncServerExchange.class.isAssignableFrom(type) && !McpAsyncServerExchange.class.isAssignableFrom(type) && !param.isAnnotationPresent(McpProgressToken.class) && !McpMeta.class.isAssignableFrom(type); @@ -150,7 +153,9 @@ private static String internalGenerateFromMethodArguments(Method method) { // Skip special parameter types if (parameterType instanceof Class parameterClass - && (ClassUtils.isAssignable(McpSyncServerExchange.class, parameterClass) + && (ClassUtils.isAssignable(McpSyncRequestContext.class, parameterClass) + || ClassUtils.isAssignable(McpAsyncRequestContext.class, parameterClass) + || ClassUtils.isAssignable(McpSyncServerExchange.class, parameterClass) || ClassUtils.isAssignable(McpAsyncServerExchange.class, parameterClass) || ClassUtils.isAssignable(CallToolRequest.class, parameterClass))) { continue; diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultLoggingSpecTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultLoggingSpecTests.java new file mode 100644 index 0000000..8b13f08 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultLoggingSpecTests.java @@ -0,0 +1,146 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.Map; + +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link DefaultLoggingSpec}. + * + * @author Christian Tzolov + */ +public class DefaultLoggingSpecTests { + + @Test + public void testMessageSetting() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + spec.message("Test log message"); + + assertThat(spec.message).isEqualTo("Test log message"); + } + + @Test + public void testLoggerSetting() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + spec.logger("test-logger"); + + assertThat(spec.logger).isEqualTo("test-logger"); + } + + @Test + public void testLevelSetting() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + spec.level(LoggingLevel.ERROR); + + assertThat(spec.level).isEqualTo(LoggingLevel.ERROR); + } + + @Test + public void testDefaultLevel() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + assertThat(spec.level).isEqualTo(LoggingLevel.INFO); + } + + @Test + public void testMetaWithMap() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + Map metaMap = Map.of("key1", "value1", "key2", "value2"); + + spec.meta(metaMap); + + assertThat(spec.meta).containsEntry("key1", "value1").containsEntry("key2", "value2"); + } + + @Test + public void testMetaWithNullMap() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + spec.meta((Map) null); + + assertThat(spec.meta).isEmpty(); + } + + @Test + public void testMetaWithKeyValue() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + spec.meta("key", "value"); + + assertThat(spec.meta).containsEntry("key", "value"); + } + + @Test + public void testMetaWithNullKey() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + spec.meta(null, "value"); + + assertThat(spec.meta).isEmpty(); + } + + @Test + public void testMetaWithNullValue() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + spec.meta("key", null); + + assertThat(spec.meta).isEmpty(); + } + + @Test + public void testMetaMultipleEntries() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + spec.meta("key1", "value1").meta("key2", "value2").meta("key3", "value3"); + + assertThat(spec.meta).hasSize(3) + .containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .containsEntry("key3", "value3"); + } + + @Test + public void testFluentInterface() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + McpRequestContextTypes.LoggingSpec result = spec.message("Test message") + .logger("test-logger") + .level(LoggingLevel.DEBUG) + .meta("key", "value"); + + assertThat(result).isSameAs(spec); + assertThat(spec.message).isEqualTo("Test message"); + assertThat(spec.logger).isEqualTo("test-logger"); + assertThat(spec.level).isEqualTo(LoggingLevel.DEBUG); + assertThat(spec.meta).containsEntry("key", "value"); + } + + @Test + public void testAllLoggingLevels() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + spec.level(LoggingLevel.DEBUG); + assertThat(spec.level).isEqualTo(LoggingLevel.DEBUG); + + spec.level(LoggingLevel.INFO); + assertThat(spec.level).isEqualTo(LoggingLevel.INFO); + + spec.level(LoggingLevel.WARNING); + assertThat(spec.level).isEqualTo(LoggingLevel.WARNING); + + spec.level(LoggingLevel.ERROR); + assertThat(spec.level).isEqualTo(LoggingLevel.ERROR); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java new file mode 100644 index 0000000..f48a216 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java @@ -0,0 +1,689 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.Map; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.Implementation; +import io.modelcontextprotocol.spec.McpSchema.ListRootsResult; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link DefaultMcpAsyncRequestContext}. + * + * @author Christian Tzolov + */ +public class DefaultMcpAsyncRequestContextTests { + + private CallToolRequest request; + + private McpAsyncServerExchange exchange; + + private McpAsyncRequestContext context; + + @BeforeEach + public void setUp() { + request = new CallToolRequest("test-tool", Map.of()); + exchange = mock(McpAsyncServerExchange.class); + context = DefaultMcpAsyncRequestContext.builder().request(request).exchange(exchange).build(); + } + + // Builder Tests + + @Test + public void testBuilderWithValidParameters() { + CallToolRequest testRequest = new CallToolRequest("test-tool", Map.of()); + McpAsyncRequestContext ctx = DefaultMcpAsyncRequestContext.builder() + .request(testRequest) + .exchange(exchange) + .build(); + + assertThat(ctx).isNotNull(); + assertThat(ctx.request()).isEqualTo(testRequest); + assertThat(ctx.exchange()).isEqualTo(exchange); + } + + @Test + public void testBuilderWithNullRequest() { + StepVerifier + .create(Mono + .fromCallable(() -> DefaultMcpAsyncRequestContext.builder().request(null).exchange(exchange).build())) + .expectErrorMatches(throwable -> throwable instanceof IllegalArgumentException + && throwable.getMessage().contains("Request must not be null")) + .verify(); + } + + @Test + public void testBuilderWithNullExchange() { + CallToolRequest testRequest = new CallToolRequest("test-tool", Map.of()); + StepVerifier + .create(Mono.fromCallable( + () -> DefaultMcpAsyncRequestContext.builder().request(testRequest).exchange(null).build())) + .expectErrorMatches(throwable -> throwable instanceof IllegalArgumentException + && throwable.getMessage().contains("Exchange must not be null")) + .verify(); + } + + // Roots Tests + + @Test + public void testRootsWhenSupported() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + McpSchema.ClientCapabilities.RootCapabilities roots = mock(McpSchema.ClientCapabilities.RootCapabilities.class); + when(capabilities.roots()).thenReturn(roots); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + ListRootsResult expectedResult = mock(ListRootsResult.class); + when(exchange.listRoots()).thenReturn(Mono.just(expectedResult)); + + StepVerifier.create(context.roots()).expectNext(expectedResult).verifyComplete(); + + verify(exchange).listRoots(); + } + + @Test + public void testRootsWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + StepVerifier.create(context.roots()).verifyComplete(); + } + + @Test + public void testRootsWhenCapabilitiesNullRoots() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + when(capabilities.roots()).thenReturn(null); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + StepVerifier.create(context.roots()).verifyComplete(); + } + + // Elicitation Tests + + @Test + public void testElicitationWithMessageAndMeta() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + Map contentMap = Map.of("name", "John", "age", 30); + ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(expectedResult.meta()).thenReturn(null); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); + + Mono>> result = context.elicit(e -> e.message("Test message"), + new TypeReference>() { + }); + + StepVerifier.create(result).assertNext(structuredResult -> { + assertThat(structuredResult.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(structuredResult.structuredContent()).isNotNull(); + assertThat(structuredResult.structuredContent()).containsEntry("name", "John"); + assertThat(structuredResult.structuredContent()).containsEntry("age", 30); + }).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ElicitRequest.class); + verify(exchange).createElicitation(captor.capture()); + + ElicitRequest capturedRequest = captor.getValue(); + assertThat(capturedRequest.message()).isEqualTo("Test message"); + assertThat(capturedRequest.requestedSchema()).isNotNull(); + } + + @Test + public void testElicitationWithMetadata() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + record Person(String name, int age) { + } + + Map contentMap = Map.of("name", "Jane", "age", 25); + ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(expectedResult.meta()).thenReturn(null); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); + + Map meta = Map.of("key", "value"); + Mono> result = context.elicit(e -> e.message("Test message").meta(meta), + new TypeReference() { + }); + + StepVerifier.create(result).assertNext(structuredResult -> { + assertThat(structuredResult.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(structuredResult.structuredContent()).isNotNull(); + assertThat(structuredResult.structuredContent().name()).isEqualTo("Jane"); + assertThat(structuredResult.structuredContent().age()).isEqualTo(25); + }).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ElicitRequest.class); + verify(exchange).createElicitation(captor.capture()); + + ElicitRequest capturedRequest = captor.getValue(); + assertThat(capturedRequest.meta()).containsEntry("key", "value"); + } + + @Test + public void testElicitationWithNullTypeReference() { + assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { + context.elicit((TypeReference) null); + })).hasMessageContaining("Elicitation response type must not be null"); + } + + @Test + public void testElicitationWithNullClassType() { + assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { + context.elicit((Class) null); + })).hasMessageContaining("Elicitation response type must not be null"); + } + + @Test + public void testElicitationWithEmptyMessage() { + assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { + context.elicit(e -> e.message("").meta(null), new TypeReference() { + }); + })).hasMessageContaining("Elicitation message must not be empty"); + } + + @Test + public void testElicitationWithNullMessage() { + assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { + context.elicit(e -> e.message(null).meta(null), new TypeReference() { + }); + })).hasMessageContaining("Elicitation message must not be empty"); + } + + @Test + public void testElicitationReturnsEmptyWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + Mono>> result = context.elicit(e -> e.message("Test message"), + new TypeReference>() { + }); + + StepVerifier.create(result).verifyComplete(); + } + + @Test + public void testElicitationReturnsResultWhenActionIsNotAccept() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + Map contentMap = Map.of(); + ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.DECLINE); + when(expectedResult.content()).thenReturn(contentMap); + when(expectedResult.meta()).thenReturn(null); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); + + Mono>> result = context.elicit(e -> e.message("Test message"), + new TypeReference>() { + }); + + StepVerifier.create(result).assertNext(structuredResult -> { + assertThat(structuredResult.action()).isEqualTo(ElicitResult.Action.DECLINE); + assertThat(structuredResult.structuredContent()).isNotNull(); + }).verifyComplete(); + } + + @Test + public void testElicitationConvertsComplexTypes() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + record Address(String street, String city) { + } + record PersonWithAddress(String name, int age, Address address) { + } + + Map addressMap = Map.of("street", "123 Main St", "city", "Springfield"); + Map contentMap = Map.of("name", "John", "age", 30, "address", addressMap); + ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(expectedResult.meta()).thenReturn(null); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); + + Mono> result = context.elicit(e -> e.message("Test message"), + new TypeReference() { + }); + + StepVerifier.create(result).assertNext(structuredResult -> { + assertThat(structuredResult.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(structuredResult.structuredContent()).isNotNull(); + assertThat(structuredResult.structuredContent().name()).isEqualTo("John"); + assertThat(structuredResult.structuredContent().age()).isEqualTo(30); + assertThat(structuredResult.structuredContent().address()).isNotNull(); + assertThat(structuredResult.structuredContent().address().street()).isEqualTo("123 Main St"); + assertThat(structuredResult.structuredContent().address().city()).isEqualTo("Springfield"); + }).verifyComplete(); + } + + @Test + public void testElicitationHandlesListTypes() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + Map contentMap = Map.of("items", + java.util.List.of(Map.of("name", "Item1"), Map.of("name", "Item2"))); + ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(expectedResult.meta()).thenReturn(null); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); + + Mono>> result = context.elicit(e -> e.message("Test message"), + new TypeReference>() { + }); + + StepVerifier.create(result).assertNext(structuredResult -> { + assertThat(structuredResult.structuredContent()).containsKey("items"); + }).verifyComplete(); + } + + @Test + public void testElicitationWithTypeReference() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + Map contentMap = Map.of("result", "success", "data", "test value"); + ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); + + Mono>> result = context + .elicit(new TypeReference>() { + }); + + StepVerifier.create(result).assertNext(map -> { + assertThat(map.structuredContent()).containsEntry("result", "success"); + assertThat(map.structuredContent()).containsEntry("data", "test value"); + }).verifyComplete(); + } + + @Test + public void testElicitationWithRequest() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + ElicitResult expectedResult = mock(ElicitResult.class); + ElicitRequest elicitRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "string")) + .build(); + + when(exchange.createElicitation(elicitRequest)).thenReturn(Mono.just(expectedResult)); + + Mono result = context.elicit(elicitRequest); + + StepVerifier.create(result).expectNext(expectedResult).verifyComplete(); + } + + @Test + public void testElicitationWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + ElicitRequest elicitRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "string")) + .build(); + + Mono result = context.elicit(elicitRequest); + + StepVerifier.create(result).verifyComplete(); + } + + // Sampling Tests + + @Test + public void testSamplingWithMessages() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Sampling sampling = mock(ClientCapabilities.Sampling.class); + when(capabilities.sampling()).thenReturn(sampling); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + CreateMessageResult expectedResult = mock(CreateMessageResult.class); + when(exchange.createMessage(any(CreateMessageRequest.class))).thenReturn(Mono.just(expectedResult)); + + Mono result = context.sample("Message 1", "Message 2"); + + StepVerifier.create(result).expectNext(expectedResult).verifyComplete(); + } + + @Test + public void testSamplingWithConsumer() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Sampling sampling = mock(ClientCapabilities.Sampling.class); + when(capabilities.sampling()).thenReturn(sampling); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + CreateMessageResult expectedResult = mock(CreateMessageResult.class); + when(exchange.createMessage(any(CreateMessageRequest.class))).thenReturn(Mono.just(expectedResult)); + + Mono result = context.sample(spec -> { + spec.message(new TextContent("Test message")); + spec.systemPrompt("System prompt"); + spec.temperature(0.7); + spec.maxTokens(100); + }); + + StepVerifier.create(result).expectNext(expectedResult).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(CreateMessageRequest.class); + verify(exchange).createMessage(captor.capture()); + + CreateMessageRequest capturedRequest = captor.getValue(); + assertThat(capturedRequest.systemPrompt()).isEqualTo("System prompt"); + assertThat(capturedRequest.temperature()).isEqualTo(0.7); + assertThat(capturedRequest.maxTokens()).isEqualTo(100); + } + + @Test + public void testSamplingWithRequest() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Sampling sampling = mock(ClientCapabilities.Sampling.class); + when(capabilities.sampling()).thenReturn(sampling); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + CreateMessageResult expectedResult = mock(CreateMessageResult.class); + CreateMessageRequest createRequest = CreateMessageRequest.builder() + .messages(java.util.List.of(new SamplingMessage(Role.USER, new TextContent("Test")))) + .maxTokens(500) + .build(); + + when(exchange.createMessage(createRequest)).thenReturn(Mono.just(expectedResult)); + + Mono result = context.sample(createRequest); + + StepVerifier.create(result).expectNext(expectedResult).verifyComplete(); + } + + @Test + public void testSamplingWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + CreateMessageRequest createRequest = CreateMessageRequest.builder() + .messages(java.util.List.of(new SamplingMessage(Role.USER, new TextContent("Test")))) + .maxTokens(500) + .build(); + + Mono result = context.sample(createRequest); + + StepVerifier.create(result).verifyComplete(); + } + + // Progress Tests + + @Test + public void testProgressWithPercentage() { + CallToolRequest requestWithToken = CallToolRequest.builder() + .name("test-tool") + .arguments(Map.of()) + .progressToken("token-123") + .build(); + McpAsyncRequestContext contextWithToken = DefaultMcpAsyncRequestContext.builder() + .request(requestWithToken) + .exchange(exchange) + .build(); + + when(exchange.progressNotification(any(ProgressNotification.class))).thenReturn(Mono.empty()); + + StepVerifier.create(contextWithToken.progress(50)).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ProgressNotification.class); + verify(exchange).progressNotification(captor.capture()); + + ProgressNotification notification = captor.getValue(); + assertThat(notification.progressToken()).isEqualTo("token-123"); + assertThat(notification.progress()).isEqualTo(0.5); + assertThat(notification.total()).isEqualTo(1.0); + } + + @Test + public void testProgressWithInvalidPercentage() { + assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { + context.progress(-1); + })).hasMessageContaining("Percentage must be between 0 and 100"); + + assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { + context.progress(101); + })).hasMessageContaining("Percentage must be between 0 and 100"); + } + + @Test + public void testProgressWithConsumer() { + CallToolRequest requestWithToken = CallToolRequest.builder() + .name("test-tool") + .arguments(Map.of()) + .progressToken("token-123") + .build(); + McpAsyncRequestContext contextWithToken = DefaultMcpAsyncRequestContext.builder() + .request(requestWithToken) + .exchange(exchange) + .build(); + + when(exchange.progressNotification(any(ProgressNotification.class))).thenReturn(Mono.empty()); + + StepVerifier.create(contextWithToken.progress(spec -> { + spec.progress(0.75); + spec.total(1.0); + spec.message("Processing..."); + })).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ProgressNotification.class); + verify(exchange).progressNotification(captor.capture()); + + ProgressNotification notification = captor.getValue(); + assertThat(notification.progressToken()).isEqualTo("token-123"); + assertThat(notification.progress()).isEqualTo(0.75); + assertThat(notification.total()).isEqualTo(1.0); + assertThat(notification.message()).isEqualTo("Processing..."); + } + + @Test + public void testProgressWithNotification() { + ProgressNotification notification = new ProgressNotification("token-123", 0.5, 1.0, "Test", null); + when(exchange.progressNotification(notification)).thenReturn(Mono.empty()); + + StepVerifier.create(context.progress(notification)).verifyComplete(); + + verify(exchange).progressNotification(notification); + } + + @Test + public void testProgressWithoutToken() { + // request already has no progress token (null by default) + // Should not throw, just log warning and return empty + StepVerifier.create(context.progress(50)).verifyComplete(); + } + + // Ping Tests + + @Test + public void testPing() { + when(exchange.ping()).thenReturn(Mono.just(new Object())); + + StepVerifier.create(context.ping()).expectNextCount(1).verifyComplete(); + + verify(exchange).ping(); + } + + // Logging Tests + + @Test + public void testLogWithConsumer() { + when(exchange.loggingNotification(any(LoggingMessageNotification.class))).thenReturn(Mono.empty()); + + StepVerifier.create(context.log(spec -> { + spec.message("Test log message"); + spec.level(LoggingLevel.INFO); + spec.logger("test-logger"); + })).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Test log message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.INFO); + assertThat(notification.logger()).isEqualTo("test-logger"); + } + + @Test + public void testDebug() { + when(exchange.loggingNotification(any(LoggingMessageNotification.class))).thenReturn(Mono.empty()); + + StepVerifier.create(context.debug("Debug message")).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Debug message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.DEBUG); + } + + @Test + public void testInfo() { + when(exchange.loggingNotification(any(LoggingMessageNotification.class))).thenReturn(Mono.empty()); + + StepVerifier.create(context.info("Info message")).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Info message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.INFO); + } + + @Test + public void testWarn() { + when(exchange.loggingNotification(any(LoggingMessageNotification.class))).thenReturn(Mono.empty()); + + StepVerifier.create(context.warn("Warning message")).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Warning message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.WARNING); + } + + @Test + public void testError() { + when(exchange.loggingNotification(any(LoggingMessageNotification.class))).thenReturn(Mono.empty()); + + StepVerifier.create(context.error("Error message")).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Error message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.ERROR); + } + + @Test + public void testLogWithEmptyMessage() { + assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { + context.debug(""); + })).hasMessageContaining("Log message must not be empty"); + } + + // Getter Tests + + @Test + public void testGetRequest() { + assertThat(context.request()).isEqualTo(request); + } + + @Test + public void testGetExchange() { + assertThat(context.exchange()).isEqualTo(exchange); + } + + @Test + public void testGetSessionId() { + when(exchange.sessionId()).thenReturn("session-123"); + + assertThat(context.sessionId()).isEqualTo("session-123"); + } + + @Test + public void testGetClientInfo() { + Implementation clientInfo = mock(Implementation.class); + when(exchange.getClientInfo()).thenReturn(clientInfo); + + assertThat(context.clientInfo()).isEqualTo(clientInfo); + } + + @Test + public void testGetClientCapabilities() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + assertThat(context.clientCapabilities()).isEqualTo(capabilities); + } + + @Test + public void testGetRequestMeta() { + Map meta = Map.of("key", "value"); + CallToolRequest requestWithMeta = CallToolRequest.builder() + .name("test-tool") + .arguments(Map.of()) + .meta(meta) + .build(); + McpAsyncRequestContext contextWithMeta = DefaultMcpAsyncRequestContext.builder() + .request(requestWithMeta) + .exchange(exchange) + .build(); + + assertThat(contextWithMeta.requestMeta()).isEqualTo(meta); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java new file mode 100644 index 0000000..b2ece1a --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java @@ -0,0 +1,641 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.Map; +import java.util.Optional; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.Implementation; +import io.modelcontextprotocol.spec.McpSchema.ListRootsResult; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link DefaultMcpSyncRequestContext}. + * + * @author Christian Tzolov + */ +public class DefaultMcpSyncRequestContextTests { + + private CallToolRequest request; + + private McpSyncServerExchange exchange; + + private McpSyncRequestContext context; + + @BeforeEach + public void setUp() { + request = new CallToolRequest("test-tool", Map.of()); + exchange = mock(McpSyncServerExchange.class); + context = DefaultMcpSyncRequestContext.builder().request(request).exchange(exchange).build(); + } + + // Builder Tests + + @Test + public void testBuilderWithValidParameters() { + CallToolRequest testRequest = new CallToolRequest("test-tool", Map.of()); + McpSyncRequestContext ctx = DefaultMcpSyncRequestContext.builder() + .request(testRequest) + .exchange(exchange) + .build(); + + assertThat(ctx).isNotNull(); + assertThat(ctx.request()).isEqualTo(testRequest); + assertThat(ctx.exchange()).isEqualTo(exchange); + } + + @Test + public void testBuilderWithNullRequest() { + assertThatThrownBy(() -> DefaultMcpSyncRequestContext.builder().request(null).exchange(exchange).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Request must not be null"); + } + + @Test + public void testBuilderWithNullExchange() { + CallToolRequest testRequest = new CallToolRequest("test-tool", Map.of()); + assertThatThrownBy(() -> DefaultMcpSyncRequestContext.builder().request(testRequest).exchange(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Exchange must not be null"); + } + + // Roots Tests + + @Test + public void testRootsWhenSupported() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + McpSchema.ClientCapabilities.RootCapabilities roots = mock(McpSchema.ClientCapabilities.RootCapabilities.class); + when(capabilities.roots()).thenReturn(roots); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + ListRootsResult expectedResult = mock(ListRootsResult.class); + when(exchange.listRoots()).thenReturn(expectedResult); + + Optional result = context.roots(); + + assertThat(result).isPresent(); + assertThat(result.get()).isEqualTo(expectedResult); + verify(exchange).listRoots(); + } + + @Test + public void testRootsWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + Optional result = context.roots(); + + assertThat(result).isEmpty(); + } + + @Test + public void testRootsWhenCapabilitiesNullRoots() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + when(capabilities.roots()).thenReturn(null); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + Optional result = context.roots(); + + assertThat(result).isEmpty(); + } + + // Elicitation Tests + + @Test + public void testElicitationWithTypeAndMessage() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + Map contentMap = Map.of("name", "John", "age", 30); + ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(expectedResult.meta()).thenReturn(null); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); + + Optional>> result = context.elicit(e -> e.message("Test message"), + new TypeReference>() { + }); + + assertThat(result).isPresent(); + assertThat(result.get().action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.get().structuredContent()).isNotNull(); + assertThat(result.get().structuredContent()).containsEntry("name", "John"); + assertThat(result.get().structuredContent()).containsEntry("age", 30); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ElicitRequest.class); + verify(exchange).createElicitation(captor.capture()); + + ElicitRequest capturedRequest = captor.getValue(); + assertThat(capturedRequest.message()).isEqualTo("Test message"); + assertThat(capturedRequest.requestedSchema()).isNotNull(); + } + + @Test + public void testElicitationWithTypeMessageAndMeta() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + record Person(String name, int age) { + } + + Map contentMap = Map.of("name", "Jane", "age", 25); + Map requestMeta = Map.of("key", "value"); + Map resultMeta = Map.of("resultKey", "resultValue"); + ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(expectedResult.meta()).thenReturn(resultMeta); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); + + Optional> result = context + .elicit(e -> e.message("Test message").meta(requestMeta), new TypeReference() { + }); + + assertThat(result).isPresent(); + assertThat(result.get().action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.get().structuredContent()).isNotNull(); + assertThat(result.get().structuredContent().name()).isEqualTo("Jane"); + assertThat(result.get().structuredContent().age()).isEqualTo(25); + assertThat(result.get().meta()).containsEntry("resultKey", "resultValue"); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ElicitRequest.class); + verify(exchange).createElicitation(captor.capture()); + + ElicitRequest capturedRequest = captor.getValue(); + assertThat(capturedRequest.meta()).containsEntry("key", "value"); + } + + @Test + public void testElicitationWithNullResponseType() { + assertThatThrownBy(() -> context.elicit((TypeReference) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Elicitation response type must not be null"); + } + + @Test + public void testElicitationWithTypeReturnsEmptyWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + Optional>> result = context + .elicit(new TypeReference>() { + }); + + assertThat(result).isEmpty(); + } + + @Test + public void testElicitationWithTypeReturnsEmptyWhenActionIsNotAccept() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.DECLINE); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); + + Optional>> result = context.elicit(e -> e.message("Test message"), + new TypeReference>() { + }); + + assertThat(result).isEmpty(); + } + + @Test + public void testElicitationWithTypeConvertsComplexTypes() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + record Address(String street, String city) { + } + record PersonWithAddress(String name, int age, Address address) { + } + + Map addressMap = Map.of("street", "123 Main St", "city", "Springfield"); + Map contentMap = Map.of("name", "John", "age", 30, "address", addressMap); + ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(expectedResult.meta()).thenReturn(null); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); + + Optional> result = context + .elicit(e -> e.message("Test message").meta(null), new TypeReference() { + }); + + assertThat(result).isPresent(); + assertThat(result.get().action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.get().structuredContent()).isNotNull(); + assertThat(result.get().structuredContent().name()).isEqualTo("John"); + assertThat(result.get().structuredContent().age()).isEqualTo(30); + assertThat(result.get().structuredContent().address()).isNotNull(); + assertThat(result.get().structuredContent().address().street()).isEqualTo("123 Main St"); + assertThat(result.get().structuredContent().address().city()).isEqualTo("Springfield"); + } + + @Test + public void testElicitationWithTypeHandlesListTypes() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + Map contentMap = Map.of("items", + java.util.List.of(Map.of("name", "Item1"), Map.of("name", "Item2"))); + ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(expectedResult.meta()).thenReturn(null); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); + + Optional>> result = context + .elicit(e -> e.message("Test message").meta(null), new TypeReference>() { + }); + + assertThat(result).isPresent(); + assertThat(result.get().structuredContent()).containsKey("items"); + } + + @Test + public void testElicitationWithTypeReference() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + Map contentMap = Map.of("result", "success", "data", "test value"); + ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); + + Optional>> result = context + .elicit(e -> e.message("Test message").meta(null), new TypeReference>() { + }); + + assertThat(result).isPresent(); + assertThat(result.get().structuredContent()).containsEntry("result", "success"); + assertThat(result.get().structuredContent()).containsEntry("data", "test value"); + } + + @Test + public void testElicitationWithRequest() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + ElicitResult expectedResult = mock(ElicitResult.class); + ElicitRequest elicitRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "string")) + .build(); + + when(exchange.createElicitation(elicitRequest)).thenReturn(expectedResult); + + Optional result = context.elicit(elicitRequest); + + assertThat(result).isPresent(); + assertThat(result.get()).isEqualTo(expectedResult); + } + + @Test + public void testElicitationWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + ElicitRequest elicitRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "string")) + .build(); + + Optional result = context.elicit(elicitRequest); + + assertThat(result).isEmpty(); + } + + // Sampling Tests + + @Test + public void testSamplingWithMessages() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Sampling sampling = mock(ClientCapabilities.Sampling.class); + when(capabilities.sampling()).thenReturn(sampling); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + CreateMessageResult expectedResult = mock(CreateMessageResult.class); + when(exchange.createMessage(any(CreateMessageRequest.class))).thenReturn(expectedResult); + + Optional result = context.sample("Message 1", "Message 2"); + + assertThat(result).isPresent(); + assertThat(result.get()).isEqualTo(expectedResult); + } + + @Test + public void testSamplingWithConsumer() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Sampling sampling = mock(ClientCapabilities.Sampling.class); + when(capabilities.sampling()).thenReturn(sampling); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + CreateMessageResult expectedResult = mock(CreateMessageResult.class); + when(exchange.createMessage(any(CreateMessageRequest.class))).thenReturn(expectedResult); + + Optional result = context.sample(spec -> { + spec.message(new TextContent("Test message")); + spec.systemPrompt("System prompt"); + spec.temperature(0.7); + spec.maxTokens(100); + }); + + assertThat(result).isPresent(); + assertThat(result.get()).isEqualTo(expectedResult); + + ArgumentCaptor captor = ArgumentCaptor.forClass(CreateMessageRequest.class); + verify(exchange).createMessage(captor.capture()); + + CreateMessageRequest capturedRequest = captor.getValue(); + assertThat(capturedRequest.systemPrompt()).isEqualTo("System prompt"); + assertThat(capturedRequest.temperature()).isEqualTo(0.7); + assertThat(capturedRequest.maxTokens()).isEqualTo(100); + } + + @Test + public void testSamplingWithRequest() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Sampling sampling = mock(ClientCapabilities.Sampling.class); + when(capabilities.sampling()).thenReturn(sampling); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + CreateMessageResult expectedResult = mock(CreateMessageResult.class); + CreateMessageRequest createRequest = CreateMessageRequest.builder() + .messages(java.util.List.of(new SamplingMessage(Role.USER, new TextContent("Test")))) + .maxTokens(500) + .build(); + + when(exchange.createMessage(createRequest)).thenReturn(expectedResult); + + Optional result = context.sample(createRequest); + + assertThat(result).isPresent(); + assertThat(result.get()).isEqualTo(expectedResult); + } + + @Test + public void testSamplingWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + CreateMessageRequest createRequest = CreateMessageRequest.builder() + .messages(java.util.List.of(new SamplingMessage(Role.USER, new TextContent("Test")))) + .maxTokens(500) + .build(); + + Optional result = context.sample(createRequest); + + assertThat(result).isEmpty(); + } + + // Progress Tests + + @Test + public void testProgressWithPercentage() { + CallToolRequest requestWithToken = CallToolRequest.builder() + .name("test-tool") + .arguments(Map.of()) + .progressToken("token-123") + .build(); + McpSyncRequestContext contextWithToken = DefaultMcpSyncRequestContext.builder() + .request(requestWithToken) + .exchange(exchange) + .build(); + + contextWithToken.progress(50); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ProgressNotification.class); + verify(exchange).progressNotification(captor.capture()); + + ProgressNotification notification = captor.getValue(); + assertThat(notification.progressToken()).isEqualTo("token-123"); + assertThat(notification.progress()).isEqualTo(0.5); + assertThat(notification.total()).isEqualTo(1.0); + } + + @Test + public void testProgressWithInvalidPercentage() { + assertThatThrownBy(() -> context.progress(-1)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Percentage must be between 0 and 100"); + + assertThatThrownBy(() -> context.progress(101)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Percentage must be between 0 and 100"); + } + + @Test + public void testProgressWithConsumer() { + CallToolRequest requestWithToken = CallToolRequest.builder() + .name("test-tool") + .arguments(Map.of()) + .progressToken("token-123") + .build(); + McpSyncRequestContext contextWithToken = DefaultMcpSyncRequestContext.builder() + .request(requestWithToken) + .exchange(exchange) + .build(); + + contextWithToken.progress(spec -> { + spec.progress(0.75); + spec.total(1.0); + spec.message("Processing..."); + }); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ProgressNotification.class); + verify(exchange).progressNotification(captor.capture()); + + ProgressNotification notification = captor.getValue(); + assertThat(notification.progressToken()).isEqualTo("token-123"); + assertThat(notification.progress()).isEqualTo(0.75); + assertThat(notification.total()).isEqualTo(1.0); + assertThat(notification.message()).isEqualTo("Processing..."); + } + + @Test + public void testProgressWithNotification() { + ProgressNotification notification = new ProgressNotification("token-123", 0.5, 1.0, "Test", null); + + context.progress(notification); + + verify(exchange).progressNotification(notification); + } + + @Test + public void testProgressWithoutToken() { + // request already has no progress token (null by default) + // Should not throw, just log warning + context.progress(50); + } + + // Ping Tests + + @Test + public void testPing() { + context.ping(); + + verify(exchange).ping(); + } + + // Logging Tests + + @Test + public void testLogWithConsumer() { + context.log(spec -> { + spec.message("Test log message"); + spec.level(LoggingLevel.INFO); + spec.logger("test-logger"); + }); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Test log message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.INFO); + assertThat(notification.logger()).isEqualTo("test-logger"); + } + + @Test + public void testDebug() { + context.debug("Debug message"); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Debug message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.DEBUG); + } + + @Test + public void testInfo() { + context.info("Info message"); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Info message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.INFO); + } + + @Test + public void testWarn() { + context.warn("Warning message"); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Warning message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.WARNING); + } + + @Test + public void testError() { + context.error("Error message"); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Error message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.ERROR); + } + + @Test + public void testLogWithEmptyMessage() { + assertThatThrownBy(() -> context.debug("")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Log message must not be empty"); + } + + // Getter Tests + + @Test + public void testGetRequest() { + assertThat(context.request()).isEqualTo(request); + } + + @Test + public void testGetExchange() { + assertThat(context.exchange()).isEqualTo(exchange); + } + + @Test + public void testGetSessionId() { + when(exchange.sessionId()).thenReturn("session-123"); + + assertThat(context.sessionId()).isEqualTo("session-123"); + } + + @Test + public void testGetClientInfo() { + Implementation clientInfo = mock(Implementation.class); + when(exchange.getClientInfo()).thenReturn(clientInfo); + + assertThat(context.clientInfo()).isEqualTo(clientInfo); + } + + @Test + public void testGetClientCapabilities() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + assertThat(context.clientCapabilities()).isEqualTo(capabilities); + } + + @Test + public void testGetRequestMeta() { + Map meta = Map.of("key", "value"); + CallToolRequest requestWithMeta = CallToolRequest.builder() + .name("test-tool") + .arguments(Map.of()) + .meta(meta) + .build(); + McpSyncRequestContext contextWithMeta = DefaultMcpSyncRequestContext.builder() + .request(requestWithMeta) + .exchange(exchange) + .build(); + + assertThat(contextWithMeta.requestMeta()).isEqualTo(meta); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultProgressSpecTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultProgressSpecTests.java new file mode 100644 index 0000000..113a076 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultProgressSpecTests.java @@ -0,0 +1,167 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link DefaultProgressSpec}. + * + * @author Christian Tzolov + */ +public class DefaultProgressSpecTests { + + @Test + public void testDefaultValues() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + + assertThat(spec.progress).isEqualTo(0.0); + assertThat(spec.total).isEqualTo(1.0); + assertThat(spec.message).isNull(); + assertThat(spec.meta).isEmpty(); + } + + @Test + public void testProgressSetting() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + + spec.progress(0.5); + + assertThat(spec.progress).isEqualTo(0.5); + } + + @Test + public void testTotalSetting() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + + spec.total(100.0); + + assertThat(spec.total).isEqualTo(100.0); + } + + @Test + public void testMessageSetting() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + + spec.message("Processing..."); + + assertThat(spec.message).isEqualTo("Processing..."); + } + + @Test + public void testMetaWithMap() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + Map metaMap = new HashMap<>(); + metaMap.put("key1", "value1"); + metaMap.put("key2", "value2"); + + spec.meta(metaMap); + + assertThat(spec.meta).containsEntry("key1", "value1").containsEntry("key2", "value2"); + } + + @Test + public void testMetaWithNullMap() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + + spec.meta((Map) null); + + assertThat(spec.meta).isEmpty(); + } + + @Test + public void testMetaWithKeyValue() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + spec.meta = new HashMap<>(); + + spec.meta("key", "value"); + + assertThat(spec.meta).containsEntry("key", "value"); + } + + @Test + public void testMetaWithNullKey() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + spec.meta = new HashMap<>(); + + spec.meta(null, "value"); + + assertThat(spec.meta).isEmpty(); + } + + @Test + public void testMetaWithNullValue() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + spec.meta = new HashMap<>(); + + spec.meta("key", null); + + assertThat(spec.meta).isEmpty(); + } + + @Test + public void testMetaMultipleEntries() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + spec.meta = new HashMap<>(); + + spec.meta("key1", "value1").meta("key2", "value2").meta("key3", "value3"); + + assertThat(spec.meta).hasSize(3) + .containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .containsEntry("key3", "value3"); + } + + @Test + public void testFluentInterface() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + spec.meta = new HashMap<>(); + + McpRequestContextTypes.ProgressSpec result = spec.progress(0.75) + .total(1.0) + .message("Processing...") + .meta("key", "value"); + + assertThat(result).isSameAs(spec); + assertThat(spec.progress).isEqualTo(0.75); + assertThat(spec.total).isEqualTo(1.0); + assertThat(spec.message).isEqualTo("Processing..."); + assertThat(spec.meta).containsEntry("key", "value"); + } + + @Test + public void testProgressBoundaries() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + + spec.progress(0.0); + assertThat(spec.progress).isEqualTo(0.0); + + spec.progress(1.0); + assertThat(spec.progress).isEqualTo(1.0); + + spec.progress(0.5); + assertThat(spec.progress).isEqualTo(0.5); + } + + @Test + public void testTotalValues() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + + spec.total(50.0); + assertThat(spec.total).isEqualTo(50.0); + + spec.total(100.0); + assertThat(spec.total).isEqualTo(100.0); + + spec.total(1.0); + assertThat(spec.total).isEqualTo(1.0); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultSamplingSpecTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultSamplingSpecTests.java new file mode 100644 index 0000000..b5adde3 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultSamplingSpecTests.java @@ -0,0 +1,215 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.Map; + +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest.ContextInclusionStrategy; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link DefaultSamplingSpec}. + * + * @author Christian Tzolov + */ +public class DefaultSamplingSpecTests { + + @Test + public void testDefaultValues() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + assertThat(spec.messages).isEmpty(); + assertThat(spec.modelPreferences).isNull(); + assertThat(spec.systemPrompt).isNull(); + assertThat(spec.temperature).isNull(); + assertThat(spec.maxTokens).isNull(); + assertThat(spec.stopSequences).isEmpty(); + assertThat(spec.metadata).isEmpty(); + assertThat(spec.meta).isEmpty(); + assertThat(spec.includeContextStrategy).isEqualTo(ContextInclusionStrategy.NONE); + } + + @Test + public void testMessageWithTextContent() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + TextContent content = new TextContent("Test message"); + + spec.message(content); + + assertThat(spec.messages).hasSize(1); + assertThat(spec.messages.get(0).role()).isEqualTo(Role.USER); + assertThat(spec.messages.get(0).content()).isEqualTo(content); + } + + @Test + public void testMessageWithMultipleTextContent() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + TextContent content1 = new TextContent("Message 1"); + TextContent content2 = new TextContent("Message 2"); + + spec.message(content1, content2); + + assertThat(spec.messages).hasSize(2); + } + + @Test + public void testMessageWithSamplingMessage() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + SamplingMessage message = new SamplingMessage(Role.ASSISTANT, new TextContent("Assistant message")); + + spec.message(message); + + assertThat(spec.messages).hasSize(1); + assertThat(spec.messages.get(0)).isEqualTo(message); + } + + @Test + public void testSystemPrompt() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + spec.systemPrompt("System instructions"); + + assertThat(spec.systemPrompt).isEqualTo("System instructions"); + } + + @Test + public void testTemperature() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + spec.temperature(0.7); + + assertThat(spec.temperature).isEqualTo(0.7); + } + + @Test + public void testMaxTokens() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + spec.maxTokens(1000); + + assertThat(spec.maxTokens).isEqualTo(1000); + } + + @Test + public void testStopSequences() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + spec.stopSequences("STOP", "END"); + + assertThat(spec.stopSequences).containsExactly("STOP", "END"); + } + + @Test + public void testIncludeContextStrategy() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + spec.includeContextStrategy(ContextInclusionStrategy.ALL_SERVERS); + + assertThat(spec.includeContextStrategy).isEqualTo(ContextInclusionStrategy.ALL_SERVERS); + } + + @Test + public void testMetadataWithMap() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + Map metadataMap = Map.of("key1", "value1", "key2", "value2"); + + spec.metadata(metadataMap); + + assertThat(spec.metadata).containsEntry("key1", "value1").containsEntry("key2", "value2"); + } + + @Test + public void testMetadataWithKeyValue() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + spec.metadata("key", "value"); + + assertThat(spec.metadata).containsEntry("key", "value"); + } + + @Test + public void testMetaWithMap() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + Map metaMap = Map.of("key1", "value1", "key2", "value2"); + + spec.meta(metaMap); + + assertThat(spec.meta).containsEntry("key1", "value1").containsEntry("key2", "value2"); + } + + @Test + public void testMetaWithKeyValue() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + spec.meta("key", "value"); + + assertThat(spec.meta).containsEntry("key", "value"); + } + + @Test + public void testModelPreferences() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + spec.modelPreferences(prefs -> { + prefs.modelHint("gpt-4"); + prefs.costPriority(0.5); + prefs.speedPriority(0.8); + prefs.intelligencePriority(0.9); + }); + + assertThat(spec.modelPreferences).isNotNull(); + assertThat(spec.modelPreferences.hints()).hasSize(1); + assertThat(spec.modelPreferences.costPriority()).isEqualTo(0.5); + assertThat(spec.modelPreferences.speedPriority()).isEqualTo(0.8); + assertThat(spec.modelPreferences.intelligencePriority()).isEqualTo(0.9); + } + + @Test + public void testFluentInterface() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + McpRequestContextTypes.SamplingSpec result = spec.message(new TextContent("Test")) + .systemPrompt("System") + .temperature(0.7) + .maxTokens(100) + .stopSequences("STOP") + .metadata("key", "value") + .meta("metaKey", "metaValue"); + + assertThat(result).isSameAs(spec); + assertThat(spec.messages).hasSize(1); + assertThat(spec.systemPrompt).isEqualTo("System"); + assertThat(spec.temperature).isEqualTo(0.7); + assertThat(spec.maxTokens).isEqualTo(100); + assertThat(spec.stopSequences).containsExactly("STOP"); + assertThat(spec.metadata).containsEntry("key", "value"); + assertThat(spec.meta).containsEntry("metaKey", "metaValue"); + } + + // ModelPreferenceSpec Tests + + @Test + public void testModelPreferenceSpecWithNullModelHint() { + DefaultSamplingSpec.DefaultModelPreferenceSpec spec = new DefaultSamplingSpec.DefaultModelPreferenceSpec(); + + assertThatThrownBy(() -> spec.modelHint(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Model hint must not be null"); + } + + @Test + public void testModelPreferenceSpecWithNullModelHints() { + DefaultSamplingSpec.DefaultModelPreferenceSpec spec = new DefaultSamplingSpec.DefaultModelPreferenceSpec(); + + assertThatThrownBy(() -> spec.modelHints((String[]) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Models must not be null"); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallbackTests.java index 0f5d9cf..7f0e5b4 100644 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallbackTests.java +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallbackTests.java @@ -18,6 +18,7 @@ import org.springaicommunity.mcp.annotation.McpMeta; import org.springaicommunity.mcp.annotation.McpTool; import org.springaicommunity.mcp.annotation.McpToolParam; +import org.springaicommunity.mcp.context.McpAsyncRequestContext; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -73,6 +74,11 @@ public Mono monoToolWithExchange(McpAsyncServerExchange exchange, String return Mono.just("Exchange tool: " + message); } + @McpTool(name = "context-mono-tool", description = "Mono tool with context parameter") + public Mono monoToolWithContext(McpAsyncRequestContext context, String message) { + return Mono.just("Context tool: " + message); + } + @McpTool(name = "list-mono-tool", description = "Mono tool with list parameter") public Mono processListMono(List items) { return Mono.just("Items: " + String.join(", ", items)); @@ -664,12 +670,34 @@ public void testIsExchangeType() throws Exception { // Test that McpAsyncServerExchange is recognized as exchange type assertThat(callback.isExchangeOrContextType(McpAsyncServerExchange.class)).isTrue(); + // Test that McpAsyncRequestContext is recognized as context type + assertThat(callback.isExchangeOrContextType(McpAsyncRequestContext.class)).isTrue(); + // Test that other types are not recognized as exchange type assertThat(callback.isExchangeOrContextType(String.class)).isFalse(); assertThat(callback.isExchangeOrContextType(Integer.class)).isFalse(); assertThat(callback.isExchangeOrContextType(Object.class)).isFalse(); } + @Test + public void testMonoToolWithContextParameter() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("monoToolWithContext", McpAsyncRequestContext.class, + String.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("context-mono-tool", Map.of("message", "hello")); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Context tool: hello"); + }).verifyComplete(); + } + @Test public void testMonoToolWithOptionalParameters() throws Exception { TestAsyncToolProvider provider = new TestAsyncToolProvider(); diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallbackTests.java index 1f84686..066debc 100644 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallbackTests.java +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallbackTests.java @@ -17,6 +17,7 @@ import org.junit.jupiter.api.Test; import org.springaicommunity.mcp.annotation.McpTool; import org.springaicommunity.mcp.annotation.McpToolParam; +import org.springaicommunity.mcp.context.McpSyncRequestContext; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -56,6 +57,11 @@ public String toolWithExchange(McpSyncServerExchange exchange, String message) { return "Exchange tool: " + message; } + @McpTool(name = "context-tool", description = "Tool with context parameter") + public String toolWithContext(McpSyncRequestContext context, String message) { + return "Context tool: " + message; + } + @McpTool(name = "list-tool", description = "Tool with list parameter") public String processList(List items) { return "Items: " + String.join(", ", items); @@ -443,12 +449,33 @@ public void testIsExchangeType() throws Exception { // Test that McpSyncServerExchange is recognized as exchange type assertThat(callback.isExchangeOrContextType(McpSyncServerExchange.class)).isTrue(); + // Test that McpSyncRequestContext is recognized as context type + assertThat(callback.isExchangeOrContextType(McpSyncRequestContext.class)).isTrue(); + // Test that other types are not recognized as exchange type assertThat(callback.isExchangeOrContextType(String.class)).isFalse(); assertThat(callback.isExchangeOrContextType(Integer.class)).isFalse(); assertThat(callback.isExchangeOrContextType(Object.class)).isFalse(); } + @Test + public void testToolWithContextParameter() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("toolWithContext", McpSyncRequestContext.class, String.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("context-tool", Map.of("message", "hello")); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Context tool: hello"); + } + @Test public void testToolWithInvalidJsonConversion() throws Exception { TestToolProvider provider = new TestToolProvider();