Skip to content

Commit 6b76a56

Browse files
authored
Merge pull request #48254 from mkouba/wsnext-client-conn-cleanup
WebSockets Next: perform client connection cleanup during app shutdown
2 parents 61a7aeb + f5ed6c3 commit 6b76a56

File tree

7 files changed

+204
-23
lines changed

7 files changed

+204
-23
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package io.quarkus.websockets.next.test.connection;
2+
3+
import static org.junit.jupiter.api.Assertions.assertFalse;
4+
import static org.junit.jupiter.api.Assertions.assertTrue;
5+
6+
import java.net.URI;
7+
import java.util.concurrent.CountDownLatch;
8+
import java.util.concurrent.ExecutorService;
9+
import java.util.concurrent.Executors;
10+
import java.util.concurrent.TimeUnit;
11+
import java.util.concurrent.atomic.AtomicBoolean;
12+
13+
import jakarta.inject.Inject;
14+
15+
import org.junit.jupiter.api.Test;
16+
import org.junit.jupiter.api.extension.RegisterExtension;
17+
18+
import io.quarkus.test.QuarkusUnitTest;
19+
import io.quarkus.test.common.http.TestHTTPResource;
20+
import io.quarkus.websockets.next.OnClose;
21+
import io.quarkus.websockets.next.OnOpen;
22+
import io.quarkus.websockets.next.OnTextMessage;
23+
import io.quarkus.websockets.next.WebSocket;
24+
import io.quarkus.websockets.next.WebSocketClient;
25+
import io.quarkus.websockets.next.WebSocketClientConnection;
26+
import io.quarkus.websockets.next.WebSocketConnector;
27+
import io.quarkus.websockets.next.test.utils.WSClient;
28+
29+
public class ConnectionIdleTimeoutTest {
30+
31+
@RegisterExtension
32+
public static final QuarkusUnitTest test = new QuarkusUnitTest()
33+
.withApplicationRoot(root -> {
34+
root.addClasses(ServerEndpoint.class, ClientEndpoint.class, WSClient.class);
35+
}).overrideConfigKey("quarkus.websockets-next.client.connection-idle-timeout", "500ms");;
36+
37+
@TestHTTPResource("/")
38+
URI uri;
39+
40+
@Inject
41+
WebSocketConnector<ClientEndpoint> connector;
42+
43+
@Test
44+
public void testTimeout() throws InterruptedException {
45+
WebSocketClientConnection conn = connector.baseUri(uri.toString()).connectAndAwait();
46+
ExecutorService executor = Executors.newSingleThreadExecutor();
47+
try {
48+
TimeUnit.MILLISECONDS.sleep(500);
49+
executor.execute(() -> {
50+
try {
51+
conn.sendTextAndAwait("ok");
52+
} catch (Throwable ignored) {
53+
}
54+
});
55+
} finally {
56+
executor.shutdownNow();
57+
}
58+
assertTrue(ServerEndpoint.CLOSED.await(5, TimeUnit.SECONDS));
59+
assertTrue(ClientEndpoint.CLOSED.await(5, TimeUnit.SECONDS));
60+
assertFalse(ServerEndpoint.MESSAGE.get());
61+
}
62+
63+
@WebSocket(path = "/end")
64+
public static class ServerEndpoint {
65+
66+
static final CountDownLatch CLOSED = new CountDownLatch(1);
67+
static final AtomicBoolean MESSAGE = new AtomicBoolean();
68+
69+
@OnTextMessage
70+
void onText(String message) {
71+
MESSAGE.set(true);
72+
}
73+
74+
@OnClose
75+
void close() {
76+
CLOSED.countDown();
77+
}
78+
79+
}
80+
81+
@WebSocketClient(path = "/end")
82+
public static class ClientEndpoint {
83+
84+
static final CountDownLatch CLOSED = new CountDownLatch(1);
85+
86+
@OnOpen
87+
void open() {
88+
}
89+
90+
@OnClose
91+
void close(WebSocketClientConnection conn) {
92+
CLOSED.countDown();
93+
}
94+
95+
}
96+
}

extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/BasicWebSocketConnectorImpl.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ public Uni<WebSocketClientConnection> connect() {
155155
throw new WebSocketClientException(e);
156156
}
157157

