13
13
* See the License for the specific language governing permissions and
14
14
* limitations under the License.
15
15
*/
16
+
16
17
package org .springframework .security .web .csrf ;
17
18
18
19
import java .io .IOException ;
20
+ import java .security .MessageDigest ;
19
21
import java .util .Arrays ;
20
22
import java .util .HashSet ;
21
23
28
30
import org .apache .commons .logging .Log ;
29
31
import org .apache .commons .logging .LogFactory ;
30
32
33
+ import org .springframework .core .log .LogMessage ;
34
+ import org .springframework .security .access .AccessDeniedException ;
35
+ import org .springframework .security .crypto .codec .Utf8 ;
31
36
import org .springframework .security .web .access .AccessDeniedHandler ;
32
37
import org .springframework .security .web .access .AccessDeniedHandlerImpl ;
33
38
import org .springframework .security .web .util .UrlUtils ;
34
39
import org .springframework .security .web .util .matcher .RequestMatcher ;
35
40
import org .springframework .util .Assert ;
36
41
import org .springframework .web .filter .OncePerRequestFilter ;
37
42
38
- import static java .lang .Boolean .TRUE ;
39
-
40
43
/**
41
44
* <p>
42
45
* Applies
58
61
* @since 3.2
59
62
*/
60
63
public final class CsrfFilter extends OncePerRequestFilter {
64
+
61
65
/**
62
66
* The default {@link RequestMatcher} that indicates if CSRF protection is required or
63
67
* not. The default is to ignore GET, HEAD, TRACE, OPTIONS and process all other
@@ -66,18 +70,21 @@ public final class CsrfFilter extends OncePerRequestFilter {
66
70
public static final RequestMatcher DEFAULT_CSRF_MATCHER = new DefaultRequiresCsrfMatcher ();
67
71
68
72
/**
69
- * The attribute name to use when marking a given request as one that should not be filtered.
73
+ * The attribute name to use when marking a given request as one that should not be
74
+ * filtered.
70
75
*
71
- * To use, set the attribute on your {@link HttpServletRequest}:
72
- * <pre>
76
+ * To use, set the attribute on your {@link HttpServletRequest}: <pre>
73
77
* CsrfFilter.skipRequest(request);
74
78
* </pre>
75
79
*/
76
80
private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER" + CsrfFilter .class .getName ();
77
81
78
82
private final Log logger = LogFactory .getLog (getClass ());
83
+
79
84
private final CsrfTokenRepository tokenRepository ;
85
+
80
86
private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER ;
87
+
81
88
private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl ();
82
89
83
90
public CsrfFilter (CsrfTokenRepository csrfTokenRepository ) {
@@ -87,62 +94,46 @@ public CsrfFilter(CsrfTokenRepository csrfTokenRepository) {
87
94
88
95
@ Override
89
96
protected boolean shouldNotFilter (HttpServletRequest request ) throws ServletException {
90
- return TRUE .equals (request .getAttribute (SHOULD_NOT_FILTER ));
97
+ return Boolean . TRUE .equals (request .getAttribute (SHOULD_NOT_FILTER ));
91
98
}
92
99
93
- /*
94
- * (non-Javadoc)
95
- *
96
- * @see
97
- * org.springframework.web.filter.OncePerRequestFilter#doFilterInternal(javax.servlet
98
- * .http.HttpServletRequest, javax.servlet.http.HttpServletResponse,
99
- * javax.servlet.FilterChain)
100
- */
101
100
@ Override
102
- protected void doFilterInternal (HttpServletRequest request ,
103
- HttpServletResponse response , FilterChain filterChain )
104
- throws ServletException , IOException {
101
+ protected void doFilterInternal (HttpServletRequest request , HttpServletResponse response , FilterChain filterChain )
102
+ throws ServletException , IOException {
105
103
request .setAttribute (HttpServletResponse .class .getName (), response );
106
-
107
104
CsrfToken csrfToken = this .tokenRepository .loadToken (request );
108
- final boolean missingToken = csrfToken == null ;
105
+ boolean missingToken = ( csrfToken == null ) ;
109
106
if (missingToken ) {
110
107
csrfToken = this .tokenRepository .generateToken (request );
111
108
this .tokenRepository .saveToken (csrfToken , request , response );
112
109
}
113
110
request .setAttribute (CsrfToken .class .getName (), csrfToken );
114
111
request .setAttribute (csrfToken .getParameterName (), csrfToken );
115
-
116
112
if (!this .requireCsrfProtectionMatcher .matches (request )) {
113
+ if (this .logger .isTraceEnabled ()) {
114
+ this .logger .trace ("Did not protect against CSRF since request did not match "
115
+ + this .requireCsrfProtectionMatcher );
116
+ }
117
117
filterChain .doFilter (request , response );
118
118
return ;
119
119
}
120
-
121
120
String actualToken = request .getHeader (csrfToken .getHeaderName ());
122
121
if (actualToken == null ) {
123
122
actualToken = request .getParameter (csrfToken .getParameterName ());
124
123
}
125
- if (!csrfToken .getToken ().equals (actualToken )) {
126
- if (this .logger .isDebugEnabled ()) {
127
- this .logger .debug ("Invalid CSRF token found for "
128
- + UrlUtils .buildFullRequestUrl (request ));
129
- }
130
- if (missingToken ) {
131
- this .accessDeniedHandler .handle (request , response ,
132
- new MissingCsrfTokenException (actualToken ));
133
- }
134
- else {
135
- this .accessDeniedHandler .handle (request , response ,
136
- new InvalidCsrfTokenException (csrfToken , actualToken ));
137
- }
124
+ if (!equalsConstantTime (csrfToken .getToken (), actualToken )) {
125
+ this .logger .debug (
126
+ LogMessage .of (() -> "Invalid CSRF token found for " + UrlUtils .buildFullRequestUrl (request )));
127
+ AccessDeniedException exception = (!missingToken ) ? new InvalidCsrfTokenException (csrfToken , actualToken )
128
+ : new MissingCsrfTokenException (actualToken );
129
+ this .accessDeniedHandler .handle (request , response , exception );
138
130
return ;
139
131
}
140
-
141
132
filterChain .doFilter (request , response );
142
133
}
143
134
144
135
public static void skipRequest (HttpServletRequest request ) {
145
- request .setAttribute (SHOULD_NOT_FILTER , TRUE );
136
+ request .setAttribute (SHOULD_NOT_FILTER , Boolean . TRUE );
146
137
}
147
138
148
139
/**
@@ -154,14 +145,11 @@ public static void skipRequest(HttpServletRequest request) {
154
145
* The default is to apply CSRF protection for any HTTP method other than GET, HEAD,
155
146
* TRACE, OPTIONS.
156
147
* </p>
157
- *
158
148
* @param requireCsrfProtectionMatcher the {@link RequestMatcher} used to determine if
159
149
* CSRF protection should be applied.
160
150
*/
161
- public void setRequireCsrfProtectionMatcher (
162
- RequestMatcher requireCsrfProtectionMatcher ) {
163
- Assert .notNull (requireCsrfProtectionMatcher ,
164
- "requireCsrfProtectionMatcher cannot be null" );
151
+ public void setRequireCsrfProtectionMatcher (RequestMatcher requireCsrfProtectionMatcher ) {
152
+ Assert .notNull (requireCsrfProtectionMatcher , "requireCsrfProtectionMatcher cannot be null" );
165
153
this .requireCsrfProtectionMatcher = requireCsrfProtectionMatcher ;
166
154
}
167
155
@@ -172,28 +160,45 @@ public void setRequireCsrfProtectionMatcher(
172
160
* <p>
173
161
* The default is to use AccessDeniedHandlerImpl with no arguments.
174
162
* </p>
175
- *
176
163
* @param accessDeniedHandler the {@link AccessDeniedHandler} to use
177
164
*/
178
165
public void setAccessDeniedHandler (AccessDeniedHandler accessDeniedHandler ) {
179
166
Assert .notNull (accessDeniedHandler , "accessDeniedHandler cannot be null" );
180
167
this .accessDeniedHandler = accessDeniedHandler ;
181
168
}
182
169
170
+ /**
171
+ * Constant time comparison to prevent against timing attacks.
172
+ * @param expected
173
+ * @param actual
174
+ * @return
175
+ */
176
+ private static boolean equalsConstantTime (String expected , String actual ) {
177
+ byte [] expectedBytes = bytesUtf8 (expected );
178
+ byte [] actualBytes = bytesUtf8 (actual );
179
+ return MessageDigest .isEqual (expectedBytes , actualBytes );
180
+ }
181
+
182
+ private static byte [] bytesUtf8 (String s ) {
183
+ // need to check if Utf8.encode() runs in constant time (probably not).
184
+ // This may leak length of string.
185
+ return (s != null ) ? Utf8 .encode (s ) : null ;
186
+ }
187
+
183
188
private static final class DefaultRequiresCsrfMatcher implements RequestMatcher {
184
- private final HashSet <String > allowedMethods = new HashSet <>(
185
- Arrays .asList ("GET" , "HEAD" , "TRACE" , "OPTIONS" ));
186
-
187
- /*
188
- * (non-Javadoc)
189
- *
190
- * @see
191
- * org.springframework.security.web.util.matcher.RequestMatcher#matches(javax.
192
- * servlet.http.HttpServletRequest)
193
- */
189
+
190
+ private final HashSet <String > allowedMethods = new HashSet <>(Arrays .asList ("GET" , "HEAD" , "TRACE" , "OPTIONS" ));
191
+
194
192
@ Override
195
193
public boolean matches (HttpServletRequest request ) {
196
194
return !this .allowedMethods .contains (request .getMethod ());
197
195
}
196
+
197
+ @ Override
198
+ public String toString () {
199
+ return "CsrfNotRequired " + this .allowedMethods ;
200
+ }
201
+
198
202
}
203
+
199
204
}
0 commit comments