diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenEncoder.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenEncoder.java new file mode 100644 index 00000000000..45c1c2bafc4 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfTokenEncoder.java @@ -0,0 +1,44 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.csrf; + +import org.jspecify.annotations.Nullable; + +/** + * Interface for encoding and decoding CSRF tokens. + * + * Defines methods to encode a CSRF token and to decode an encoded token + * by referencing the original unencoded token. + * + * This is primarily used to safely transform CSRF tokens for security purposes. + * + * @author Cheol Jeon + * @since + * @see XorCsrfTokenEncoder + */ +public interface CsrfTokenEncoder { + + String encode(String token); + + /** + * Decodes the encoded CSRF token using the original unencoded token. + * This is necessary because the decoding process requires the original token length. + */ + @Nullable + String decode(String encodedToken, String originalToken); + +} diff --git a/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenEncoder.java b/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenEncoder.java new file mode 100644 index 00000000000..261664d8566 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenEncoder.java @@ -0,0 +1,112 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.csrf; + +import org.jspecify.annotations.Nullable; +import org.springframework.core.log.LogMessage; +import org.springframework.security.crypto.codec.Utf8; +import org.springframework.util.Assert; + +import java.security.SecureRandom; +import java.util.Base64; + +import static org.springframework.security.web.csrf.CsrfTokenRequestHandlerLoggerHolder.logger; + +/** + * Implementation of CsrfTokenEncoder that uses XOR operation combined with a random key + * to encode and decode CSRF tokens. + * + * The encode method generates a random byte array and XORs it with the UTF-8 bytes of the token, + * then combines both arrays and encodes them in Base64 URL-safe format. + * + * The decode method reverses this process by decoding the Base64 string, splitting the bytes, + * and XORing the two parts to retrieve the original token. + * + * This approach enhances CSRF token security by obfuscating the token value with randomness. + * + * @author Cheol Jeon + * @since + * @see XorCsrfTokenRequestAttributeHandler + */ +public class XorCsrfTokenEncoder implements CsrfTokenEncoder { + private SecureRandom secureRandom; + + public XorCsrfTokenEncoder() { + this(new SecureRandom()); + } + + public XorCsrfTokenEncoder(SecureRandom secureRandom) { + Assert.notNull(secureRandom, "secureRandom cannot be null"); + this.secureRandom = secureRandom; + } + + @Override + public String encode(String token) { + byte[] tokenBytes = Utf8.encode(token); + byte[] randomBytes = new byte[tokenBytes.length]; + secureRandom.nextBytes(randomBytes); + + byte[] xoredBytes = xor(randomBytes, tokenBytes); + byte[] combinedBytes = new byte[tokenBytes.length + randomBytes.length]; + System.arraycopy(randomBytes, 0, combinedBytes, 0, randomBytes.length); + System.arraycopy(xoredBytes, 0, combinedBytes, randomBytes.length, xoredBytes.length); + + return Base64.getUrlEncoder().encodeToString(combinedBytes); + } + + @Override + public @Nullable String decode(String encodedToken, String originalToken) { + byte[] actualBytes; + try { + actualBytes = Base64.getUrlDecoder().decode(encodedToken); + } + catch (Exception ex) { + logger.trace(LogMessage.format("Not returning the CSRF token since it's not Base64-encoded"), ex); + return null; + } + + byte[] tokenBytes = Utf8.encode(originalToken); + int tokenSize = tokenBytes.length; + if (actualBytes.length != tokenSize * 2) { + logger.trace(LogMessage.format( + "Not returning the CSRF token since its Base64-decoded length (%d) is not equal to (%d)", + actualBytes.length, tokenSize * 2)); + return null; + } + + // extract token and random bytes + byte[] xoredCsrf = new byte[tokenSize]; + byte[] randomBytes = new byte[tokenSize]; + + System.arraycopy(actualBytes, 0, randomBytes, 0, tokenSize); + System.arraycopy(actualBytes, tokenSize, xoredCsrf, 0, tokenSize); + + byte[] csrfBytes = xor(randomBytes, xoredCsrf); + return Utf8.decode(csrfBytes); + } + + private byte[] xor(byte[] randomBytes, byte[] csrfBytes) { + Assert.isTrue(randomBytes.length == csrfBytes.length, "arrays must be equal length"); + int len = csrfBytes.length; + byte[] xoredCsrf = new byte[len]; + System.arraycopy(csrfBytes, 0, xoredCsrf, 0, len); + for (int i = 0; i < len; i++) { + xoredCsrf[i] ^= randomBytes[i]; + } + return xoredCsrf; + } +} diff --git a/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java b/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java index ad1220ca571..dec72b48229 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java +++ b/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java @@ -16,20 +16,16 @@ package org.springframework.security.web.csrf; -import java.security.SecureRandom; -import java.util.Base64; -import java.util.function.Supplier; - import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.jspecify.annotations.Nullable; - -import org.springframework.core.log.LogMessage; -import org.springframework.security.crypto.codec.Utf8; import org.springframework.util.Assert; +import java.security.SecureRandom; +import java.util.function.Supplier; + /** * An implementation of the {@link CsrfTokenRequestHandler} interface that is capable of * masking the value of the {@link CsrfToken} on each request and resolving the raw token @@ -37,13 +33,14 @@ * * @author Steve Riesenberg * @author Yoobin Yoon + * @author Cheol Jeon * @since 5.8 */ public final class XorCsrfTokenRequestAttributeHandler extends CsrfTokenRequestAttributeHandler { private static final Log logger = LogFactory.getLog(XorCsrfTokenRequestAttributeHandler.class); - private SecureRandom secureRandom = new SecureRandom(); + private CsrfTokenEncoder csrfTokenEncoder = new XorCsrfTokenEncoder(); /** * Specifies the {@code SecureRandom} used to generate random bytes that are used to @@ -52,7 +49,7 @@ public final class XorCsrfTokenRequestAttributeHandler extends CsrfTokenRequestA */ public void setSecureRandom(SecureRandom secureRandom) { Assert.notNull(secureRandom, "secureRandom cannot be null"); - this.secureRandom = secureRandom; + this.csrfTokenEncoder = new XorCsrfTokenEncoder(secureRandom); } @Override @@ -69,7 +66,7 @@ private Supplier deferCsrfTokenUpdate(Supplier csrfTokenSu return new CachedCsrfTokenSupplier(() -> { CsrfToken csrfToken = csrfTokenSupplier.get(); Assert.state(csrfToken != null, "csrfToken supplier returned null"); - String updatedToken = createXoredCsrfToken(this.secureRandom, csrfToken.getToken()); + String updatedToken = csrfTokenEncoder.encode(csrfToken.getToken()); return new DefaultCsrfToken(csrfToken.getHeaderName(), csrfToken.getParameterName(), updatedToken); }); } @@ -80,61 +77,7 @@ private Supplier deferCsrfTokenUpdate(Supplier csrfTokenSu if (actualToken == null) { return null; } - return getTokenValue(actualToken, csrfToken.getToken()); - } - - private static @Nullable String getTokenValue(String actualToken, String token) { - byte[] actualBytes; - try { - actualBytes = Base64.getUrlDecoder().decode(actualToken); - } - catch (Exception ex) { - logger.trace(LogMessage.format("Not returning the CSRF token since it's not Base64-encoded"), ex); - return null; - } - - byte[] tokenBytes = Utf8.encode(token); - int tokenSize = tokenBytes.length; - if (actualBytes.length != tokenSize * 2) { - logger.trace(LogMessage.format( - "Not returning the CSRF token since its Base64-decoded length (%d) is not equal to (%d)", - actualBytes.length, tokenSize * 2)); - return null; - } - - // extract token and random bytes - byte[] xoredCsrf = new byte[tokenSize]; - byte[] randomBytes = new byte[tokenSize]; - - System.arraycopy(actualBytes, 0, randomBytes, 0, tokenSize); - System.arraycopy(actualBytes, tokenSize, xoredCsrf, 0, tokenSize); - - byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf); - return Utf8.decode(csrfBytes); - } - - private static String createXoredCsrfToken(SecureRandom secureRandom, String token) { - byte[] tokenBytes = Utf8.encode(token); - byte[] randomBytes = new byte[tokenBytes.length]; - secureRandom.nextBytes(randomBytes); - - byte[] xoredBytes = xorCsrf(randomBytes, tokenBytes); - byte[] combinedBytes = new byte[tokenBytes.length + randomBytes.length]; - System.arraycopy(randomBytes, 0, combinedBytes, 0, randomBytes.length); - System.arraycopy(xoredBytes, 0, combinedBytes, randomBytes.length, xoredBytes.length); - - return Base64.getUrlEncoder().encodeToString(combinedBytes); - } - - private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) { - Assert.isTrue(randomBytes.length == csrfBytes.length, "arrays must be equal length"); - int len = csrfBytes.length; - byte[] xoredCsrf = new byte[len]; - System.arraycopy(csrfBytes, 0, xoredCsrf, 0, len); - for (int i = 0; i < len; i++) { - xoredCsrf[i] ^= randomBytes[i]; - } - return xoredCsrf; + return csrfTokenEncoder.decode(actualToken, csrfToken.getToken()); } private static final class CachedCsrfTokenSupplier implements Supplier { diff --git a/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenEncoderTest.java b/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenEncoderTest.java new file mode 100644 index 00000000000..6463e6a2155 --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenEncoderTest.java @@ -0,0 +1,93 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.csrf; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.mock.web.MockHttpServletRequest; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +/** + * Tests for {@link XorCsrfTokenEncoder}. + * + * @author Cheol Jeon + * @since + */ +public class XorCsrfTokenEncoderTest { + + private XorCsrfTokenEncoder encoder; + + private CsrfToken csrfToken; + + @BeforeEach + void setup() { + this.encoder = new XorCsrfTokenEncoder(); + this.csrfToken = new CookieCsrfTokenRepository().generateToken(new MockHttpServletRequest()); + } + + @Test + void encodeAndDecode_shouldReturnOriginalToken() { + String originalToken = csrfToken.getToken(); + + String encoded = encoder.encode(originalToken); + assertNotNull(encoded, "Encoded token should not be null"); + + String decoded = encoder.decode(encoded, originalToken); + assertEquals(originalToken, decoded, "Decoded token should match the original"); + } + + @Test + void decode_withInvalidBase64_shouldReturnNull() { + String invalidEncoded = "not-base64!!"; + + String decoded = encoder.decode(invalidEncoded, "any-token"); + assertNull(decoded, "Decoding invalid base64 should return null"); + } + + @Test + void decode_withIncorrectLength_shouldReturnNull() { + String originalToken = csrfToken.getToken(); + + String encoded = encoder.encode(originalToken); + + // The CSRF token generated in Spring Security uses UUID.randomUUID().toString(), + // which produces a 36‑byte ASCII string (hyphens + hex digits). Because 36 is + // a multiple of 3, Base64 encoding of that input will not include padding ('='). + // Therefore, removing a single character from the encoded string (encoded.length() - 1) + // is sufficient here to simulate corruption of the token for this test case — + // i.e. it will produce an encoded value that no longer decodes back to the original token. + String truncated = encoded.substring(0, encoded.length() - 1); + + String decoded = encoder.decode(truncated, originalToken); + assertNull(decoded, "Decoding token with invalid length should return null"); + } + + @Test + void encode_shouldProduceDifferentValuesForSameInput() { + String originalToken = csrfToken.getToken(); + + String encoded1 = encoder.encode(originalToken); + String encoded2 = encoder.encode(originalToken); + + // Because random bytes used, encoded results should differ + assertNotEquals(encoded1, encoded2, "Encoded values for same input should differ"); + } +}