From 9772e27b2ee10cce94b7e268352aa3bf4503ab92 Mon Sep 17 00:00:00 2001 From: Kevin Stanton Date: Thu, 26 Jun 2025 11:05:56 -0500 Subject: [PATCH] WIP: AuthContext --- mcp-bom/pom.xml | 2 +- mcp-spring/mcp-spring-webflux/pom.xml | 6 ++-- mcp-spring/mcp-spring-webmvc/pom.xml | 6 ++-- mcp-test/pom.xml | 4 +-- mcp/pom.xml | 2 +- .../server/McpAsyncServer.java | 7 ++-- .../server/McpAsyncServerExchange.java | 11 +++++- .../server/McpSyncServerExchange.java | 6 +++- .../server/auth/SecurityContext.java | 17 ++++++++++ ...HttpServletSseServerTransportProvider.java | 34 +++++++++++++++++++ .../spec/McpServerSession.java | 20 +++++++++-- pom.xml | 2 +- 12 files changed, 99 insertions(+), 18 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/auth/SecurityContext.java diff --git a/mcp-bom/pom.xml b/mcp-bom/pom.xml index 7214dacda..7dd8ac6eb 100644 --- a/mcp-bom/pom.xml +++ b/mcp-bom/pom.xml @@ -7,7 +7,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.11.0-SNAPSHOT + 0.11.1-mcp-tool-level-authnz-SNAPSHOT mcp-bom diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index 26452fe95..3b0bec9f7 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.11.0-SNAPSHOT + 0.11.1-mcp-tool-level-authnz-SNAPSHOT ../../pom.xml mcp-spring-webflux @@ -25,13 +25,13 @@ io.modelcontextprotocol.sdk mcp - 0.11.0-SNAPSHOT + 0.11.1-mcp-tool-level-authnz-SNAPSHOT io.modelcontextprotocol.sdk mcp-test - 0.11.0-SNAPSHOT + 0.11.1-mcp-tool-level-authnz-SNAPSHOT test diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml index 48d1c3465..5a735bf6e 100644 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.11.0-SNAPSHOT + 0.11.1-mcp-tool-level-authnz-SNAPSHOT ../../pom.xml mcp-spring-webmvc @@ -25,13 +25,13 @@ io.modelcontextprotocol.sdk mcp - 0.11.0-SNAPSHOT + 0.11.1-mcp-tool-level-authnz-SNAPSHOT io.modelcontextprotocol.sdk mcp-test - 0.11.0-SNAPSHOT + 0.11.1-mcp-tool-level-authnz-SNAPSHOT test diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index f24d9fab2..c652f95af 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.11.0-SNAPSHOT + 0.11.1-mcp-tool-level-authnz-SNAPSHOT mcp-test jar @@ -24,7 +24,7 @@ io.modelcontextprotocol.sdk mcp - 0.11.0-SNAPSHOT + 0.11.1-mcp-tool-level-authnz-SNAPSHOT diff --git a/mcp/pom.xml b/mcp/pom.xml index 773432827..5230bccdb 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.11.0-SNAPSHOT + 0.11.1-mcp-tool-level-authnz-SNAPSHOT mcp jar diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 1efa13de3..699ecead1 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -17,6 +17,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.auth.SecurityContext; import io.modelcontextprotocol.spec.McpClientSession; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; @@ -183,9 +184,9 @@ public class McpAsyncServer { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - mcpTransportProvider.setSessionFactory( - transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + mcpTransportProvider.setSessionFactory(transport -> new McpServerSession(UUID.randomUUID().toString(), + requestTimeout, transport, this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, + notificationHandlers, SecurityContext.EMPTY)); } // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 2fd95a10d..70ff0bcaf 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -5,6 +5,7 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.server.auth.SecurityContext; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; @@ -28,6 +29,8 @@ public class McpAsyncServerExchange { private final McpSchema.Implementation clientInfo; + private final SecurityContext securityContext; + private volatile LoggingLevel minLoggingLevel = LoggingLevel.INFO; private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { @@ -47,10 +50,11 @@ public class McpAsyncServerExchange { * @param clientInfo The client implementation information. */ public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabilities clientCapabilities, - McpSchema.Implementation clientInfo) { + McpSchema.Implementation clientInfo, SecurityContext securityContext) { this.session = session; this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; + this.securityContext = securityContext; } /** @@ -159,6 +163,11 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN }); } + public Mono getSecurityContext() { + // defer()? Could securityContext change over time, e.g. token refreshes? + return Mono.just(securityContext == null ? SecurityContext.EMPTY : securityContext); + } + /** * Set the minimum logging level for the client. Messages below this level will be * filtered out. diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index 25da5a6f9..e1ddcb328 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -4,8 +4,8 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.server.auth.SecurityContext; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; /** @@ -108,4 +108,8 @@ public void loggingNotification(LoggingMessageNotification loggingMessageNotific this.exchange.loggingNotification(loggingMessageNotification).block(); } + public SecurityContext getSecurityContext() { + return this.exchange.getSecurityContext().block(); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/auth/SecurityContext.java b/mcp/src/main/java/io/modelcontextprotocol/server/auth/SecurityContext.java new file mode 100644 index 000000000..5b9214a3f --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/auth/SecurityContext.java @@ -0,0 +1,17 @@ +package io.modelcontextprotocol.server.auth; + +import java.security.Principal; + +public record SecurityContext(Principal principal, String authHeader) { + // absent SecurityContext marker + public static final SecurityContext EMPTY = new SecurityContext(null, ""); + + public boolean isEmpty() { + return this == EMPTY; + } + + public boolean isPresent() { + return !isEmpty(); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index afdbff472..b2e51e28d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -13,6 +13,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.auth.SecurityContext; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; @@ -191,6 +192,14 @@ public Mono notifyClients(String method, Object params) { protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { + SecurityContext securityContext = (SecurityContext) request.getAttribute("securityContext"); + if (securityContext == null) { + // if null but auth is not required... + // TODO: and auth is required... + response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Unauthorized"); + return; + } + String requestURI = request.getRequestURI(); if (!requestURI.endsWith(sseEndpoint)) { response.sendError(HttpServletResponse.SC_NOT_FOUND); @@ -220,6 +229,11 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) // Create a new session using the session factory McpServerSession session = sessionFactory.create(sessionTransport); + + // set security context if available (otherwise it'll be SecurityContext.EMPTY) + if (securityContext != null) { + session.setSecurityContext(securityContext); + } this.sessions.put(sessionId, session); // Send initial endpoint event @@ -246,6 +260,13 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } + SecurityContext securityContext = (SecurityContext) request.getAttribute("securityContext"); + if (securityContext == null) { + // TODO: and auth is required... + response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Unauthorized"); + return; + } + String requestURI = request.getRequestURI(); if (!requestURI.endsWith(messageEndpoint)) { response.sendError(HttpServletResponse.SC_NOT_FOUND); @@ -278,6 +299,19 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } + // set security context if available (otherwise it'll be SecurityContext.EMPTY) + if (securityContext != null) { + // Update the session's security context + if (session.getSecurityContext() == null) { + // TODO should we allow this if no previous security context? + } + else if (!session.getSecurityContext().principal().equals(securityContext.principal())) { + // TODO if the principal has changed, return unauthorized + // don't allow changing the principal for the session + } + session.setSecurityContext(securityContext); + } + try { BufferedReader reader = request.getReader(); StringBuilder body = new StringBuilder(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 86906d859..348cad51d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -9,6 +9,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.auth.SecurityContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; @@ -48,6 +49,8 @@ public class McpServerSession implements McpSession { private final AtomicReference clientInfo = new AtomicReference<>(); + private SecurityContext securityContext; + private static final int STATE_UNINITIALIZED = 0; private static final int STATE_INITIALIZING = 1; @@ -68,10 +71,12 @@ public class McpServerSession implements McpSession { * received. * @param requestHandlers map of request handlers to use * @param notificationHandlers map of notification handlers to use + * @param securityContext the authentication context for this session */ public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, InitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, - Map> requestHandlers, Map notificationHandlers) { + Map> requestHandlers, Map notificationHandlers, + SecurityContext securityContext) { this.id = id; this.requestTimeout = requestTimeout; this.transport = transport; @@ -79,6 +84,7 @@ public McpServerSession(String id, Duration requestTimeout, McpServerTransport t this.initNotificationHandler = initNotificationHandler; this.requestHandlers = requestHandlers; this.notificationHandlers = notificationHandlers; + this.securityContext = securityContext != null ? securityContext : SecurityContext.EMPTY; } /** @@ -242,7 +248,8 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti return Mono.defer(() -> { if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { this.state.lazySet(STATE_INITIALIZED); - exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get())); + exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get(), + this.securityContext)); return this.initNotificationHandler.handle(); } @@ -255,6 +262,14 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti }); } + public SecurityContext getSecurityContext() { + return securityContext; + } + + public void setSecurityContext(SecurityContext securityContext) { + this.securityContext = securityContext != null ? securityContext : SecurityContext.EMPTY; + } + record MethodNotFoundError(String method, String message, Object data) { } @@ -321,6 +336,7 @@ public interface NotificationHandler { * @param the type of the response that is expected as a result of handling the * request. */ + @FunctionalInterface public interface RequestHandler { /** diff --git a/pom.xml b/pom.xml index 3fd0857e8..778e57b81 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.11.0-SNAPSHOT + 0.11.1-mcp-tool-level-authnz-SNAPSHOT pom https://github.com/modelcontextprotocol/java-sdk