158-
Uni<WebSocket> websocket = Uni.createFrom().<WebSocket> emitter(e -> {
158+
Uni<WebSocketOpen> websocketOpen = Uni.createFrom().<WebSocketOpen> emitter(e -> {
159159
// Create a new event loop context for each client, otherwise the current context is used
160160
// We want to avoid a situation where if multiple clients/connections are created in a row,
161161
// the same event loop is used and so writing/receiving messages is de-facto serialized
@@ -171,7 +171,7 @@ public void handle(Void event) {
171171
@Override
172172
public void handle(AsyncResult<WebSocket> r) {
173173
if (r.succeeded()) {
174-
e.complete(r.result());
174+
e.complete(new WebSocketOpen(newCleanupConsumer(c, context), r.result()));
175175
} else {
176176
e.fail(r.cause());
177177
}
@@ -183,14 +183,20 @@ public void handle(AsyncResult<WebSocket> r) {
183183
}
184184
});
185185
});
186-
return websocket.map(ws -> {
186+
return websocketOpen.map(wsOpen -> {
187+
WebSocket ws = wsOpen.websocket();
187188
String clientId = BasicWebSocketConnector.class.getName();
188189
TrafficLogger trafficLogger = TrafficLogger.forClient(config);
189-
WebSocketClientConnectionImpl connection = new WebSocketClientConnectionImpl(clientId, ws,
190+
WebSocketClientConnectionImpl connection = new WebSocketClientConnectionImpl(clientId,
191+
ws,
190192
codecs,
191193
pathParams,
192194
serverEndpointUri,
193-
headers, trafficLogger, userData, null);
195+
headers,
196+
trafficLogger,
197+
userData,
198+
null,
199+
wsOpen.cleanup());
194200
if (trafficLogger != null) {
195201
trafficLogger.connectionOpened(connection);
196202
}

extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ClientConnectionManager.java

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,21 @@
22

33
import java.util.Iterator;
44
import java.util.List;
5+
import java.util.Map.Entry;
56
import java.util.Set;
67
import java.util.concurrent.ConcurrentHashMap;
78
import java.util.concurrent.ConcurrentMap;
89
import java.util.concurrent.CopyOnWriteArrayList;
910
import java.util.stream.Stream;
1011

11-
import jakarta.annotation.PreDestroy;
1212
import jakarta.enterprise.event.Event;
1313
import jakarta.inject.Singleton;
1414

1515
import org.jboss.logging.Logger;
1616

1717
import io.quarkus.arc.Arc;
1818
import io.quarkus.arc.ArcContainer;
19+
import io.quarkus.runtime.Shutdown;
1920
import io.quarkus.websockets.next.Closed;
2021
import io.quarkus.websockets.next.Open;
2122
import io.quarkus.websockets.next.OpenClientConnections;
@@ -26,7 +27,7 @@ public class ClientConnectionManager implements OpenClientConnections {
2627

2728
private static final Logger LOG = Logger.getLogger(ClientConnectionManager.class);
2829

29-
private final ConcurrentMap<String, Set<WebSocketClientConnection>> endpointToConnections = new ConcurrentHashMap<>();
30+
private final ConcurrentMap<String, Set<WebSocketClientConnectionImpl>> endpointToConnections = new ConcurrentHashMap<>();
3031

3132
private final List<ClientConnectionListener> listeners = new CopyOnWriteArrayList<>();
3233

@@ -50,10 +51,11 @@ public Iterator<WebSocketClientConnection> iterator() {
5051

5152
@Override
5253
public Stream<WebSocketClientConnection> stream() {
53-
return endpointToConnections.values().stream().flatMap(Set::stream).filter(WebSocketClientConnection::isOpen);
54+
return endpointToConnections.values().stream().flatMap(Set::stream).filter(WebSocketClientConnection::isOpen)
55+
.map(WebSocketClientConnection.class::cast);
5456
}
5557

56-
void add(String endpoint, WebSocketClientConnection connection) {
58+
void add(String endpoint, WebSocketClientConnectionImpl connection) {
5759
LOG.debugf("Add client connection: %s", connection);
5860
if (endpointToConnections.computeIfAbsent(endpoint, e -> ConcurrentHashMap.newKeySet()).add(connection)) {
5961
if (openEvent != null) {
@@ -72,9 +74,9 @@ void add(String endpoint, WebSocketClientConnection connection) {
7274
}
7375
}
7476

75-
void remove(String endpoint, WebSocketClientConnection connection) {
77+
void remove(String endpoint, WebSocketClientConnectionImpl connection) {
7678
LOG.debugf("Remove client connection: %s", connection);
77-
Set<WebSocketClientConnection> connections = endpointToConnections.get(endpoint);
79+
Set<WebSocketClientConnectionImpl> connections = endpointToConnections.get(endpoint);
7880
if (connections != null) {
7981
if (connections.remove(connection)) {
8082
if (closedEvent != null) {
@@ -99,8 +101,8 @@ void remove(String endpoint, WebSocketClientConnection connection) {
99101
* @param endpoint
100102
* @return the connections for the given client endpoint, never {@code null}
101103
*/
102-
public Set<WebSocketClientConnection> getConnections(String endpoint) {
103-
Set<WebSocketClientConnection> ret = endpointToConnections.get(endpoint);
104+
public Set<WebSocketClientConnectionImpl> getConnections(String endpoint) {
105+
Set<WebSocketClientConnectionImpl> ret = endpointToConnections.get(endpoint);
104106
if (ret == null) {
105107
return Set.of();
106108
}
@@ -111,9 +113,19 @@ public void addListener(ClientConnectionListener listener) {
111113
this.listeners.add(listener);
112114
}
113115

114-
@PreDestroy
115-
void destroy() {
116-
endpointToConnections.clear();
116+
@Shutdown
117+
void cleanup() {
118+
if (!endpointToConnections.isEmpty()) {
119+
int sum = 0;
120+
for (Entry<String, Set<WebSocketClientConnectionImpl>> e : endpointToConnections.entrySet()) {
121+
for (WebSocketClientConnectionImpl c : e.getValue()) {
122+
c.cleanup();
123+
sum++;
124+
}
125+
}
126+
LOG.debugf("Cleanup performed for %s connections", sum);
127+
endpointToConnections.clear();
128+
}
117129
}
118130

119131
public interface ClientConnectionListener {

extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketClientConnectionImpl.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import java.util.Map;
77
import java.util.Map.Entry;
88
import java.util.Objects;
9+
import java.util.function.Consumer;
910

1011
import io.quarkus.websockets.next.HandshakeRequest;
1112
import io.quarkus.websockets.next.WebSocketClientConnection;
@@ -19,13 +20,17 @@ class WebSocketClientConnectionImpl extends WebSocketConnectionBase implements W
1920

2021
private final WebSocket webSocket;
2122

23+
private final Consumer<WebSocketClientConnection> cleanup;
24+
2225
WebSocketClientConnectionImpl(String clientId, WebSocket webSocket, Codecs codecs,
2326
Map<String, String> pathParams, URI serverEndpointUri, Map<String, List<String>> headers,
24-
TrafficLogger trafficLogger, Map<String, Object> userData, SendingInterceptor sendingInterceptor) {
27+
TrafficLogger trafficLogger, Map<String, Object> userData, SendingInterceptor sendingInterceptor,
28+
Consumer<WebSocketClientConnection> cleanup) {
2529
super(Map.copyOf(pathParams), codecs, new ClientHandshakeRequestImpl(serverEndpointUri, headers), trafficLogger,
2630
new UserDataImpl(userData), sendingInterceptor);
2731
this.clientId = clientId;
2832
this.webSocket = Objects.requireNonNull(webSocket);
33+
this.cleanup = cleanup;
2934
}
3035

3136
@Override
@@ -48,6 +53,12 @@ public int hashCode() {
4853
return Objects.hash(identifier);
4954
}
5055

56+
protected void cleanup() {
57+
if (cleanup != null) {
58+
cleanup.accept(this);
59+
}
60+
}
61+
5162
@Override
5263
public boolean equals(Object obj) {
5364
if (this == obj)

extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorBase.java

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import java.net.URI;
44
import java.net.URLEncoder;
55
import java.nio.charset.StandardCharsets;
6+
import java.time.Duration;
67
import java.util.ArrayList;
78
import java.util.HashMap;
89
import java.util.HashSet;
@@ -11,23 +12,33 @@
1112
import java.util.Objects;
1213
import java.util.Optional;
1314
import java.util.Set;
15+
import java.util.concurrent.TimeUnit;
16+
import java.util.function.Consumer;
1417
import java.util.regex.Matcher;
1518
import java.util.regex.Pattern;
1619

20+
import org.jboss.logging.Logger;
21+
1722
import io.quarkus.tls.TlsConfiguration;
1823
import io.quarkus.tls.TlsConfigurationRegistry;
1924
import io.quarkus.tls.runtime.config.TlsConfigUtils;
2025
import io.quarkus.websockets.next.UserData.TypedKey;
26+
import io.quarkus.websockets.next.WebSocketClientConnection;
2127
import io.quarkus.websockets.next.WebSocketClientException;
2228
import io.quarkus.websockets.next.runtime.config.WebSocketsClientRuntimeConfig;
2329
import io.vertx.core.Vertx;
30+
import io.vertx.core.http.WebSocket;
31+
import io.vertx.core.http.WebSocketClient;
2432
import io.vertx.core.http.WebSocketClientOptions;
2533
import io.vertx.core.http.WebSocketConnectOptions;
34+
import io.vertx.core.impl.ContextImpl;
2635

2736
abstract class WebSocketConnectorBase<THIS extends WebSocketConnectorBase<THIS>> {
2837

2938
protected static final Pattern PATH_PARAM_PATTERN = Pattern.compile("\\{[a-zA-Z0-9_]+\\}");
3039

40+
private static final Logger LOG = Logger.getLogger(WebSocketConnectorBase.class);
41+
3142
// mutable state
3243

3344
protected URI baseUri;
@@ -172,7 +183,16 @@ protected WebSocketClientOptions populateClientOptions() {
172183
if (config.maxFrameSize().isPresent()) {
173184
clientOptions.setMaxFrameSize(config.maxFrameSize().getAsInt());
174185
}
175-
186+
if (config.connectionIdleTimeout().isPresent()) {
187+
Duration timeout = config.connectionIdleTimeout().get();
188+
if (timeout.toMillis() > Integer.MAX_VALUE) {
189+
clientOptions.setIdleTimeoutUnit(TimeUnit.SECONDS);
190+
clientOptions.setIdleTimeout((int) timeout.toSeconds());
191+
} else {
192+
clientOptions.setIdleTimeoutUnit(TimeUnit.MILLISECONDS);
193+
clientOptions.setIdleTimeout((int) timeout.toMillis());
194+
}
195+
}
176196
Optional<TlsConfiguration> maybeTlsConfiguration = TlsConfiguration.from(tlsConfigurationRegistry,
177197
Optional.ofNullable(tlsConfigurationName));
178198
if (maybeTlsConfiguration.isEmpty()) {
@@ -201,4 +221,28 @@ protected WebSocketConnectOptions newConnectOptions(URI serverEndpointUri) {
201221
protected boolean isSecure(URI uri) {
202222
return "https".equals(uri.getScheme()) || "wss".equals(uri.getScheme());
203223
}
224+
225+
record WebSocketOpen(Consumer<WebSocketClientConnection> cleanup, WebSocket websocket) {
226+
}
227+
228+
Consumer<WebSocketClientConnection> newCleanupConsumer(WebSocketClient client, ContextImpl context) {
229+
return new Consumer<WebSocketClientConnection>() {
230+
@Override
231+
public void accept(WebSocketClientConnection conn) {
232+
try {
233+
client.close();
234+
LOG.debugf("Client closed for connection %s", conn.id());
235+
} catch (Throwable e) {
236+
LOG.errorf(e, "Unable to close the client for connection %s", conn.id());
237+
}
238+
try {
239+
context.close();
240+
LOG.debugf("Context closed for connection %s", conn.id());
241+
} catch (Throwable e) {
242+
LOG.errorf(e, "Unable to close the context for connection %s", conn.id());
243+
}
244+
}
245+
};
246+
}
247+
204248
}

0 commit comments

Comments
 (0)