@@ -238,13 +238,14 @@ def response_hook(span: Span, status: str, response_headers: List):
238
238
API
239
239
---
240
240
"""
241
+ import weakref
241
242
from logging import getLogger
242
- from threading import get_ident
243
243
from time import time_ns
244
244
from timeit import default_timer
245
245
from typing import Collection
246
246
247
247
import flask
248
+ from packaging import version as package_version
248
249
249
250
import opentelemetry .instrumentation .wsgi as otel_wsgi
250
251
from opentelemetry import context , trace
@@ -265,11 +266,21 @@ def response_hook(span: Span, status: str, response_headers: List):
265
266
_ENVIRON_STARTTIME_KEY = "opentelemetry-flask.starttime_key"
266
267
_ENVIRON_SPAN_KEY = "opentelemetry-flask.span_key"
267
268
_ENVIRON_ACTIVATION_KEY = "opentelemetry-flask.activation_key"
268
- _ENVIRON_THREAD_ID_KEY = "opentelemetry-flask.thread_id_key "
269
+ _ENVIRON_REQCTX_REF_KEY = "opentelemetry-flask.reqctx_ref_key "
269
270
_ENVIRON_TOKEN = "opentelemetry-flask.token"
270
271
271
272
_excluded_urls_from_env = get_excluded_urls ("FLASK" )
272
273
274
+ if package_version .parse (flask .__version__ ) >= package_version .parse ("2.2.0" ):
275
+
276
+ def _request_ctx_ref () -> weakref .ReferenceType :
277
+ return weakref .ref (flask .globals .request_ctx ._get_current_object ())
278
+
279
+ else :
280
+
281
+ def _request_ctx_ref () -> weakref .ReferenceType :
282
+ return weakref .ref (flask ._request_ctx_stack .top )
283
+
273
284
274
285
def get_default_span_name ():
275
286
try :
@@ -399,7 +410,7 @@ def _before_request():
399
410
activation = trace .use_span (span , end_on_exit = True )
400
411
activation .__enter__ () # pylint: disable=E1101
401
412
flask_request_environ [_ENVIRON_ACTIVATION_KEY ] = activation
402
- flask_request_environ [_ENVIRON_THREAD_ID_KEY ] = get_ident ()
413
+ flask_request_environ [_ENVIRON_REQCTX_REF_KEY ] = _request_ctx_ref ()
403
414
flask_request_environ [_ENVIRON_SPAN_KEY ] = span
404
415
flask_request_environ [_ENVIRON_TOKEN ] = token
405
416
@@ -439,17 +450,22 @@ def _teardown_request(exc):
439
450
return
440
451
441
452
activation = flask .request .environ .get (_ENVIRON_ACTIVATION_KEY )
442
- thread_id = flask .request .environ .get (_ENVIRON_THREAD_ID_KEY )
443
- if not activation or thread_id != get_ident ():
453
+
454
+ original_reqctx_ref = flask .request .environ .get (
455
+ _ENVIRON_REQCTX_REF_KEY
456
+ )
457
+ current_reqctx_ref = _request_ctx_ref ()
458
+ if not activation or original_reqctx_ref != current_reqctx_ref :
444
459
# This request didn't start a span, maybe because it was created in
445
460
# a way that doesn't run `before_request`, like when it is created
446
461
# with `app.test_request_context`.
447
462
#
448
- # Similarly, check the thread_id against the current thread to ensure
449
- # tear down only happens on the original thread. This situation can
450
- # arise if the original thread handling the request spawn children
451
- # threads and then uses something like copy_current_request_context
452
- # to copy the request context.
463
+ # Similarly, check that the request_ctx that created the span
464
+ # matches the current request_ctx, and only tear down if they match.
465
+ # This situation can arise if the original request_ctx handling
466
+ # the request calls functions that push new request_ctx's,
467
+ # like any decorated with `flask.copy_current_request_context`.
468
+
453
469
return
454
470
if exc is None :
455
471
activation .__exit__ (None , None , None )
0 commit comments