-
-
Notifications
You must be signed in to change notification settings - Fork 308
/
csrf.py
329 lines (246 loc) · 9.93 KB
/
csrf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
import hashlib
import hmac
import logging
import os
from urllib.parse import urlparse
from flask import Blueprint
from flask import current_app
from flask import g
from flask import request
from flask import session
from itsdangerous import BadData
from itsdangerous import SignatureExpired
from itsdangerous import URLSafeTimedSerializer
from werkzeug.exceptions import BadRequest
from wtforms import ValidationError
from wtforms.csrf.core import CSRF
__all__ = ("generate_csrf", "validate_csrf", "CSRFProtect")
logger = logging.getLogger(__name__)
def generate_csrf(secret_key=None, token_key=None):
"""Generate a CSRF token. The token is cached for a request, so multiple
calls to this function will generate the same token.
During testing, it might be useful to access the signed token in
``g.csrf_token`` and the raw token in ``session['csrf_token']``.
:param secret_key: Used to securely sign the token. Default is
``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
:param token_key: Key where token is stored in session for comparison.
Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
"""
secret_key = _get_config(
secret_key,
"WTF_CSRF_SECRET_KEY",
current_app.secret_key,
message="A secret key is required to use CSRF.",
)
field_name = _get_config(
token_key,
"WTF_CSRF_FIELD_NAME",
"csrf_token",
message="A field name is required to use CSRF.",
)
if field_name not in g:
s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")
if field_name not in session:
session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()
try:
token = s.dumps(session[field_name])
except TypeError:
session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()
token = s.dumps(session[field_name])
setattr(g, field_name, token)
return g.get(field_name)
def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
"""Check if the given data is a valid CSRF token. This compares the given
signed token to the one stored in the session.
:param data: The signed CSRF token to be checked.
:param secret_key: Used to securely sign the token. Default is
``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
:param time_limit: Number of seconds that the token is valid. Default is
``WTF_CSRF_TIME_LIMIT`` or 3600 seconds (60 minutes).
:param token_key: Key where token is stored in session for comparison.
Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
:raises ValidationError: Contains the reason that validation failed.
.. versionchanged:: 0.14
Raises ``ValidationError`` with a specific error message rather than
returning ``True`` or ``False``.
"""
secret_key = _get_config(
secret_key,
"WTF_CSRF_SECRET_KEY",
current_app.secret_key,
message="A secret key is required to use CSRF.",
)
field_name = _get_config(
token_key,
"WTF_CSRF_FIELD_NAME",
"csrf_token",
message="A field name is required to use CSRF.",
)
time_limit = _get_config(time_limit, "WTF_CSRF_TIME_LIMIT", 3600, required=False)
if not data:
raise ValidationError("The CSRF token is missing.")
if field_name not in session:
raise ValidationError("The CSRF session token is missing.")
s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")
try:
token = s.loads(data, max_age=time_limit)
except SignatureExpired as e:
raise ValidationError("The CSRF token has expired.") from e
except BadData as e:
raise ValidationError("The CSRF token is invalid.") from e
if not hmac.compare_digest(session[field_name], token):
raise ValidationError("The CSRF tokens do not match.")
def _get_config(
value, config_name, default=None, required=True, message="CSRF is not configured."
):
"""Find config value based on provided value, Flask config, and default
value.
:param value: already provided config value
:param config_name: Flask ``config`` key
:param default: default value if not provided or configured
:param required: whether the value must not be ``None``
:param message: error message if required config is not found
:raises KeyError: if required config is not found
"""
if value is None:
value = current_app.config.get(config_name, default)
if required and value is None:
raise RuntimeError(message)
return value
class _FlaskFormCSRF(CSRF):
def setup_form(self, form):
self.meta = form.meta
return super().setup_form(form)
def generate_csrf_token(self, csrf_token_field):
return generate_csrf(
secret_key=self.meta.csrf_secret, token_key=self.meta.csrf_field_name
)
def validate_csrf_token(self, form, field):
if g.get("csrf_valid", False):
# already validated by CSRFProtect
return
try:
validate_csrf(
field.data,
self.meta.csrf_secret,
self.meta.csrf_time_limit,
self.meta.csrf_field_name,
)
except ValidationError as e:
logger.info(e.args[0])
raise
class CSRFProtect:
"""Enable CSRF protection globally for a Flask app.
::
app = Flask(__name__)
csrf = CSRFProtect(app)
Checks the ``csrf_token`` field sent with forms, or the ``X-CSRFToken``
header sent with JavaScript requests. Render the token in templates using
``{{ csrf_token() }}``.
See the :ref:`csrf` documentation.
"""
def __init__(self, app=None):
self._exempt_views = set()
self._exempt_blueprints = set()
if app:
self.init_app(app)
def init_app(self, app):
app.extensions["csrf"] = self
app.config.setdefault("WTF_CSRF_ENABLED", True)
app.config.setdefault("WTF_CSRF_CHECK_DEFAULT", True)
app.config["WTF_CSRF_METHODS"] = set(
app.config.get("WTF_CSRF_METHODS", ["POST", "PUT", "PATCH", "DELETE"])
)
app.config.setdefault("WTF_CSRF_FIELD_NAME", "csrf_token")
app.config.setdefault("WTF_CSRF_HEADERS", ["X-CSRFToken", "X-CSRF-Token"])
app.config.setdefault("WTF_CSRF_TIME_LIMIT", 3600)
app.config.setdefault("WTF_CSRF_SSL_STRICT", True)
app.jinja_env.globals["csrf_token"] = generate_csrf
app.context_processor(lambda: {"csrf_token": generate_csrf})
@app.before_request
def csrf_protect():
if not app.config["WTF_CSRF_ENABLED"]:
return
if not app.config["WTF_CSRF_CHECK_DEFAULT"]:
return
if request.method not in app.config["WTF_CSRF_METHODS"]:
return
if not request.endpoint:
return
if app.blueprints.get(request.blueprint) in self._exempt_blueprints:
return
view = app.view_functions.get(request.endpoint)
dest = f"{view.__module__}.{view.__name__}"
if dest in self._exempt_views:
return
self.protect()
def _get_csrf_token(self):
# find the token in the form data
field_name = current_app.config["WTF_CSRF_FIELD_NAME"]
base_token = request.form.get(field_name)
if base_token:
return base_token
# if the form has a prefix, the name will be {prefix}-csrf_token
for key in request.form:
if key.endswith(field_name):
csrf_token = request.form[key]
if csrf_token:
return csrf_token
# find the token in the headers
for header_name in current_app.config["WTF_CSRF_HEADERS"]:
csrf_token = request.headers.get(header_name)
if csrf_token:
return csrf_token
return None
def protect(self):
if request.method not in current_app.config["WTF_CSRF_METHODS"]:
return
try:
validate_csrf(self._get_csrf_token())
except ValidationError as e:
logger.info(e.args[0])
self._error_response(e.args[0])
if request.is_secure and current_app.config["WTF_CSRF_SSL_STRICT"]:
if not request.referrer:
self._error_response("The referrer header is missing.")
good_referrer = f"https://{request.host}/"
if not same_origin(request.referrer, good_referrer):
self._error_response("The referrer does not match the host.")
g.csrf_valid = True # mark this request as CSRF valid
def exempt(self, view):
"""Mark a view or blueprint to be excluded from CSRF protection.
::
@app.route('/some-view', methods=['POST'])
@csrf.exempt
def some_view():
...
::
bp = Blueprint(...)
csrf.exempt(bp)
"""
if isinstance(view, Blueprint):
self._exempt_blueprints.add(view)
return view
if isinstance(view, str):
view_location = view
else:
view_location = ".".join((view.__module__, view.__name__))
self._exempt_views.add(view_location)
return view
def _error_response(self, reason):
raise CSRFError(reason)
class CSRFError(BadRequest):
"""Raise if the client sends invalid CSRF data with the request.
Generates a 400 Bad Request response with the failure reason by default.
Customize the response by registering a handler with
:meth:`flask.Flask.errorhandler`.
"""
description = "CSRF validation failed."
def same_origin(current_uri, compare_uri):
current = urlparse(current_uri)
compare = urlparse(compare_uri)
return (
current.scheme == compare.scheme
and current.hostname == compare.hostname
and current.port == compare.port
)