Skip to content

Commit c306df9

Browse files
author
Steve Riesenberg
committed
Add XorCsrfChannelInterceptor
Issue gh-12378
1 parent d42405d commit c306df9

File tree

12 files changed

+501
-58
lines changed

12 files changed

+501
-58
lines changed

config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2022 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -56,6 +56,8 @@ final class WebSocketMessageBrokerSecurityConfiguration
5656

5757
private static final String SIMPLE_URL_HANDLER_MAPPING_BEAN_NAME = "stompWebSocketHandlerMapping";
5858

59+
private static final String CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME = "csrfChannelInterceptor";
60+
5961
private MessageMatcherDelegatingAuthorizationManager b;
6062

6163
private static final AuthorizationManager<Message<?>> ANY_MESSAGE_AUTHENTICATED = MessageMatcherDelegatingAuthorizationManager
@@ -66,7 +68,7 @@ final class WebSocketMessageBrokerSecurityConfiguration
6668

6769
private final SecurityContextChannelInterceptor securityContextChannelInterceptor = new SecurityContextChannelInterceptor();
6870

69-
private final ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor();
71+
private ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor();
7072

7173
private AuthorizationChannelInterceptor authorizationChannelInterceptor = new AuthorizationChannelInterceptor(
7274
ANY_MESSAGE_AUTHENTICATED);
@@ -86,6 +88,12 @@ public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentRes
8688

8789
@Override
8890
public void configureClientInboundChannel(ChannelRegistration registration) {
91+
ChannelInterceptor csrfChannelInterceptor = getBeanOrNull(CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME,
92+
ChannelInterceptor.class);
93+
if (csrfChannelInterceptor != null) {
94+
this.csrfChannelInterceptor = csrfChannelInterceptor;
95+
}
96+
8997
this.authorizationChannelInterceptor
9098
.setAuthorizationEventPublisher(new SpringAuthorizationEventPublisher(this.context));
9199
this.authorizationChannelInterceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);

config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -61,6 +61,7 @@
6161
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
6262
import org.springframework.security.web.csrf.CsrfToken;
6363
import org.springframework.security.web.csrf.DefaultCsrfToken;
64+
import org.springframework.security.web.csrf.DeferredCsrfToken;
6465
import org.springframework.security.web.csrf.MissingCsrfTokenException;
6566
import org.springframework.stereotype.Controller;
6667
import org.springframework.test.util.ReflectionTestUtils;
@@ -79,6 +80,7 @@
7980

8081
import static org.assertj.core.api.Assertions.assertThat;
8182
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
83+
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
8284

8385
public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
8486

@@ -284,7 +286,7 @@ public void inboundChannelSecurityDefinedByBean() {
284286

285287
private void assertHandshake(HttpServletRequest request) {
286288
TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class);
287-
assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token);
289+
assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token);
288290
assertThat(handshakeHandler.attributes.get(this.sessionAttr))
289291
.isEqualTo(request.getSession().getAttribute(this.sessionAttr));
290292
}
@@ -306,7 +308,7 @@ private MockHttpServletRequest sockjsHttpRequest(String mapping) {
306308
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket");
307309
request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
308310
request.getSession().setAttribute(this.sessionAttr, "sessionValue");
309-
request.setAttribute(CsrfToken.class.getName(), this.token);
311+
request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token));
310312
return request;
311313
}
312314

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright 2002-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.config.annotation.web.socket;
18+
19+
import org.springframework.security.web.csrf.CsrfToken;
20+
import org.springframework.security.web.csrf.DeferredCsrfToken;
21+
22+
/**
23+
* @author Steve Riesenberg
24+
*/
25+
final class TestDeferredCsrfToken implements DeferredCsrfToken {
26+
27+
private final CsrfToken csrfToken;
28+
29+
TestDeferredCsrfToken(CsrfToken csrfToken) {
30+
this.csrfToken = csrfToken;
31+
}
32+
33+
@Override
34+
public CsrfToken get() {
35+
return this.csrfToken;
36+
}
37+
38+
@Override
39+
public boolean isGenerated() {
40+
return false;
41+
}
42+
43+
}

config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2022 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -70,6 +70,7 @@
7070
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
7171
import org.springframework.security.web.csrf.CsrfToken;
7272
import org.springframework.security.web.csrf.DefaultCsrfToken;
73+
import org.springframework.security.web.csrf.DeferredCsrfToken;
7374
import org.springframework.security.web.csrf.MissingCsrfTokenException;
7475
import org.springframework.stereotype.Controller;
7576
import org.springframework.test.util.ReflectionTestUtils;
@@ -92,6 +93,7 @@
9293
import static org.assertj.core.api.Assertions.fail;
9394
import static org.mockito.Mockito.atLeastOnce;
9495
import static org.mockito.Mockito.verify;
96+
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
9597

9698
public class WebSocketMessageBrokerSecurityConfigurationTests {
9799

@@ -367,7 +369,7 @@ public void sendMessageWhenAnonymousConfiguredAndLoggedInUserThenAccessDeniedExc
367369

368370
private void assertHandshake(HttpServletRequest request) {
369371
TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class);
370-
assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token);
372+
assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token);
371373
assertThat(handshakeHandler.attributes.get(this.sessionAttr))
372374
.isEqualTo(request.getSession().getAttribute(this.sessionAttr));
373375
}
@@ -389,7 +391,7 @@ private MockHttpServletRequest sockjsHttpRequest(String mapping) {
389391
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket");
390392
request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
391393
request.getSession().setAttribute(this.sessionAttr, "sessionValue");
392-
request.setAttribute(CsrfToken.class.getName(), this.token);
394+
request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token));
393395
return request;
394396
}
395397

