Skip to content

Commit dbb0933

Browse files
committed
Thread-safe access to WebSocketServerFactory and WebSocketExtensions
Closes gh-24745
1 parent 5953d99 commit dbb0933

File tree

2 files changed

+42
-19
lines changed

2 files changed

+42
-19
lines changed

spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2020 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@
2020
import java.lang.reflect.Method;
2121
import java.security.Principal;
2222
import java.util.ArrayList;
23+
import java.util.Collections;
2324
import java.util.List;
2425
import java.util.Map;
2526
import java.util.Set;
@@ -39,6 +40,7 @@
3940
import org.springframework.http.server.ServerHttpResponse;
4041
import org.springframework.http.server.ServletServerHttpRequest;
4142
import org.springframework.http.server.ServletServerHttpResponse;
43+
import org.springframework.lang.Nullable;
4244
import org.springframework.util.Assert;
4345
import org.springframework.util.ClassUtils;
4446
import org.springframework.util.CollectionUtils;
@@ -67,15 +69,18 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Serv
6769
private static final ThreadLocal<WebSocketHandlerContainer> containerHolder =
6870
new NamedThreadLocal<>("WebSocketHandlerContainer");
6971

70-
72+
@Nullable
7173
private WebSocketPolicy policy;
7274

73-
private WebSocketServerFactory factory;
75+
@Nullable
76+
private volatile WebSocketServerFactory factory;
7477

78+
@Nullable
7579
private ServletContext servletContext;
7680

7781
private volatile boolean running = false;
7882

83+
@Nullable
7984
private volatile List<WebSocketExtension> supportedExtensions;
8085

8186

