Skip to content

Commit 1181740

Browse files
Rob Winchrwinch
Rob Winch
authored andcommitted
Constant Time Comparison for CSRF tokens
Closes gh-9291
1 parent 628ea00 commit 1181740

File tree

2 files changed

+124
-106
lines changed

2 files changed

+124
-106
lines changed

web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java

+57-52
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
1617
package org.springframework.security.web.csrf;
1718

1819
import java.io.IOException;
20+
import java.security.MessageDigest;
1921
import java.util.Arrays;
2022
import java.util.HashSet;
2123

@@ -28,15 +30,16 @@
2830
import org.apache.commons.logging.Log;
2931
import org.apache.commons.logging.LogFactory;
3032

33+
import org.springframework.core.log.LogMessage;
34+
import org.springframework.security.access.AccessDeniedException;
35+
import org.springframework.security.crypto.codec.Utf8;
3136
import org.springframework.security.web.access.AccessDeniedHandler;
3237
import org.springframework.security.web.access.AccessDeniedHandlerImpl;
3338
import org.springframework.security.web.util.UrlUtils;
3439
import org.springframework.security.web.util.matcher.RequestMatcher;
3540
import org.springframework.util.Assert;
3641
import org.springframework.web.filter.OncePerRequestFilter;
3742

38-
import static java.lang.Boolean.TRUE;
39-
4043
/**
4144
* <p>
4245
* Applies
@@ -58,6 +61,7 @@
5861
* @since 3.2
5962
*/
6063
public final class CsrfFilter extends OncePerRequestFilter {
64+
6165
/**
6266
* The default {@link RequestMatcher} that indicates if CSRF protection is required or
6367
* not. The default is to ignore GET, HEAD, TRACE, OPTIONS and process all other
@@ -66,18 +70,21 @@ public final class CsrfFilter extends OncePerRequestFilter {
6670
public static final RequestMatcher DEFAULT_CSRF_MATCHER = new DefaultRequiresCsrfMatcher();
6771

6872
/**
69-
* The attribute name to use when marking a given request as one that should not be filtered.
73+
* The attribute name to use when marking a given request as one that should not be
74+
* filtered.
7075
*
71-
* To use, set the attribute on your {@link HttpServletRequest}:
72-
* <pre>
76+
* To use, set the attribute on your {@link HttpServletRequest}: <pre>
7377
* CsrfFilter.skipRequest(request);
7478
* </pre>
7579
*/
7680
private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER" + CsrfFilter.class.getName();
7781

7882
private final Log logger = LogFactory.getLog(getClass());
83+
7984
private final CsrfTokenRepository tokenRepository;
85+
8086
private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
87+
8188
private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl();
8289

8390
public CsrfFilter(CsrfTokenRepository csrfTokenRepository) {
@@ -87,62 +94,46 @@ public CsrfFilter(CsrfTokenRepository csrfTokenRepository) {
8794

8895
@Override
8996
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
90-
return TRUE.equals(request.getAttribute(SHOULD_NOT_FILTER));
97+
return Boolean.TRUE.equals(request.getAttribute(SHOULD_NOT_FILTER));
9198
}
9299

93-
/*
94-
* (non-Javadoc)
95-
*
96-
* @see
97-
* org.springframework.web.filter.OncePerRequestFilter#doFilterInternal(javax.servlet
98-
* .http.HttpServletRequest, javax.servlet.http.HttpServletResponse,
99-
* javax.servlet.FilterChain)
100-
*/
101100
@Override
102-
protected void doFilterInternal(HttpServletRequest request,
103-
HttpServletResponse response, FilterChain filterChain)
104-
throws ServletException, IOException {
101+
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
102+
throws ServletException, IOException {
105103
request.setAttribute(HttpServletResponse.class.getName(), response);
106-
107104
CsrfToken csrfToken = this.tokenRepository.loadToken(request);
108-
final boolean missingToken = csrfToken == null;
105+
boolean missingToken = (csrfToken == null);
109106
if (missingToken) {
110107
csrfToken = this.tokenRepository.generateToken(request);
111108
this.tokenRepository.saveToken(csrfToken, request, response);
112109
}
113110
request.setAttribute(CsrfToken.class.getName(), csrfToken);
114111
request.setAttribute(csrfToken.getParameterName(), csrfToken);
115-
116112
if (!this.requireCsrfProtectionMatcher.matches(request)) {
113+
if (this.logger.isTraceEnabled()) {
114+
this.logger.trace("Did not protect against CSRF since request did not match "
115+
+ this.requireCsrfProtectionMatcher);
116+
}
117117
filterChain.doFilter(request, response);
118118
return;
119119
}
120-
121120
String actualToken = request.getHeader(csrfToken.getHeaderName());
122121
if (actualToken == null) {
123122
actualToken = request.getParameter(csrfToken.getParameterName());
124123
}
125-
if (!csrfToken.getToken().equals(actualToken)) {
126-
if (this.logger.isDebugEnabled()) {
127-
this.logger.debug("Invalid CSRF token found for "
128-
+ UrlUtils.buildFullRequestUrl(request));
129-
}
130-
if (missingToken) {
131-
this.accessDeniedHandler.handle(request, response,
132-
new MissingCsrfTokenException(actualToken));
133-
}
134-
else {
135-
this.accessDeniedHandler.handle(request, response,
136-
new InvalidCsrfTokenException(csrfToken, actualToken));
137-
}
124+
if (!equalsConstantTime(csrfToken.getToken(), actualToken)) {
125+
this.logger.debug(
126+
LogMessage.of(() -> "Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request)));
127+
AccessDeniedException exception = (!missingToken) ? new InvalidCsrfTokenException(csrfToken, actualToken)
128+
: new MissingCsrfTokenException(actualToken);
129+
this.accessDeniedHandler.handle(request, response, exception);
138130
return;
139131
}
140-
141132
filterChain.doFilter(request, response);
142133
}
143134

144135
public static void skipRequest(HttpServletRequest request) {
145-
request.setAttribute(SHOULD_NOT_FILTER, TRUE);
136+
request.setAttribute(SHOULD_NOT_FILTER, Boolean.TRUE);
146137
}
147138

148139
/**
@@ -154,14 +145,11 @@ public static void skipRequest(HttpServletRequest request) {
154145
* The default is to apply CSRF protection for any HTTP method other than GET, HEAD,
155146
* TRACE, OPTIONS.
156147
* </p>
157-
*
158148
* @param requireCsrfProtectionMatcher the {@link RequestMatcher} used to determine if
159149
* CSRF protection should be applied.
160150
*/
161-
public void setRequireCsrfProtectionMatcher(
162-
RequestMatcher requireCsrfProtectionMatcher) {
163-
Assert.notNull(requireCsrfProtectionMatcher,
164-
"requireCsrfProtectionMatcher cannot be null");
151+
public void setRequireCsrfProtectionMatcher(RequestMatcher requireCsrfProtectionMatcher) {
152+
Assert.notNull(requireCsrfProtectionMatcher, "requireCsrfProtectionMatcher cannot be null");
165153
this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher;
166154
}
167155

@@ -172,28 +160,45 @@ public void setRequireCsrfProtectionMatcher(
172160
* <p>
173161
* The default is to use AccessDeniedHandlerImpl with no arguments.
174162
* </p>
175-
*
176163
* @param accessDeniedHandler the {@link AccessDeniedHandler} to use
177164
*/
178165
public void setAccessDeniedHandler(AccessDeniedHandler accessDeniedHandler) {
179166
Assert.notNull(accessDeniedHandler, "accessDeniedHandler cannot be null");
180167
this.accessDeniedHandler = accessDeniedHandler;
181168
}
182169

170+
/**
171+
* Constant time comparison to prevent against timing attacks.
172+
* @param expected
173+
* @param actual
174+
* @return
175+
*/
176+
private static boolean equalsConstantTime(String expected, String actual) {
177+
byte[] expectedBytes = bytesUtf8(expected);
178+
byte[] actualBytes = bytesUtf8(actual);
179+
return MessageDigest.isEqual(expectedBytes, actualBytes);
180+
}
181+
182+
private static byte[] bytesUtf8(String s) {
183+
// need to check if Utf8.encode() runs in constant time (probably not).
184+
// This may leak length of string.
185+
return (s != null) ? Utf8.encode(s) : null;
186+
}
187+
183188
private static final class DefaultRequiresCsrfMatcher implements RequestMatcher {
184-
private final HashSet<String> allowedMethods = new HashSet<>(
185-
Arrays.asList("GET", "HEAD", "TRACE", "OPTIONS"));
186-
187-
/*
188-
* (non-Javadoc)
189-
*
190-
* @see
191-
* org.springframework.security.web.util.matcher.RequestMatcher#matches(javax.
192-
* servlet.http.HttpServletRequest)
193-
*/
189+
190+
private final HashSet<String> allowedMethods = new HashSet<>(Arrays.asList("GET", "HEAD", "TRACE", "OPTIONS"));
191+
194192
@Override
195193
public boolean matches(HttpServletRequest request) {
196194
return !this.allowedMethods.contains(request.getMethod());
197195
}
196+
197+
@Override
198+
public String toString() {
199+
return "CsrfNotRequired " + this.allowedMethods;
200+
}
201+
198202
}
203+
199204
}

0 commit comments

Comments
 (0)