Skip to content

Commit 40e027c

Browse files
author
Rob Winch
committed
Constant Time Comparison for CSRF tokens
Closes gh-9291
1 parent c066e23 commit 40e027c

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java

+21-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.security.web.csrf;
1818

1919
import java.io.IOException;
20+
import java.security.MessageDigest;
2021
import java.util.Arrays;
2122
import java.util.HashSet;
2223

@@ -31,6 +32,7 @@
3132

3233
import org.springframework.core.log.LogMessage;
3334
import org.springframework.security.access.AccessDeniedException;
35+
import org.springframework.security.crypto.codec.Utf8;
3436
import org.springframework.security.web.access.AccessDeniedHandler;
3537
import org.springframework.security.web.access.AccessDeniedHandlerImpl;
3638
import org.springframework.security.web.util.UrlUtils;
@@ -119,7 +121,7 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
119121
if (actualToken == null) {
120122
actualToken = request.getParameter(csrfToken.getParameterName());
121123
}
122-
if (!csrfToken.getToken().equals(actualToken)) {
124+
if (!equalsConstantTime(csrfToken.getToken(), actualToken)) {
123125
this.logger.debug(
124126
LogMessage.of(() -> "Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request)));
125127
AccessDeniedException exception = (!missingToken) ? new InvalidCsrfTokenException(csrfToken, actualToken)
@@ -165,6 +167,24 @@ public void setAccessDeniedHandler(AccessDeniedHandler accessDeniedHandler) {
165167
this.accessDeniedHandler = accessDeniedHandler;
166168
}
167169

170+
/**
171+
* Constant time comparison to prevent against timing attacks.
172+
* @param expected
173+
* @param actual
174+
* @return
175+
*/
176+
private static boolean equalsConstantTime(String expected, String actual) {
177+
byte[] expectedBytes = bytesUtf8(expected);
178+
byte[] actualBytes = bytesUtf8(actual);
179+
return MessageDigest.isEqual(expectedBytes, actualBytes);
180+
}
181+
182+
private static byte[] bytesUtf8(String s) {
183+
// need to check if Utf8.encode() runs in constant time (probably not).
184+
// This may leak length of string.
185+
return (s != null) ? Utf8.encode(s) : null;
186+
}
187+
168188
private static final class DefaultRequiresCsrfMatcher implements RequestMatcher {
169189

170190
private final HashSet<String> allowedMethods = new HashSet<>(Arrays.asList("GET", "HEAD", "TRACE", "OPTIONS"));

web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java

+21-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.security.web.server.csrf;
1818

19+
import java.security.MessageDigest;
1920
import java.util.Arrays;
2021
import java.util.HashSet;
2122
import java.util.Set;
@@ -28,6 +29,7 @@
2829
import org.springframework.http.MediaType;
2930
import org.springframework.http.codec.multipart.FormFieldPart;
3031
import org.springframework.http.server.reactive.ServerHttpRequest;
32+
import org.springframework.security.crypto.codec.Utf8;
3133
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
3234
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
3335
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
@@ -139,7 +141,7 @@ private Mono<Boolean> containsValidCsrfToken(ServerWebExchange exchange, CsrfTok
139141
return exchange.getFormData().flatMap((data) -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
140142
.switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
141143
.switchIfEmpty(tokenFromMultipartData(exchange, expected))
142-
.map((actual) -> actual.equals(expected.getToken()));
144+
.map((actual) -> equalsConstantTime(actual, expected.getToken()));
143145
}
144146

145147
private Mono<String> tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) {
@@ -168,6 +170,24 @@ private Mono<CsrfToken> csrfToken(ServerWebExchange exchange) {
168170
return this.csrfTokenRepository.loadToken(exchange).switchIfEmpty(generateToken(exchange));
169171
}
170172

173+
/**
174+
* Constant time comparison to prevent against timing attacks.
175+
* @param expected
176+
* @param actual
177+
* @return
178+
*/
179+
private static boolean equalsConstantTime(String expected, String actual) {
180+
byte[] expectedBytes = bytesUtf8(expected);
181+
byte[] actualBytes = bytesUtf8(actual);
182+
return MessageDigest.isEqual(expectedBytes, actualBytes);
183+
}
184+
185+
private static byte[] bytesUtf8(String s) {
186+
// need to check if Utf8.encode() runs in constant time (probably not).
187+
// This may leak length of string.
188+
return (s != null) ? Utf8.encode(s) : null;
189+
}
190+
171191
private Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
172192
return this.csrfTokenRepository.generateToken(exchange)
173193
.delayUntil((token) -> this.csrfTokenRepository.saveToken(exchange, token));

0 commit comments

Comments
 (0)