Skip to content

Commit 5370844

Browse files
committed
WebSockets Next: perform connection cleanup during app shutdown
1 parent 4e66327 commit 5370844

File tree

5 files changed

+91
-22
lines changed

5 files changed

+91
-22
lines changed

extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/BasicWebSocketConnectorImpl.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ public Uni<WebSocketClientConnection> connect() {
155155
throw new WebSocketClientException(e);
156156
}
157157

158-
Uni<WebSocket> websocket = Uni.createFrom().<WebSocket> emitter(e -> {
158+
Uni<WebSocketOpen> websocketOpen = Uni.createFrom().<WebSocketOpen> emitter(e -> {
159159
// Create a new event loop context for each client, otherwise the current context is used
160160
// We want to avoid a situation where if multiple clients/connections are created in a row,
161161
// the same event loop is used and so writing/receiving messages is de-facto serialized
@@ -171,7 +171,7 @@ public void handle(Void event) {
171171
@Override
172172
public void handle(AsyncResult<WebSocket> r) {
173173
if (r.succeeded()) {
174-
e.complete(r.result());
174+
e.complete(new WebSocketOpen(newCleanupConsumer(c, context), r.result()));
175175
} else {
176176
e.fail(r.cause());
177177
}
@@ -183,14 +183,20 @@ public void handle(AsyncResult<WebSocket> r) {
183183
}
184184
});
185185
});
186-
return websocket.map(ws -> {
186+
return websocketOpen.map(wsOpen -> {
187+
WebSocket ws = wsOpen.websocket();
187188
String clientId = BasicWebSocketConnector.class.getName();
188189
TrafficLogger trafficLogger = TrafficLogger.forClient(config);
189-
WebSocketClientConnectionImpl connection = new WebSocketClientConnectionImpl(clientId, ws,
190+
WebSocketClientConnectionImpl connection = new WebSocketClientConnectionImpl(clientId,
191+
ws,
190192
codecs,
191193
pathParams,
192194
serverEndpointUri,
193-
headers, trafficLogger, userData, null);
195+
headers,
196+
trafficLogger,
197+
userData,
198+
null,
199+
wsOpen.cleanup());
194200
if (trafficLogger != null) {
195201
trafficLogger.connectionOpened(connection);
196202
}

extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ClientConnectionManager.java

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,21 @@
22

33
import java.util.Iterator;
44
import java.util.List;
5+
import java.util.Map.Entry;
56
import java.util.Set;
67
import java.util.concurrent.ConcurrentHashMap;
78
import java.util.concurrent.ConcurrentMap;
89
import java.util.concurrent.CopyOnWriteArrayList;
910
import java.util.stream.Stream;
1011

11-
import jakarta.annotation.PreDestroy;
1212
import jakarta.enterprise.event.Event;
1313
import jakarta.inject.Singleton;
1414

1515
import org.jboss.logging.Logger;
1616

1717
import io.quarkus.arc.Arc;
1818
import io.quarkus.arc.ArcContainer;
19+
import io.quarkus.runtime.Shutdown;
1920
import io.quarkus.websockets.next.Closed;
2021
import io.quarkus.websockets.next.Open;
2122
import io.quarkus.websockets.next.OpenClientConnections;
@@ -26,7 +27,7 @@ public class ClientConnectionManager implements OpenClientConnections {
2627

2728
private static final Logger LOG = Logger.getLogger(ClientConnectionManager.class);
2829

29-
private final ConcurrentMap<String, Set<WebSocketClientConnection>> endpointToConnections = new ConcurrentHashMap<>();
30+
private final ConcurrentMap<String, Set<WebSocketClientConnectionImpl>> endpointToConnections = new ConcurrentHashMap<>();
3031

3132
private final List<ClientConnectionListener> listeners = new CopyOnWriteArrayList<>();
3233

@@ -50,10 +51,11 @@ public Iterator<WebSocketClientConnection> iterator() {
5051

5152
@Override
5253
public Stream<WebSocketClientConnection> stream() {
53-
return endpointToConnections.values().stream().flatMap(Set::stream).filter(WebSocketClientConnection::isOpen);
54+
return endpointToConnections.values().stream().flatMap(Set::stream).filter(WebSocketClientConnection::isOpen)
55+
.map(WebSocketClientConnection.class::cast);
5456
}
5557

56-
void add(String endpoint, WebSocketClientConnection connection) {
58+
void add(String endpoint, WebSocketClientConnectionImpl connection) {
5759
LOG.debugf("Add client connection: %s", connection);
5860
if (endpointToConnections.computeIfAbsent(endpoint, e -> ConcurrentHashMap.newKeySet()).add(connection)) {
5961
if (openEvent != null) {
@@ -72,9 +74,9 @@ void add(String endpoint, WebSocketClientConnection connection) {
7274
}
7375
}
7476

75-
void remove(String endpoint, WebSocketClientConnection connection) {
77+
void remove(String endpoint, WebSocketClientConnectionImpl connection) {
7678
LOG.debugf("Remove client connection: %s", connection);
77-
Set<WebSocketClientConnection> connections = endpointToConnections.get(endpoint);
79+
Set<WebSocketClientConnectionImpl> connections = endpointToConnections.get(endpoint);
7880
if (connections != null) {
7981
if (connections.remove(connection)) {
8082
if (closedEvent != null) {
@@ -99,8 +101,8 @@ void remove(String endpoint, WebSocketClientConnection connection) {
99101
* @param endpoint
100102
* @return the connections for the given client endpoint, never {@code null}
101103
*/
102-
public Set<WebSocketClientConnection> getConnections(String endpoint) {
103-
Set<WebSocketClientConnection> ret = endpointToConnections.get(endpoint);
104+
public Set<WebSocketClientConnectionImpl> getConnections(String endpoint) {
105+
Set<WebSocketClientConnectionImpl> ret = endpointToConnections.get(endpoint);
104106
if (ret == null) {
105107
return Set.of();
106108
}
@@ -111,9 +113,19 @@ public void addListener(ClientConnectionListener listener) {
111113
this.listeners.add(listener);
112114
}
113115

114-
@PreDestroy
115-
void destroy() {
116-
endpointToConnections.clear();
116+
@Shutdown
117+
void cleanup() {
118+
if (!endpointToConnections.isEmpty()) {
119+
int sum = 0;
120+
for (Entry<String, Set<WebSocketClientConnectionImpl>> e : endpointToConnections.entrySet()) {
121+
for (WebSocketClientConnectionImpl c : e.getValue()) {
122+
c.cleanup();
123+
sum++;
124+
}
125+
}
126+
LOG.debugf("Cleanup performed for %s connections", sum);
127+
endpointToConnections.clear();
128+
}
117129
}
118130

119131
public interface ClientConnectionListener {

extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketClientConnectionImpl.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import java.util.Map;
77
import java.util.Map.Entry;
88
import java.util.Objects;
9+
import java.util.function.Consumer;
910

1011
import io.quarkus.websockets.next.HandshakeRequest;
1112
import io.quarkus.websockets.next.WebSocketClientConnection;
@@ -19,13 +20,17 @@ class WebSocketClientConnectionImpl extends WebSocketConnectionBase implements W
1920

2021
private final WebSocket webSocket;
2122

23+
private final Consumer<WebSocketClientConnection> cleanup;
24+
2225
WebSocketClientConnectionImpl(String clientId, WebSocket webSocket, Codecs codecs,
2326
Map<String, String> pathParams, URI serverEndpointUri, Map<String, List<String>> headers,
24-
TrafficLogger trafficLogger, Map<String, Object> userData, SendingInterceptor sendingInterceptor) {
27+
TrafficLogger trafficLogger, Map<String, Object> userData, SendingInterceptor sendingInterceptor,
28+
Consumer<WebSocketClientConnection> cleanup) {
2529
super(Map.copyOf(pathParams), codecs, new ClientHandshakeRequestImpl(serverEndpointUri, headers), trafficLogger,
2630
new UserDataImpl(userData), sendingInterceptor);
2731
this.clientId = clientId;
2832
this.webSocket = Objects.requireNonNull(webSocket);
33+
this.cleanup = cleanup;
2934
}
3035

3136
@Override
@@ -48,6 +53,12 @@ public int hashCode() {
4853
return Objects.hash(identifier);
4954
}
5055

56+
protected void cleanup() {
57+
if (cleanup != null) {
58+
cleanup.accept(this);
59+
}
60+
}
61+
5162
@Override
5263
public boolean equals(Object obj) {
5364
if (this == obj)

extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorBase.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,32 @@
1111
import java.util.Objects;
1212
import java.util.Optional;
1313
import java.util.Set;
14+
import java.util.function.Consumer;
1415
import java.util.regex.Matcher;
1516
import java.util.regex.Pattern;
1617

18+
import org.jboss.logging.Logger;
19+
1720
import io.quarkus.tls.TlsConfiguration;
1821
import io.quarkus.tls.TlsConfigurationRegistry;
1922
import io.quarkus.tls.runtime.config.TlsConfigUtils;
2023
import io.quarkus.websockets.next.UserData.TypedKey;
24+
import io.quarkus.websockets.next.WebSocketClientConnection;
2125
import io.quarkus.websockets.next.WebSocketClientException;
2226
import io.quarkus.websockets.next.runtime.config.WebSocketsClientRuntimeConfig;
2327
import io.vertx.core.Vertx;
28+
import io.vertx.core.http.WebSocket;
29+
import io.vertx.core.http.WebSocketClient;
2430
import io.vertx.core.http.WebSocketClientOptions;
2531
import io.vertx.core.http.WebSocketConnectOptions;
32+
import io.vertx.core.impl.ContextImpl;
2633

2734
abstract class WebSocketConnectorBase<THIS extends WebSocketConnectorBase<THIS>> {
2835

2936
protected static final Pattern PATH_PARAM_PATTERN = Pattern.compile("\\{[a-zA-Z0-9_]+\\}");
3037

38+
private static final Logger LOG = Logger.getLogger(WebSocketConnectorBase.class);
39+
3140
// mutable state
3241

3342
protected URI baseUri;
@@ -201,4 +210,28 @@ protected WebSocketConnectOptions newConnectOptions(URI serverEndpointUri) {
201210
protected boolean isSecure(URI uri) {
202211
return "https".equals(uri.getScheme()) || "wss".equals(uri.getScheme());
203212
}
213+
214+
record WebSocketOpen(Consumer<WebSocketClientConnection> cleanup, WebSocket websocket) {
215+
}
216+
217+
Consumer<WebSocketClientConnection> newCleanupConsumer(WebSocketClient client, ContextImpl context) {
218+
return new Consumer<WebSocketClientConnection>() {
219+
@Override
220+
public void accept(WebSocketClientConnection conn) {
221+
try {
222+
client.close();
223+
LOG.debugf("Client closed for connection %s", conn.id());
224+
} catch (Throwable e) {
225+
LOG.errorf(e, "Unable to close the client for connection %s", conn.id());
226+
}
227+
try {
228+
context.close();
229+
LOG.debugf("Context closed for connection %s", conn.id());
230+
} catch (Throwable e) {
231+
LOG.errorf(e, "Unable to close the context for connection %s", conn.id());
232+
}
233+
}
234+
};
235+
}
236+
204237
}

extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ public Uni<WebSocketClientConnection> connect() {
102102

103103
var telemetrySupport = telemetryProvider == null ? null
104104
: telemetryProvider.createClientTelemetrySupport(clientEndpoint.path);
105-
Uni<WebSocket> websocket = Uni.createFrom().<WebSocket> emitter(e -> {
105+
Uni<WebSocketOpen> websocketOpen = Uni.createFrom().<WebSocketOpen> emitter(e -> {
106106
// Create a new event loop context for each client, otherwise the current context is used
107107
// We want to avoid a situation where if multiple clients/connections are created in a row,
108108
// the same event loop is used and so writing/receiving messages is de-facto serialized
@@ -121,7 +121,7 @@ public void handle(Void event) {
121121
@Override
122122
public void handle(AsyncResult<WebSocket> r) {
123123
if (r.succeeded()) {
124-
e.complete(r.result());
124+
e.complete(new WebSocketOpen(newCleanupConsumer(c, context), r.result()));
125125
} else {
126126
if (telemetrySupport != null && telemetrySupport.interceptConnection()) {
127127
telemetrySupport.connectionOpeningFailed(r.cause());
@@ -136,13 +136,20 @@ public void handle(AsyncResult<WebSocket> r) {
136136
}
137137
});
138138
});
139-
return websocket.map(ws -> {
139+
return websocketOpen.map(wsOpen -> {
140+
WebSocket ws = wsOpen.websocket();
140141
TrafficLogger trafficLogger = TrafficLogger.forClient(config);
141142
SendingInterceptor sendingInterceptor = telemetrySupport == null ? null : telemetrySupport.getSendingInterceptor();
142-
WebSocketClientConnectionImpl connection = new WebSocketClientConnectionImpl(clientEndpoint.clientId, ws,
143+
WebSocketClientConnectionImpl connection = new WebSocketClientConnectionImpl(clientEndpoint.clientId,
144+
ws,
143145
codecs,
144146
pathParams,
145-
serverEndpointUri, headers, trafficLogger, userData, sendingInterceptor);
147+
serverEndpointUri,
148+
headers,
149+
trafficLogger,
150+
userData,
151+
sendingInterceptor,
152+
wsOpen.cleanup());
146153
if (trafficLogger != null) {
147154
trafficLogger.connectionOpened(connection);
148155
}

0 commit comments

Comments
 (0)