From 789d87562c5108318f2628c7e7534fdc7ba517c7 Mon Sep 17 00:00:00 2001 From: Kevin Stanton Date: Sun, 10 Aug 2025 00:18:05 -0500 Subject: [PATCH 1/5] test transportContext present for both HttpServletSseServerTransportProvider and HttpServletStreamableServerTransportProvider --- ...stractMcpClientServerIntegrationTests.java | 62 +++++++++++++++++++ .../HttpServletSseIntegrationTests.java | 1 + ...HttpServletStreamableIntegrationTests.java | 1 + 3 files changed, 64 insertions(+) diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java index e2adb340c..cc580bdd4 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java @@ -10,6 +10,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertWith; import static org.awaitility.Awaitility.await; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; import java.net.URI; @@ -28,6 +29,7 @@ import java.util.function.Function; import java.util.stream.Collectors; +import jakarta.servlet.http.HttpServletRequest; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -825,6 +827,61 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { mcpServer.close(); } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testToolCallSuccessWithTranportContextExtraction(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var expectedCallResponse = new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("CALL RESPONSE; ctx=value")), null); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + McpTransportContext transportContext = exchange.transportContext(); + assertTrue(transportContext != null, "transportContext should not be null"); + assertTrue(!transportContext.equals(McpTransportContext.EMPTY), "transportContext should not be empty"); + String ctxValue = (String) transportContext.get("important"); + + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("CALL RESPONSE; ctx=" + ctxValue)), null); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull().isEqualTo(expectedCallResponse); + } + + mcpServer.close(); + } + @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient" }) void testToolListChangeHandlingSuccess(String clientType) { @@ -1531,4 +1588,9 @@ private double evaluateExpression(String expression) { }; } + protected static McpTransportContextExtractor extractor = (r, tc) -> { + tc.put("important", "value"); + return tc; + }; + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java index 56e74218f..4435b8b44 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java @@ -40,6 +40,7 @@ public void before() { // Create and configure the transport provider mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() .objectMapper(new ObjectMapper()) + .contextExtractor(extractor) .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) .sseEndpoint(CUSTOM_SSE_ENDPOINT) .build(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java index 6ac10014e..0815556b9 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java @@ -38,6 +38,7 @@ public void before() { // Create and configure the transport provider mcpServerTransportProvider = HttpServletStreamableServerTransportProvider.builder() .objectMapper(new ObjectMapper()) + .contextExtractor(extractor) .mcpEndpoint(MESSAGE_ENDPOINT) .keepAliveInterval(Duration.ofSeconds(1)) .build(); From 7d5fcb3e76949b0c50b48c493d8fb1d63ce125b0 Mon Sep 17 00:00:00 2001 From: Kevin Stanton Date: Sat, 9 Aug 2025 18:22:25 -0500 Subject: [PATCH 2/5] add McpTransportContext capability to HttpServletSseServerTransportProvider --- .../server/DefaultMcpTransportContext.java | 10 ++++ ...HttpServletSseServerTransportProvider.java | 51 +++++++++++++++++-- .../spec/McpServerSession.java | 30 ++++++++--- 3 files changed, 80 insertions(+), 11 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java b/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java index 9e18e189d..c6a556b6e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java @@ -5,6 +5,7 @@ package io.modelcontextprotocol.server; import java.util.Map; +import java.util.StringJoiner; import java.util.concurrent.ConcurrentHashMap; /** @@ -46,4 +47,13 @@ public McpTransportContext copy() { return new DefaultMcpTransportContext(new ConcurrentHashMap<>(this.storage)); } + // TODO for debugging + + @Override + public String toString() { + return new StringJoiner(", ", DefaultMcpTransportContext.class.getSimpleName() + "[", "]") + .add("storage=" + storage) + .toString(); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index ceeea31b1..8c45b7cc3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -16,6 +16,9 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.DefaultMcpTransportContext; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; @@ -102,6 +105,8 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement /** Map of active client sessions, keyed by session ID */ private final Map sessions = new ConcurrentHashMap<>(); + private McpTransportContextExtractor contextExtractor; + /** Flag indicating if the transport is in the process of shutting down */ private final AtomicBoolean isClosing = new AtomicBoolean(false); @@ -144,7 +149,7 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String m @Deprecated public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { - this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null, null); } /** @@ -163,11 +168,33 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String b @Deprecated public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint, Duration keepAliveInterval) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, null); + } + + /** + * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE + * endpoint. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param baseUrl The base URL for the server transport + * @param messageEndpoint The endpoint path where clients will send their messages + * @param sseEndpoint The endpoint path where clients will establish SSE connections + * @param keepAliveInterval The interval for keep-alive pings, or null to disable + * @param contextExtractor The extractor for transport context from the request. + * keep-alive functionality + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. + */ + @Deprecated + public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval, + McpTransportContextExtractor contextExtractor) { this.objectMapper = objectMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.contextExtractor = contextExtractor; if (keepAliveInterval != null) { @@ -339,10 +366,13 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) body.append(line); } + final McpTransportContext transportContext = contextExtractor.extract(request, + new DefaultMcpTransportContext()); McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); // Process the message through the session's handle method - session.handle(message).block(); // Block for Servlet compatibility + // Block for Servlet compatibility + session.handle(message).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); response.setStatus(HttpServletResponse.SC_OK); } @@ -534,6 +564,8 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + private Duration keepAliveInterval; /** @@ -583,6 +615,19 @@ public Builder sseEndpoint(String sseEndpoint) { return this; } + /** + * Sets the context extractor for extracting transport context from the request. + * @param contextExtractor The context extractor to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public HttpServletSseServerTransportProvider.Builder contextExtractor( + McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + /** * Sets the interval for keep-alive pings. *