@@ -118,17 +123,20 @@ public void start() {
118123
if (!isRunning()) {
119124
this.running = true;
120125
try {
121-
if (this.factory == null) {
122-
this.factory = new WebSocketServerFactory(this.servletContext, this.policy);
126+
WebSocketServerFactory factory = this.factory;
127+
if (factory == null) {
128+
Assert.state(this.servletContext != null, "No ServletContext set");
129+
factory = new WebSocketServerFactory(this.servletContext, this.policy);
130+
this.factory = factory;
123131
}
124-
this.factory.setCreator((request, response) -> {
132+
factory.setCreator((request, response) -> {
125133
WebSocketHandlerContainer container = containerHolder.get();
126134
Assert.state(container != null, "Expected WebSocketHandlerContainer");
127135
response.setAcceptedSubProtocol(container.getSelectedProtocol());
128136
response.setExtensions(container.getExtensionConfigs());
129137
return container.getHandler();
130138
});
131-
this.factory.start();
139+
factory.start();
132140
}
133141
catch (Throwable ex) {
134142
throw new IllegalStateException("Unable to start Jetty WebSocketServerFactory", ex);
@@ -140,9 +148,10 @@ public void start() {
140148
public void stop() {
141149
if (isRunning()) {
142150
this.running = false;
143-
if (this.factory != null) {
151+
WebSocketServerFactory factory = this.factory;
152+
if (factory != null) {
144153
try {
145-
this.factory.stop();
154+
factory.stop();
146155
}
147156
catch (Throwable ex) {
148157
throw new IllegalStateException("Unable to stop Jetty WebSocketServerFactory", ex);
@@ -164,10 +173,12 @@ public String[] getSupportedVersions() {
164173

165174
@Override
166175
public List<WebSocketExtension> getSupportedExtensions(ServerHttpRequest request) {
167-
if (this.supportedExtensions == null) {
168-
this.supportedExtensions = buildWebSocketExtensions();
176+
List<WebSocketExtension> extensions = this.supportedExtensions;
177+
if (extensions == null) {
178+
extensions = buildWebSocketExtensions();
179+
this.supportedExtensions = extensions;
169180
}
170-
return this.supportedExtensions;
181+
return extensions;
171182
}
172183

173184
private List<WebSocketExtension> buildWebSocketExtensions() {
@@ -181,22 +192,25 @@ private List<WebSocketExtension> buildWebSocketExtensions() {
181192

182193
@SuppressWarnings({"unchecked", "deprecation"})
183194
private Set<String> getExtensionNames() {
195+
WebSocketServerFactory factory = this.factory;
196+
Assert.state(factory != null, "No WebSocketServerFactory available");
184197
try {
185-
return this.factory.getAvailableExtensionNames();
198+
return factory.getAvailableExtensionNames();
186199
}
187200
catch (IncompatibleClassChangeError ex) {
188201
// Fallback for versions prior to 9.4.21:
189202
// 9.4.20.v20190813: ExtensionFactory (abstract class -> interface)
190203
// 9.4.21.v20190926: ExtensionFactory (interface -> abstract class) + deprecated
191204
Class<?> clazz = org.eclipse.jetty.websocket.api.extensions.ExtensionFactory.class;
192205
Method method = ClassUtils.getMethod(clazz, "getExtensionNames");
193-
return (Set<String>) ReflectionUtils.invokeMethod(method, this.factory.getExtensionFactory());
206+
Set<String> result = (Set<String>) ReflectionUtils.invokeMethod(method, factory.getExtensionFactory());
207+
return (result != null ? result : Collections.emptySet());
194208
}
195209
}
196210

197211
@Override
198212
public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
199-
String selectedProtocol, List<WebSocketExtension> selectedExtensions, Principal user,
213+
@Nullable String selectedProtocol, List<WebSocketExtension> selectedExtensions, @Nullable Principal user,
200214
WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException {
201215

202216
Assert.isInstanceOf(ServletServerHttpRequest.class, request, "ServletServerHttpRequest required");
@@ -205,7 +219,9 @@ public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
205219
Assert.isInstanceOf(ServletServerHttpResponse.class, response, "ServletServerHttpResponse required");
206220
HttpServletResponse servletResponse = ((ServletServerHttpResponse) response).getServletResponse();
207221

208-
Assert.isTrue(this.factory.isUpgradeRequest(servletRequest, servletResponse), "Not a WebSocket handshake");
222+
WebSocketServerFactory factory = this.factory;
223+
Assert.state(factory != null, "No WebSocketServerFactory available");
224+
Assert.isTrue(factory.isUpgradeRequest(servletRequest, servletResponse), "Not a WebSocket handshake");
209225

210226
JettyWebSocketSession session = new JettyWebSocketSession(attributes, user);
211227
JettyWebSocketHandlerAdapter handlerAdapter = new JettyWebSocketHandlerAdapter(wsHandler, session);
@@ -215,7 +231,7 @@ public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
215231

216232
try {
217233
containerHolder.set(container);
218-
this.factory.acceptWebSocket(servletRequest, servletResponse);
234+
factory.acceptWebSocket(servletRequest, servletResponse);
219235
}
220236
catch (IOException ex) {
221237
throw new HandshakeFailureException(
@@ -231,12 +247,13 @@ private static class WebSocketHandlerContainer {
231247

232248
private final JettyWebSocketHandlerAdapter handler;
233249

250+
@Nullable
234251
private final String selectedProtocol;
235252

236253
private final List<ExtensionConfig> extensionConfigs;
237254

238-
public WebSocketHandlerContainer(
239-
JettyWebSocketHandlerAdapter handler, String protocol, List<WebSocketExtension> extensions) {
255+
public WebSocketHandlerContainer(JettyWebSocketHandlerAdapter handler,
256+
@Nullable String protocol, List<WebSocketExtension> extensions) {
240257

241258
this.handler = handler;
242259
this.selectedProtocol = protocol;
@@ -255,6 +272,7 @@ public JettyWebSocketHandlerAdapter getHandler() {
255272
return this.handler;
256273
}
257274

275+
@Nullable
258276
public String getSelectedProtocol() {
259277
return this.selectedProtocol;
260278
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
/**
22
* Server-side support for the Jetty 9+ WebSocket API.
33
*/
4+
@NonNullApi
5+
@NonNullFields
46
package org.springframework.web.socket.server.jetty;
7+
8+
import org.springframework.lang.NonNullApi;
9+
import org.springframework.lang.NonNullFields;

0 commit comments

Comments
 (0)