Skip to content

Commit 440748e

Browse files
author
Steve Riesenberg
committed
Add test support for Xor CSRF tokens
Issue gh-4001
1 parent 8bd25f9 commit 440748e

File tree

5 files changed

+50
-25
lines changed

5 files changed

+50
-25
lines changed

config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.springframework.security.web.access.AccessDeniedHandler;
4242
import org.springframework.security.web.csrf.CsrfFilter;
4343
import org.springframework.security.web.csrf.CsrfToken;
44+
import org.springframework.security.web.csrf.CsrfTokenRepository;
4445
import org.springframework.security.web.util.matcher.RequestMatcher;
4546
import org.springframework.stereotype.Controller;
4647
import org.springframework.test.context.junit.jupiter.SpringExtension;
@@ -301,33 +302,35 @@ public void getWhenUsingCsrfAndCustomRequestAttributeThenSetUsingCsrfAttrName()
301302
}
302303

303304
@Test
304-
public void postWhenUsingCsrfAndXorCsrfTokenRequestProcessorThenOk() throws Exception {
305+
public void postWhenUsingCsrfAndXorCsrfTokenRequestAttributeHandlerThenOk() throws Exception {
305306
this.spring.configLocations(this.xml("WithXorCsrfTokenRequestAttributeHandler"), this.xml("shared-controllers"))
306307
.autowire();
307308
// @formatter:off
308309
MvcResult mvcResult = this.mvc.perform(get("/ok"))
309310
.andExpect(status().isOk())
310311
.andReturn();
311312
MockHttpSession session = (MockHttpSession) mvcResult.getRequest().getSession();
312-
CsrfToken csrfToken = (CsrfToken) mvcResult.getRequest().getAttribute("_csrf");
313313
MockHttpServletRequestBuilder ok = post("/ok")
314-
.header(csrfToken.getHeaderName(), csrfToken.getToken())
314+
.with(csrf())
315315
.session(session);
316316
this.mvc.perform(ok).andExpect(status().isOk());
317317
// @formatter:on
318318
}
319319

320320
@Test
321-
public void postWhenUsingCsrfAndXorCsrfTokenRequestProcessorWithRawTokenThenForbidden() throws Exception {
321+
public void postWhenUsingCsrfAndXorCsrfTokenRequestAttributeHandlerWithRawTokenThenForbidden() throws Exception {
322322
this.spring.configLocations(this.xml("WithXorCsrfTokenRequestAttributeHandler"), this.xml("shared-controllers"))
323323
.autowire();
324324
// @formatter:off
325-
MvcResult mvcResult = this.mvc.perform(get("/ok"))
325+
MvcResult mvcResult = this.mvc.perform(get("/csrf"))
326326
.andExpect(status().isOk())
327327
.andReturn();
328-
MockHttpSession session = (MockHttpSession) mvcResult.getRequest().getSession();
328+
MockHttpServletRequest request = mvcResult.getRequest();
329+
MockHttpSession session = (MockHttpSession) request.getSession();
330+
CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
331+
CsrfToken csrfToken = repository.loadToken(request);
329332
MockHttpServletRequestBuilder ok = post("/ok")
330-
.with(csrf())
333+
.header(csrfToken.getHeaderName(), csrfToken.getToken())
331334
.session(session);
332335
this.mvc.perform(ok).andExpect(status().isForbidden());
333336
// @formatter:on
@@ -594,7 +597,7 @@ static class CsrfReturnedResultMatcher implements ResultMatcher {
594597
@Override
595598
public void match(MvcResult result) throws Exception {
596599
MockHttpServletRequest request = result.getRequest();
597-
CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request);
600+
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
598601
assertThat(token).isNotNull();
599602
assertThat(token.getToken()).isEqualTo(this.token.apply(result));
600603
}

test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@
9595
import org.springframework.security.web.csrf.CsrfFilter;
9696
import org.springframework.security.web.csrf.CsrfToken;
9797
import org.springframework.security.web.csrf.CsrfTokenRepository;
98+
import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
99+
import org.springframework.security.web.csrf.DeferredCsrfToken;
98100
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
99101
import org.springframework.test.util.ReflectionTestUtils;
100102
import org.springframework.test.web.servlet.MockMvc;
@@ -499,6 +501,10 @@ public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request)
499501
*/
500502
public static final class CsrfRequestPostProcessor implements RequestPostProcessor {
501503

504+
private static final byte[] INVALID_TOKEN_BYTES = new byte[] { 1, 1, 1, 96, 99, 98 };
505+
506+
private static final String INVALID_TOKEN_VALUE = Base64.getEncoder().encodeToString(INVALID_TOKEN_BYTES);
507+
502508
private boolean asHeader;
503509

504510
private boolean useInvalidToken;
@@ -509,14 +515,17 @@ private CsrfRequestPostProcessor() {
509515
@Override
510516
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
511517
CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
518+
CsrfTokenRequestHandler handler = WebTestUtils.getCsrfTokenRequestHandler(request);
512519
if (!(repository instanceof TestCsrfTokenRepository)) {
513520
repository = new TestCsrfTokenRepository(new HttpSessionCsrfTokenRepository());
514521
WebTestUtils.setCsrfTokenRepository(request, repository);
515522
}
516523
TestCsrfTokenRepository.enable(request);
517-
CsrfToken token = repository.generateToken(request);
518-
repository.saveToken(token, request, new MockHttpServletResponse());
519-
String tokenValue = this.useInvalidToken ? "invalid" + token.getToken() : token.getToken();
524+
MockHttpServletResponse response = new MockHttpServletResponse();
525+
DeferredCsrfToken deferredCsrfToken = repository.loadDeferredToken(request, response);
526+
handler.handle(request, response, deferredCsrfToken::get);
527+
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
528+
String tokenValue = this.useInvalidToken ? INVALID_TOKEN_VALUE : token.getToken();
520529
if (this.asHeader) {
521530
request.addHeader(token.getHeaderName(), tokenValue);
522531
}

test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
import org.springframework.security.web.context.SecurityContextRepository;
3232
import org.springframework.security.web.csrf.CsrfFilter;
3333
import org.springframework.security.web.csrf.CsrfTokenRepository;
34+
import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
3435
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
36+
import org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler;
3537
import org.springframework.test.util.ReflectionTestUtils;
3638
import org.springframework.web.context.WebApplicationContext;
3739
import org.springframework.web.context.support.WebApplicationContextUtils;
@@ -48,6 +50,8 @@ public abstract class WebTestUtils {
4850

4951
private static final CsrfTokenRepository DEFAULT_TOKEN_REPO = new HttpSessionCsrfTokenRepository();
5052

53+
private static final CsrfTokenRequestHandler DEFAULT_CSRF_HANDLER = new XorCsrfTokenRequestAttributeHandler();
54+
5155
private WebTestUtils() {
5256
}
5357

@@ -107,6 +111,23 @@ public static CsrfTokenRepository getCsrfTokenRepository(HttpServletRequest requ
107111
return (CsrfTokenRepository) ReflectionTestUtils.getField(filter, "tokenRepository");
108112
}
109113

114+
/**
115+
* Gets the {@link CsrfTokenRequestHandler} for the specified
116+
* {@link HttpServletRequest}. If one is not found, the default
117+
* {@link XorCsrfTokenRequestAttributeHandler} is used.
118+
* @param request the {@link HttpServletRequest} to obtain the
119+
* {@link CsrfTokenRequestHandler}
120+
* @return the {@link CsrfTokenRequestHandler} for the specified
121+
* {@link HttpServletRequest}
122+
*/
123+
public static CsrfTokenRequestHandler getCsrfTokenRequestHandler(HttpServletRequest request) {
124+
CsrfFilter filter = findFilter(request, CsrfFilter.class);
125+
if (filter == null) {
126+
return DEFAULT_CSRF_HANDLER;
127+
}
128+
return (CsrfTokenRequestHandler) ReflectionTestUtils.getField(filter, "requestHandler");
129+
}
130+
110131
/**
111132
* Sets the {@link CsrfTokenRepository} for the specified {@link HttpServletRequest}.
112133
* @param request the {@link HttpServletRequest} to obtain the

test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import org.springframework.http.MediaType;
2626
import org.springframework.mock.web.MockHttpServletRequest;
2727
import org.springframework.mock.web.MockServletContext;
28-
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
2928
import org.springframework.security.web.csrf.CsrfToken;
3029
import org.springframework.test.web.servlet.MockMvc;
3130
import org.springframework.test.web.servlet.MvcResult;
@@ -52,8 +51,7 @@ public void setup() {
5251
@Test
5352
public void defaults() {
5453
MockHttpServletRequest request = formLogin().buildRequest(this.servletContext);
55-
CsrfToken token = (CsrfToken) request
56-
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
54+
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
5755
assertThat(request.getParameter("username")).isEqualTo("user");
5856
assertThat(request.getParameter("password")).isEqualTo("password");
5957
assertThat(request.getMethod()).isEqualTo("POST");
@@ -66,8 +64,7 @@ public void defaults() {
6664
public void custom() {
6765
MockHttpServletRequest request = formLogin("/login").user("username", "admin").password("password", "secret")
6866
.buildRequest(this.servletContext);
69-
CsrfToken token = (CsrfToken) request
70-
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
67+
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
7168
assertThat(request.getParameter("username")).isEqualTo("admin");
7269
assertThat(request.getParameter("password")).isEqualTo("secret");
7370
assertThat(request.getMethod()).isEqualTo("POST");
@@ -79,8 +76,7 @@ public void custom() {
7976
public void customWithUriVars() {
8077
MockHttpServletRequest request = formLogin().loginProcessingUrl("/uri-login/{var1}/{var2}", "val1", "val2")
8178
.user("username", "admin").password("password", "secret").buildRequest(this.servletContext);
82-
CsrfToken token = (CsrfToken) request
83-
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
79+
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
8480
assertThat(request.getParameter("username")).isEqualTo("admin");
8581
assertThat(request.getParameter("password")).isEqualTo("secret");
8682
assertThat(request.getMethod()).isEqualTo("POST");

test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import org.springframework.http.MediaType;
2626
import org.springframework.mock.web.MockHttpServletRequest;
2727
import org.springframework.mock.web.MockServletContext;
28-
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
2928
import org.springframework.security.web.csrf.CsrfToken;
3029
import org.springframework.test.web.servlet.MockMvc;
3130
import org.springframework.test.web.servlet.MvcResult;
@@ -52,8 +51,7 @@ public void setup() {
5251
@Test
5352
public void defaults() {
5453
MockHttpServletRequest request = logout().buildRequest(this.servletContext);
55-
CsrfToken token = (CsrfToken) request
56-
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
54+
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
5755
assertThat(request.getMethod()).isEqualTo("POST");
5856
assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
5957
assertThat(request.getRequestURI()).isEqualTo("/logout");
@@ -62,8 +60,7 @@ public void defaults() {
6260
@Test
6361
public void custom() {
6462
MockHttpServletRequest request = logout("/admin/logout").buildRequest(this.servletContext);
65-
CsrfToken token = (CsrfToken) request
66-
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
63+
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
6764
assertThat(request.getMethod()).isEqualTo("POST");
6865
assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
6966
assertThat(request.getRequestURI()).isEqualTo("/admin/logout");
@@ -73,8 +70,7 @@ public void custom() {
7370
public void customWithUriVars() {
7471
MockHttpServletRequest request = logout().logoutUrl("/uri-logout/{var1}/{var2}", "val1", "val2")
7572
.buildRequest(this.servletContext);
76-
CsrfToken token = (CsrfToken) request
77-
.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
73+
CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
7874
assertThat(request.getMethod()).isEqualTo("POST");
7975
assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
8076
assertThat(request.getRequestURI()).isEqualTo("/uri-logout/val1/val2");

0 commit comments

Comments
 (0)