@@ -609,7 +654,7 @@ public HttpServletSseServerTransportProvider build() { throw new IllegalStateException("MessageEndpoint must be set"); } return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, - keepAliveInterval); + keepAliveInterval, contextExtractor); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 62985dc17..669c10b83 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -198,7 +198,9 @@ public Mono sendNotification(String method, Object params) { * @return a Mono that completes when the message is processed */ public Mono handle(McpSchema.JSONRPCMessage message) { - return Mono.defer(() -> { + return Mono.deferContextual(ctx -> { + McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + // TODO handle errors for communication to without initialization happening // first if (message instanceof McpSchema.JSONRPCResponse response) { @@ -214,7 +216,7 @@ public Mono handle(McpSchema.JSONRPCMessage message) { } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); - return handleIncomingRequest(request).onErrorResume(error -> { + return handleIncomingRequest(request, transportContext).onErrorResume(error -> { var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); @@ -227,7 +229,7 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) { // happening first logger.debug("Received notification: {}", notification); // TODO: in case of error, should the POST request be signalled? - return handleIncomingNotification(notification) + return handleIncomingNotification(notification, transportContext) .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); } else { @@ -240,9 +242,11 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) { /** * Handles an incoming JSON-RPC request by routing it to the appropriate handler. * @param request The incoming JSON-RPC request + * @param transportContext * @return A Mono containing the JSON-RPC response */ - private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { + private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request, + McpTransportContext transportContext) { return Mono.defer(() -> { Mono resultMono; if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { @@ -266,7 +270,11 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR error.message(), error.data()))); } - resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params())); + resultMono = this.exchangeSink.asMono().flatMap(exchange -> { + McpAsyncServerExchange newExchange = new McpAsyncServerExchange(exchange.sessionId(), this, + exchange.getClientCapabilities(), exchange.getClientInfo(), transportContext); + return handler.handle(newExchange, request.params()); + }); } return resultMono .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) @@ -280,16 +288,18 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR /** * Handles an incoming JSON-RPC notification by routing it to the appropriate handler. * @param notification The incoming JSON-RPC notification + * @param transportContext * @return A Mono that completes when the notification is processed */ - private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { + private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification, + McpTransportContext transportContext) { return Mono.defer(() -> { if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { this.state.lazySet(STATE_INITIALIZED); // FIXME: The session ID passed here is not the same as the one in the // legacy SSE transport. exchangeSink.tryEmitValue(new McpAsyncServerExchange(this.id, this, clientCapabilities.get(), - clientInfo.get(), McpTransportContext.EMPTY)); + clientInfo.get(), transportContext)); } var handler = notificationHandlers.get(notification.method()); @@ -297,7 +307,11 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti logger.warn("No handler registered for notification method: {}", notification); return Mono.empty(); } - return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, notification.params())); + return this.exchangeSink.asMono().flatMap(exchange -> { + McpAsyncServerExchange newExchange = new McpAsyncServerExchange(exchange.sessionId(), this, + exchange.getClientCapabilities(), exchange.getClientInfo(), transportContext); + return handler.handle(newExchange, notification.params()); + }); }); } From 04b562d0ee4fb120d1ed1089dc2482f01db901f5 Mon Sep 17 00:00:00 2001 From: Kevin Stanton Date: Sun, 10 Aug 2025 01:01:41 -0500 Subject: [PATCH 3/5] remove TODO/debug code --- .../server/DefaultMcpTransportContext.java | 9 --------- 1 file changed, 9 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java b/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java index c6a556b6e..0222a1883 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java @@ -47,13 +47,4 @@ public McpTransportContext copy() { return new DefaultMcpTransportContext(new ConcurrentHashMap<>(this.storage)); } - // TODO for debugging - - @Override - public String toString() { - return new StringJoiner(", ", DefaultMcpTransportContext.class.getSimpleName() + "[", "]") - .add("storage=" + storage) - .toString(); - } - } From f2efd30e04538893ce5a1a9933555738779f561a Mon Sep 17 00:00:00 2001 From: Kevin Stanton Date: Sun, 10 Aug 2025 01:02:20 -0500 Subject: [PATCH 4/5] remove unused import --- .../modelcontextprotocol/server/DefaultMcpTransportContext.java | 1 - 1 file changed, 1 deletion(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java b/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java index 0222a1883..9e18e189d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java @@ -5,7 +5,6 @@ package io.modelcontextprotocol.server; import java.util.Map; -import java.util.StringJoiner; import java.util.concurrent.ConcurrentHashMap; /** From 991aa4a879c2f4f3bbb21f977abdcc06285a9c32 Mon Sep 17 00:00:00 2001 From: Kevin Stanton Date: Sun, 10 Aug 2025 15:00:38 -0500 Subject: [PATCH 5/5] fix comment --- .../server/transport/HttpServletSseServerTransportProvider.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index 8c45b7cc3..d60e927f0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -180,8 +180,8 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String b * @param messageEndpoint The endpoint path where clients will send their messages * @param sseEndpoint The endpoint path where clients will establish SSE connections * @param keepAliveInterval The interval for keep-alive pings, or null to disable + * keep-alive functionality * @param contextExtractor The extractor for transport context from the request. - * keep-alive functionality * @deprecated Use the builder {@link #builder()} instead for better configuration * options. */