|
16 | 16 |
|
17 | 17 | package org.springframework.security.web.server.csrf;
|
18 | 18 |
|
| 19 | +import java.security.MessageDigest; |
19 | 20 | import java.util.Arrays;
|
20 | 21 | import java.util.HashSet;
|
21 | 22 | import java.util.Set;
|
|
28 | 29 | import org.springframework.http.MediaType;
|
29 | 30 | import org.springframework.http.codec.multipart.FormFieldPart;
|
30 | 31 | import org.springframework.http.server.reactive.ServerHttpRequest;
|
| 32 | +import org.springframework.security.crypto.codec.Utf8; |
31 | 33 | import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
|
32 | 34 | import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
|
33 | 35 | import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
|
@@ -139,7 +141,7 @@ private Mono<Boolean> containsValidCsrfToken(ServerWebExchange exchange, CsrfTok
|
139 | 141 | return exchange.getFormData().flatMap((data) -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
|
140 | 142 | .switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
|
141 | 143 | .switchIfEmpty(tokenFromMultipartData(exchange, expected))
|
142 |
| - .map((actual) -> actual.equals(expected.getToken())); |
| 144 | + .map((actual) -> equalsConstantTime(actual, expected.getToken())); |
143 | 145 | }
|
144 | 146 |
|
145 | 147 | private Mono<String> tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) {
|
@@ -168,6 +170,24 @@ private Mono<CsrfToken> csrfToken(ServerWebExchange exchange) {
|
168 | 170 | return this.csrfTokenRepository.loadToken(exchange).switchIfEmpty(generateToken(exchange));
|
169 | 171 | }
|
170 | 172 |
|
| 173 | + /** |
| 174 | + * Constant time comparison to prevent against timing attacks. |
| 175 | + * @param expected |
| 176 | + * @param actual |
| 177 | + * @return |
| 178 | + */ |
| 179 | + private static boolean equalsConstantTime(String expected, String actual) { |
| 180 | + byte[] expectedBytes = bytesUtf8(expected); |
| 181 | + byte[] actualBytes = bytesUtf8(actual); |
| 182 | + return MessageDigest.isEqual(expectedBytes, actualBytes); |
| 183 | + } |
| 184 | + |
| 185 | + private static byte[] bytesUtf8(String s) { |
| 186 | + // need to check if Utf8.encode() runs in constant time (probably not). |
| 187 | + // This may leak length of string. |
| 188 | + return (s != null) ? Utf8.encode(s) : null; |
| 189 | + } |
| 190 | + |
171 | 191 | private Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
|
172 | 192 | return this.csrfTokenRepository.generateToken(exchange)
|
173 | 193 | .delayUntil((token) -> this.csrfTokenRepository.saveToken(exchange, token));
|
|
0 commit comments