config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2022 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -61,6 +61,7 @@
6161
import org.springframework.security.test.context.support.WithMockUser;
6262
import org.springframework.security.web.csrf.CsrfToken;
6363
import org.springframework.security.web.csrf.DefaultCsrfToken;
64+
import org.springframework.security.web.csrf.DeferredCsrfToken;
6465
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
6566
import org.springframework.stereotype.Controller;
6667
import org.springframework.test.context.junit.jupiter.SpringExtension;
@@ -77,6 +78,7 @@
7778
import static org.mockito.ArgumentMatchers.any;
7879
import static org.mockito.BDDMockito.given;
7980
import static org.mockito.Mockito.verify;
81+
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
8082
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
8183

8284
/**
@@ -381,12 +383,14 @@ public void requestWhenConnectMessageThenUsesCsrfTokenHandshakeInterceptor() thr
381383
MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build();
382384
String csrfAttributeName = CsrfToken.class.getName();
383385
String customAttributeName = this.getClass().getName();
384-
MvcResult result = mvc.perform(get("/app").requestAttr(csrfAttributeName, this.token)
385-
.sessionAttr(customAttributeName, "attributeValue")).andReturn();
386+
MvcResult result = mvc.perform(
387+
get("/app").requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token))
388+
.sessionAttr(customAttributeName, "attributeValue"))
389+
.andReturn();
386390
CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName);
387391
String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName);
388392
String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName);
389-
assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
393+
assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
390394
assertThat(handshakeValue).isEqualTo(sessionValue)
391395
.withFailMessage("Explicitly listed session variables are not overridden");
392396
}
@@ -398,12 +402,13 @@ public void requestWhenConnectMessageAndUsingSockJsThenUsesCsrfTokenHandshakeInt
398402
MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build();
399403
String csrfAttributeName = CsrfToken.class.getName();
400404
String customAttributeName = this.getClass().getName();
401-
MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket").requestAttr(csrfAttributeName, this.token)
405+
MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket")
406+
.requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token))
402407
.sessionAttr(customAttributeName, "attributeValue")).andReturn();
403408
CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName);
404409
String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName);
405410
String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName);
406-
assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
411+
assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
407412
assertThat(handshakeValue).isEqualTo(sessionValue)
408413
.withFailMessage("Explicitly listed session variables are not overridden");
409414
}
@@ -526,6 +531,26 @@ private SecurityContextHolderStrategy getSecurityContextHolderStrategy() {
526531
return SecurityContextHolder.getContextHolderStrategy();
527532
}
528533

534+
private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
535+
536+
private final CsrfToken csrfToken;
537+
538+
TestDeferredCsrfToken(CsrfToken csrfToken) {
539+
this.csrfToken = csrfToken;
540+
}
541+
542+
@Override
543+
public CsrfToken get() {
544+
return this.csrfToken;
545+
}
546+
547+
@Override
548+
public boolean isGenerated() {
549+
return false;
550+
}
551+
552+
}
553+
529554
@Controller
530555
static class MessageController {
531556

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright 2002-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.messaging.web.csrf;
18+
19+
import java.security.MessageDigest;
20+
import java.util.Map;
21+
22+
import org.springframework.messaging.Message;
23+
import org.springframework.messaging.MessageChannel;
24+
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
25+
import org.springframework.messaging.simp.SimpMessageType;
26+
import org.springframework.messaging.support.ChannelInterceptor;
27+
import org.springframework.security.crypto.codec.Utf8;
28+
import org.springframework.security.messaging.util.matcher.MessageMatcher;
29+
import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher;
30+
import org.springframework.security.web.csrf.CsrfToken;
31+
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
32+
import org.springframework.security.web.csrf.MissingCsrfTokenException;
33+
34+
/**
35+
* {@link ChannelInterceptor} that validates a CSRF token masked by the
36+
* {@link org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler} in
37+
* the header of any {@link SimpMessageType#CONNECT} message.
38+
*
39+
* @author Steve Riesenberg
40+
* @since 5.8
41+
*/
42+
public final class XorCsrfChannelInterceptor implements ChannelInterceptor {
43+
44+
private final MessageMatcher<Object> matcher = new SimpMessageTypeMatcher(SimpMessageType.CONNECT);
45+
46+
@Override
47+
public Message<?> preSend(Message<?> message, MessageChannel channel) {
48+
if (!this.matcher.matches(message)) {
49+
return message;
50+
}
51+
Map<String, Object> sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders());
52+
CsrfToken expectedToken = (sessionAttributes != null)
53+
? (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()) : null;
54+
if (expectedToken == null) {
55+
throw new MissingCsrfTokenException(null);
56+
}
57+
String actualToken = SimpMessageHeaderAccessor.wrap(message)
58+
.getFirstNativeHeader(expectedToken.getHeaderName());
59+
String actualTokenValue = XorCsrfTokenUtils.getTokenValue(actualToken, expectedToken.getToken());
60+
boolean csrfCheckPassed = equalsConstantTime(expectedToken.getToken(), actualTokenValue);
61+
if (!csrfCheckPassed) {
62+
throw new InvalidCsrfTokenException(expectedToken, actualToken);
63+
}
64+
return message;
65+
}
66+
67+
/**
68+
* Constant time comparison to prevent against timing attacks.
69+
* @param expected
70+
* @param actual
71+
* @return
72+
*/
73+
private static boolean equalsConstantTime(String expected, String actual) {
74+
if (expected == actual) {
75+
return true;
76+
}
77+
if (expected == null || actual == null) {
78+
return false;
79+
}
80+
// Encode after ensure that the string is not null
81+
byte[] expectedBytes = Utf8.encode(expected);
82+
byte[] actualBytes = Utf8.encode(actual);
83+
return MessageDigest.isEqual(expectedBytes, actualBytes);
84+
}
85+
86+
}

0 commit comments

Comments
 (0)