11package io .quarkiverse .mcp .server .websocket .runtime ;
22
33import java .util .List ;
4+ import java .util .concurrent .ConcurrentHashMap ;
5+ import java .util .concurrent .ConcurrentMap ;
46
57import jakarta .enterprise .inject .Instance ;
68
3032import io .quarkus .websockets .next .OnClose ;
3133import io .quarkus .websockets .next .OnOpen ;
3234import io .quarkus .websockets .next .OnTextMessage ;
33- import io .quarkus .websockets .next .UserData .TypedKey ;
3435import io .quarkus .websockets .next .WebSocketConnection ;
3536import io .smallrye .mutiny .Uni ;
3637import io .smallrye .mutiny .vertx .UniHelper ;
@@ -44,8 +45,6 @@ public abstract class WebSocketMcpMessageHandler extends McpMessageHandler<WebSo
4445
4546 private static final Logger LOG = Logger .getLogger (WebSocketMcpMessageHandler .class );
4647
47- private static final String MCP_CONNECTION_ID = "mcpConnectionId" ;
48-
4948 CurrentIdentityAssociation currentIdentityAssociation ;
5049
5150 protected WebSocketMcpMessageHandler (McpServersRuntimeConfig config ,
@@ -70,20 +69,17 @@ protected WebSocketMcpMessageHandler(McpServersRuntimeConfig config,
7069
7170 protected abstract String serverName ();
7271
72+ private final ConcurrentMap <String , WebSocketMcpConnection > connections = new ConcurrentHashMap <>();
73+
7374 @ OnOpen
7475 void openConnection (WebSocketConnection connection ) {
75- String id = ConnectionManager .connectionId ();
76- WebSocketMcpConnection mcpConnection = new WebSocketMcpConnection (id , config .servers ().get (serverName ()), connection );
77- connectionManager .add (mcpConnection );
78- LOG .debugf ("WebSocket connection initialized [%s]" , id );
79- connection .userData ().put (TypedKey .forString (MCP_CONNECTION_ID ), id );
76+ LOG .debugf ("MCP WebSocket connection open [id: %s]" , connection .id ());
8077 }
8178
8279 @ SuppressWarnings ("unchecked" )
8380 @ OnTextMessage
84- Uni <Void > consumeMessage (WebSocketConnection connection , String message ) {
85- String connectionId = connection .userData ().get (TypedKey .forString (MCP_CONNECTION_ID ));
86- WebSocketMcpConnection mcpConnection = (WebSocketMcpConnection ) connectionManager .get (connectionId );
81+ Uni <Void > consumeMessage (WebSocketConnection connection , String message ) throws InterruptedException {
82+ WebSocketMcpConnection mcpConnection = connections .computeIfAbsent (connection .id (), k -> newConnection (connection ));
8783 Object json = Json .decodeValue (message );
8884
8985 SecuritySupport securitySupport ;
@@ -106,16 +102,26 @@ public void setCurrentIdentity(CurrentIdentityAssociation currentIdentityAssocia
106102
107103 @ OnClose
108104 void closeConnection (WebSocketConnection connection ) {
109- String id = connection .userData ().get (TypedKey .forString (MCP_CONNECTION_ID ));
110- connectionManager .remove (id );
111- LOG .debugf ("WebSocket connection closed [%s]" , id );
105+ WebSocketMcpConnection mcpConnection = connections .remove (connection .id ());
106+ if (mcpConnection != null ) {
107+ connectionManager .remove (mcpConnection .id ());
108+ LOG .debugf ("MCP WebSocket connection closed [mcpId: %s, id: %s]" , mcpConnection .id (), connection .id ());
109+ }
112110 }
113111
114112 @ Override
115113 protected Transport transport () {
116114 return Transport .WEBSOCKET ;
117115 }
118116
117+ private WebSocketMcpConnection newConnection (WebSocketConnection connection ) {
118+ String id = ConnectionManager .connectionId ();
119+ WebSocketMcpConnection mcpConnection = new WebSocketMcpConnection (id , config .servers ().get (serverName ()), connection );
120+ connectionManager .add (mcpConnection );
121+ LOG .debugf ("MCP WebSocket connection initialized [mcpId: %s, id: %s]" , mcpConnection .id (), connection .id ());
122+ return mcpConnection ;
123+ }
124+
119125 static class WebSocketMcpRequest extends McpRequestImpl <WebSocketMcpConnection > {
120126
121127 WebSocketMcpRequest (String serverName , Object json , WebSocketMcpConnection connection ,
0 commit comments