Skip to content

Commit 83c9ec4

Browse files
committed
Efficient and consistent setAllowedOrigins collection type
Issue: SPR-13761 (cherry picked from commit 3d1ae9c)
1 parent d03d8cb commit 83c9ec4

File tree

7 files changed

+83
-94
lines changed

7 files changed

+83
-94
lines changed

spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616

1717
package org.springframework.web.socket.server.support;
1818

19-
import java.util.ArrayList;
2019
import java.util.Collection;
2120
import java.util.Collections;
22-
import java.util.List;
21+
import java.util.LinkedHashSet;
2322
import java.util.Map;
23+
import java.util.Set;
2424

2525
import org.apache.commons.logging.Log;
2626
import org.apache.commons.logging.LogFactory;
@@ -34,8 +34,8 @@
3434
import org.springframework.web.util.WebUtils;
3535

3636
/**
37-
* An interceptor to check request {@code Origin} header value against a collection of
38-
* allowed origins.
37+
* An interceptor to check request {@code Origin} header value against a
38+
* collection of allowed origins.
3939
*
4040
* @author Sebastien Deleuze
4141
* @since 4.1.2
@@ -44,58 +44,57 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
4444

4545
protected Log logger = LogFactory.getLog(getClass());
4646

47-
private final List<String> allowedOrigins;
47+
private final Set<String> allowedOrigins = new LinkedHashSet<String>();
4848

4949

5050
/**
5151
* Default constructor with only same origin requests allowed.
5252
*/
5353
public OriginHandshakeInterceptor() {
54-
this.allowedOrigins = new ArrayList<String>();
5554
}
5655

5756
/**
5857
* Constructor using the specified allowed origin values.
59-
*
6058
* @see #setAllowedOrigins(Collection)
6159
*/
6260
public OriginHandshakeInterceptor(Collection<String> allowedOrigins) {
63-
this();
6461
setAllowedOrigins(allowedOrigins);
6562
}
6663

64+
6765
/**
68-
* Configure allowed {@code Origin} header values. This check is mostly designed for
69-
* browser clients. There is nothing preventing other types of client to modify the
70-
* {@code Origin} header value.
71-
*
72-
* <p>Each provided allowed origin must start by "http://", "https://" or be "*"
73-
* (means that all origins are allowed).
74-
*
66+
* Configure allowed {@code Origin} header values. This check is mostly
67+
* designed for browsers. There is nothing preventing other types of client
68+
* to modify the {@code Origin} header value.
69+
* <p>Each provided allowed origin must have a scheme, and optionally a port
70+
* (e.g. "http://example.org", "http://example.org:9090"). An allowed origin
71+
* string may also be "*" in which case all origins are allowed.
7572
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
7673
*/
7774
public void setAllowedOrigins(Collection<String> allowedOrigins) {
78-
Assert.notNull(allowedOrigins, "Allowed origin Collection must not be null");
75+
Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null");
7976
this.allowedOrigins.clear();
8077
this.allowedOrigins.addAll(allowedOrigins);
8178
}
8279

8380
/**
84-
* @see #setAllowedOrigins(Collection)
8581
* @since 4.1.5
82+
* @see #setAllowedOrigins
8683
*/
8784
public Collection<String> getAllowedOrigins() {
88-
return Collections.unmodifiableList(this.allowedOrigins);
85+
return Collections.unmodifiableSet(this.allowedOrigins);
8986
}
9087

88+
9189
@Override
9290
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
9391
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
92+
9493
if (!WebUtils.isValidOrigin(request, this.allowedOrigins)) {
9594
response.setStatusCode(HttpStatus.FORBIDDEN);
9695
if (logger.isDebugEnabled()) {
97-
logger.debug("Handshake request rejected, Origin header value "
98-
+ request.getHeaders().getOrigin() + " not allowed");
96+
logger.debug("Handshake request rejected, Origin header value " +
97+
request.getHeaders().getOrigin() + " not allowed");
9998
}
10099
return false;
101100
}

spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818

