Skip to content

Commit d1c4b17

Browse files
committed
ws: make sure MCP connection is initialized before a message is consumed
- fixes intermittent CI failures
1 parent 822e8d6 commit d1c4b17

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

transports/websocket/deployment/src/main/java/io/quarkiverse/mcp/server/websocket/deployment/WebSocketMcpServerProcessor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ void generateEndpoints(McpWebSocketServersBuildTimeConfig config, BuildProducer<
6868
.className(endpointClassName)
6969
.superClass(WebSocketMcpMessageHandler.class)
7070
.build();
71-
// @WebSocket(path = "/foo/bar")
71+
// @WebSocket(path = "/foo/bar", inboundProcessingMode = InboundProcessingMode.CONCURRENT)
7272
endpointCreator.addAnnotation(
7373
AnnotationInstance.builder(WebSocket.class)
7474
.add("path", e.getValue().websocket().endpointPath())

transports/websocket/runtime/src/main/java/io/quarkiverse/mcp/server/websocket/runtime/WebSocketMcpMessageHandler.java

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package io.quarkiverse.mcp.server.websocket.runtime;
22

33
import java.util.List;
4+
import java.util.concurrent.ConcurrentHashMap;
5+
import java.util.concurrent.ConcurrentMap;
46

57
import jakarta.enterprise.inject.Instance;
68

@@ -30,7 +32,6 @@
3032
import io.quarkus.websockets.next.OnClose;
3133
import io.quarkus.websockets.next.OnOpen;
3234
import io.quarkus.websockets.next.OnTextMessage;
33-
import io.quarkus.websockets.next.UserData.TypedKey;
3435
import io.quarkus.websockets.next.WebSocketConnection;
3536
import io.smallrye.mutiny.Uni;
3637
import io.smallrye.mutiny.vertx.UniHelper;
@@ -44,8 +45,6 @@ public abstract class WebSocketMcpMessageHandler extends McpMessageHandler<WebSo
4445

4546
private static final Logger LOG = Logger.getLogger(WebSocketMcpMessageHandler.class);
4647

47-
private static final String MCP_CONNECTION_ID = "mcpConnectionId";
48-
4948
CurrentIdentityAssociation currentIdentityAssociation;
5049

5150
protected WebSocketMcpMessageHandler(McpServersRuntimeConfig config,
@@ -70,20 +69,17 @@ protected WebSocketMcpMessageHandler(McpServersRuntimeConfig config,
7069

7170
protected abstract String serverName();
7271

72+
private final ConcurrentMap<String, WebSocketMcpConnection> connections = new ConcurrentHashMap<>();
73+
7374
@OnOpen
7475
void openConnection(WebSocketConnection connection) {
75-
String id = ConnectionManager.connectionId();
76-
WebSocketMcpConnection mcpConnection = new WebSocketMcpConnection(id, config.servers().get(serverName()), connection);
77-
connectionManager.add(mcpConnection);
78-
LOG.debugf("WebSocket connection initialized [%s]", id);
79-
connection.userData().put(TypedKey.forString(MCP_CONNECTION_ID), id);
76+
LOG.debugf("MCP WebSocket connection open [id: %s]", connection.id());
8077
}
8178

8279
@SuppressWarnings("unchecked")
8380
@OnTextMessage
84-
Uni<Void> consumeMessage(WebSocketConnection connection, String message) {
85-
String connectionId = connection.userData().get(TypedKey.forString(MCP_CONNECTION_ID));
86-
WebSocketMcpConnection mcpConnection = (WebSocketMcpConnection) connectionManager.get(connectionId);
81+
Uni<Void> consumeMessage(WebSocketConnection connection, String message) throws InterruptedException {
82+
WebSocketMcpConnection mcpConnection = connections.computeIfAbsent(connection.id(), k -> newConnection(connection));
8783
Object json = Json.decodeValue(message);
8884

8985
SecuritySupport securitySupport;
@@ -106,16 +102,26 @@ public void setCurrentIdentity(CurrentIdentityAssociation currentIdentityAssocia
106102

107103
@OnClose
108104
void closeConnection(WebSocketConnection connection) {
109-
String id = connection.userData().get(TypedKey.forString(MCP_CONNECTION_ID));
110-
connectionManager.remove(id);
111-
LOG.debugf("WebSocket connection closed [%s]", id);
105+
WebSocketMcpConnection mcpConnection = connections.remove(connection.id());
106+
if (mcpConnection != null) {
107+
connectionManager.remove(mcpConnection.id());
108+
LOG.debugf("MCP WebSocket connection closed [mcpId: %s, id: %s]", mcpConnection.id(), connection.id());
109+
}
112110
}
113111

114112
@Override
115113
protected Transport transport() {
116114
return Transport.WEBSOCKET;
117115
}
118116

117+
private WebSocketMcpConnection newConnection(WebSocketConnection connection) {
118+
String id = ConnectionManager.connectionId();
119+
WebSocketMcpConnection mcpConnection = new WebSocketMcpConnection(id, config.servers().get(serverName()), connection);
120+
connectionManager.add(mcpConnection);
121+
LOG.debugf("MCP WebSocket connection initialized [mcpId: %s, id: %s]", mcpConnection.id(), connection.id());
122+
return mcpConnection;
123+
}
124+
119125
static class WebSocketMcpRequest extends McpRequestImpl<WebSocketMcpConnection> {
120126

121127
WebSocketMcpRequest(String serverName, Object json, WebSocketMcpConnection connection,

0 commit comments

Comments
 (0)