1919
import java.io.IOException;
2020
import java.nio.charset.Charset;
21-
import java.util.ArrayList;
2221
import java.util.Arrays;
22+
import java.util.Collection;
2323
import java.util.Collections;
2424
import java.util.Date;
2525
import java.util.HashSet;
26+
import java.util.LinkedHashSet;
2627
import java.util.List;
2728
import java.util.Random;
29+
import java.util.Set;
2830
import java.util.concurrent.TimeUnit;
2931

3032
import org.apache.commons.logging.Log;
@@ -53,7 +55,7 @@
5355
* path resolution and handling of static SockJS requests (e.g. "/info", "/iframe.html",
5456
* etc). Sub-classes must handle session URLs (i.e. transport-specific requests).
5557
*
56-
* By default, only same origin requests are allowed. Use {@link #setAllowedOrigins(List)}
58+
* By default, only same origin requests are allowed. Use {@link #setAllowedOrigins}
5759
* to specify a list of allowed origins (a list containing "*" will allow all origins).
5860
*
5961
* @author Rossen Stoyanchev
@@ -91,10 +93,10 @@ public abstract class AbstractSockJsService implements SockJsService {
9193

9294
private boolean webSocketEnabled = true;
9395

94-
private final List<String> allowedOrigins = new ArrayList<String>();
95-
9696
private boolean suppressCors = false;
9797

98+
protected final Set<String> allowedOrigins = new LinkedHashSet<String>();
99+
98100

99101
public AbstractSockJsService(TaskScheduler scheduler) {
100102
Assert.notNull(scheduler, "TaskScheduler must not be null");
@@ -271,6 +273,24 @@ public boolean isWebSocketEnabled() {
271273
return this.webSocketEnabled;
272274
}
273275

276+
/**
277+
* This option can be used to disable automatic addition of CORS headers for
278+
* SockJS requests.
279+
* <p>The default value is "false".
280+
* @since 4.1.2
281+
*/
282+
public void setSuppressCors(boolean suppressCors) {
283+
this.suppressCors = suppressCors;
284+
}
285+
286+
/**
287+
* @since 4.1.2
288+
* @see #setSuppressCors(boolean)
289+
*/
290+
public boolean shouldSuppressCors() {
291+
return this.suppressCors;
292+
}
293+
274294
/**
275295
* Configure allowed {@code Origin} header values. This check is mostly
276296
* designed for browsers. There is nothing preventing other types of client
@@ -286,36 +306,18 @@ public boolean isWebSocketEnabled() {
286306
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
287307
* @see <a href="https://github.com/sockjs/sockjs-client#supported-transports-by-browser-html-served-from-http-or-https">SockJS supported transports by browser</a>
288308
*/
289-
public void setAllowedOrigins(List<String> allowedOrigins) {
290-
Assert.notNull(allowedOrigins, "Allowed origin List must not be null");
309+
public void setAllowedOrigins(Collection<String> allowedOrigins) {
310+
Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null");
291311
this.allowedOrigins.clear();
292312
this.allowedOrigins.addAll(allowedOrigins);
293313
}
294314

295315
/**
296316
* @since 4.1.2
297-
* @see #setAllowedOrigins(List)
298-
*/
299-
public List<String> getAllowedOrigins() {
300-
return Collections.unmodifiableList(this.allowedOrigins);
301-
}
302-
303-
/**
304-
* This option can be used to disable automatic addition of CORS headers for
305-
* SockJS requests.
306-
* <p>The default value is "false".
307-
* @since 4.1.2
317+
* @see #setAllowedOrigins
308318
*/
309-
public void setSuppressCors(boolean suppressCors) {
310-
this.suppressCors = suppressCors;
311-
}
312-
313-
/**
314-
* @since 4.1.2
315-
* @see #setSuppressCors(boolean)
316-
*/
317-
public boolean shouldSuppressCors() {
318-
return this.suppressCors;
319+
public Collection<String> getAllowedOrigins() {
320+
return Collections.unmodifiableSet(this.allowedOrigins);
319321
}
320322

321323

spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ protected boolean validateRequest(String serverId, String sessionId, String tran
292292
return false;
293293
}
294294

295-
if (!getAllowedOrigins().contains("*")) {
295+
if (!this.allowedOrigins.contains("*")) {
296296
TransportType transportType = TransportType.fromValue(transport);
297297
if (transportType == null || !transportType.supportsOrigin()) {
298298
if (logger.isWarnEnabled()) {

spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,11 @@
1818

1919
import java.io.IOException;
2020
import java.io.InputStream;
21-
import java.util.Arrays;
2221
import java.util.Date;
2322
import java.util.List;
2423
import java.util.Map;
2524
import java.util.concurrent.ScheduledFuture;
2625

27-
import static org.junit.Assert.assertEquals;
28-
import org.junit.Before;
2926
import org.junit.Test;
3027

3128
import org.springframework.beans.factory.xml.XmlBeanDefinitionReader;
@@ -77,13 +74,7 @@
7774
*/
7875
public class HandlersBeanDefinitionParserTests {
7976

80-
private GenericWebApplicationContext appContext;
81-
82-
83-
@Before
84-
public void setup() {
85-
this.appContext = new GenericWebApplicationContext();
86-
}
77+
private final GenericWebApplicationContext appContext = new GenericWebApplicationContext();
8778

8879

8980
@Test
@@ -235,10 +226,12 @@ public void sockJsAttributes() {
235226

236227
List<HandshakeInterceptor> interceptors = transportService.getHandshakeInterceptors();
237228
assertThat(interceptors, contains(instanceOf(OriginHandshakeInterceptor.class)));
238-
assertEquals(Arrays.asList("http://mydomain1.com", "http://mydomain2.com"), transportService.getAllowedOrigins());
239229
assertTrue(transportService.shouldSuppressCors());
230+
assertTrue(transportService.getAllowedOrigins().contains("http://mydomain1.com"));
231+
assertTrue(transportService.getAllowedOrigins().contains("http://mydomain2.com"));
240232
}
241233

234+
242235
private void loadBeanDefinitions(String fileName) {
243236
XmlBeanDefinitionReader reader = new XmlBeanDefinitionReader(this.appContext);
244237
ClassPathResource resource = new ClassPathResource(fileName, HandlersBeanDefinitionParserTests.class);
@@ -279,9 +272,11 @@ public boolean supportsPartialMessages() {
279272
}
280273
}
281274

275+
282276
class FooWebSocketHandler extends TestWebSocketHandler {
283277
}
284278

279+
285280
class TestHandshakeHandler implements HandshakeHandler {
286281

287282
@Override
@@ -292,9 +287,11 @@ public boolean doHandshake(ServerHttpRequest request, ServerHttpResponse respons
292287
}
293288
}
294289

290+
295291
class TestChannelInterceptor extends ChannelInterceptorAdapter {
296292
}
297293

294+
298295
class FooTestInterceptor implements HandshakeInterceptor {
299296

300297
@Override
@@ -310,9 +307,11 @@ public void afterHandshake(ServerHttpRequest request, ServerHttpResponse respons
310307
}
311308
}
312309

310+
313311
class BarTestInterceptor extends FooTestInterceptor {
314312
}
315313

314+
316315
@SuppressWarnings({ "unchecked", "rawtypes" })
317316
class TestTaskScheduler implements TaskScheduler {
318317

@@ -345,9 +344,9 @@ public ScheduledFuture scheduleWithFixedDelay(Runnable task, Date startTime, lon
345344
public ScheduledFuture scheduleWithFixedDelay(Runnable task, long delay) {
346345
return null;
347346
}
348-
349347
}
350348

349+
351350
class TestMessageCodec implements SockJsMessageCodec {
352351

353352
@Override
@@ -364,4 +363,4 @@ public String[] decode(String content) throws IOException {
364363
public String[] decodeInputStream(InputStream content) throws IOException {
365364
return new String[0];
366365
}
367-
}
366+
}

0 commit comments

Comments
 (0)