diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e9b4fdf9d..50f591d950 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#1824](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1824)) - Fix sqlalchemy instrumentation wrap methods to accept sqlcommenter options ([#1873](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1873)) +- Exclude background task execution from root server span in ASGI middleware + ([#1952](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1952)) ### Added diff --git a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py index 8d5aa4e2d2..ae47a5cb4f 100644 --- a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py @@ -576,7 +576,7 @@ async def __call__(self, scope, receive, send): if scope["type"] == "http": self.active_requests_counter.add(1, active_requests_count_attrs) try: - with trace.use_span(span, end_on_exit=True) as current_span: + with trace.use_span(span, end_on_exit=False) as current_span: if current_span.is_recording(): for key, value in attributes.items(): current_span.set_attribute(key, value) @@ -630,6 +630,8 @@ async def __call__(self, scope, receive, send): ) if token: context.detach(token) + if span.is_recording(): + span.end() # pylint: enable=too-many-branches @@ -653,8 +655,11 @@ async def otel_receive(): def _get_otel_send( self, server_span, server_span_name, scope, send, duration_attrs ): + expecting_trailers = False + @wraps(send) async def otel_send(message): + nonlocal expecting_trailers with self.tracer.start_as_current_span( " ".join((server_span_name, scope["type"], "send")) ) as send_span: @@ -668,6 +673,8 @@ async def otel_send(message): ] = status_code set_status_code(server_span, status_code) set_status_code(send_span, status_code) + + expecting_trailers = message.get("trailers", False) elif message["type"] == "websocket.send": set_status_code(server_span, 200) set_status_code(send_span, 200) @@ -703,5 +710,15 @@ async def otel_send(message): pass await send(message) + if ( + not expecting_trailers + and message["type"] == "http.response.body" + and not message.get("more_body", False) + ) or ( + expecting_trailers + and message["type"] == "http.response.trailers" + and not message.get("more_trailers", False) + ): + server_span.end() return otel_send diff --git a/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_middleware.py b/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_middleware.py index 209acdf663..da7bc8ea74 100644 --- a/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_middleware.py +++ b/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_middleware.py @@ -16,6 +16,7 @@ import asyncio import sys +import time import unittest from timeit import default_timer from unittest import mock @@ -57,6 +58,8 @@ "http.server.request.size": _duration_attrs, } +_SIMULATED_BACKGROUND_TASK_EXECUTION_TIME_S = 0.01 + async def http_app(scope, receive, send): message = await receive() @@ -99,6 +102,108 @@ async def simple_asgi(scope, receive, send): await websocket_app(scope, receive, send) +async def long_response_asgi(scope, receive, send): + assert isinstance(scope, dict) + assert scope["type"] == "http" + message = await receive() + scope["headers"] = [(b"content-length", b"128")] + assert scope["type"] == "http" + if message.get("type") == "http.request": + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + [b"Content-Type", b"text/plain"], + [b"content-length", b"1024"], + ], + } + ) + await send( + {"type": "http.response.body", "body": b"*", "more_body": True} + ) + await send( + {"type": "http.response.body", "body": b"*", "more_body": True} + ) + await send( + {"type": "http.response.body", "body": b"*", "more_body": True} + ) + await send( + {"type": "http.response.body", "body": b"*", "more_body": False} + ) + + +async def background_execution_asgi(scope, receive, send): + assert isinstance(scope, dict) + assert scope["type"] == "http" + message = await receive() + scope["headers"] = [(b"content-length", b"128")] + assert scope["type"] == "http" + if message.get("type") == "http.request": + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + [b"Content-Type", b"text/plain"], + [b"content-length", b"1024"], + ], + } + ) + await send( + { + "type": "http.response.body", + "body": b"*", + } + ) + time.sleep(_SIMULATED_BACKGROUND_TASK_EXECUTION_TIME_S) + + +async def background_execution_trailers_asgi(scope, receive, send): + assert isinstance(scope, dict) + assert scope["type"] == "http" + message = await receive() + scope["headers"] = [(b"content-length", b"128")] + assert scope["type"] == "http" + if message.get("type") == "http.request": + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + [b"Content-Type", b"text/plain"], + [b"content-length", b"1024"], + ], + "trailers": True, + } + ) + await send( + {"type": "http.response.body", "body": b"*", "more_body": True} + ) + await send( + {"type": "http.response.body", "body": b"*", "more_body": False} + ) + await send( + { + "type": "http.response.trailers", + "headers": [ + [b"trailer", b"test-trailer"], + ], + "more_trailers": True, + } + ) + await send( + { + "type": "http.response.trailers", + "headers": [ + [b"trailer", b"second-test-trailer"], + ], + "more_trailers": False, + } + ) + time.sleep(_SIMULATED_BACKGROUND_TASK_EXECUTION_TIME_S) + + async def error_asgi(scope, receive, send): assert isinstance(scope, dict) assert scope["type"] == "http" @@ -127,14 +232,19 @@ def validate_outputs(self, outputs, error=None, modifiers=None): # Ensure modifiers is a list modifiers = modifiers or [] # Check for expected outputs - self.assertEqual(len(outputs), 2) response_start = outputs[0] - response_body = outputs[1] + response_final_body = [ + output + for output in outputs + if output["type"] == "http.response.body" + ][-1] + self.assertEqual(response_start["type"], "http.response.start") - self.assertEqual(response_body["type"], "http.response.body") + self.assertEqual(response_final_body["type"], "http.response.body") + self.assertEqual(response_final_body.get("more_body", False), False) # Check http response body - self.assertEqual(response_body["body"], b"*") + self.assertEqual(response_final_body["body"], b"*") # Check http response start self.assertEqual(response_start["status"], 200) @@ -153,7 +263,6 @@ def validate_outputs(self, outputs, error=None, modifiers=None): # Check spans span_list = self.memory_exporter.get_finished_spans() - self.assertEqual(len(span_list), 4) expected = [ { "name": "GET / http receive", @@ -194,6 +303,7 @@ def validate_outputs(self, outputs, error=None, modifiers=None): for modifier in modifiers: expected = modifier(expected) # Check that output matches + self.assertEqual(len(span_list), len(expected)) for span, expected in zip(span_list, expected): self.assertEqual(span.name, expected["name"]) self.assertEqual(span.kind, expected["kind"]) @@ -232,6 +342,80 @@ def test_asgi_exc_info(self): outputs = self.get_all_output() self.validate_outputs(outputs, error=ValueError) + def test_long_response(self): + """Test that the server span is ended on the final response body message. + + If the server span is ended early then this test will fail due + to discrepancies in the expected list of spans and the emitted list of spans. + """ + app = otel_asgi.OpenTelemetryMiddleware(long_response_asgi) + self.seed_app(app) + self.send_default_request() + outputs = self.get_all_output() + + def add_more_body_spans(expected: list): + more_body_span = { + "name": "GET / http send", + "kind": trace_api.SpanKind.INTERNAL, + "attributes": {"type": "http.response.body"}, + } + extra_spans = [more_body_span] * 3 + expected[2:2] = extra_spans + return expected + + self.validate_outputs(outputs, modifiers=[add_more_body_spans]) + + def test_background_execution(self): + """Test that the server span is ended BEFORE the background task is finished.""" + app = otel_asgi.OpenTelemetryMiddleware(background_execution_asgi) + self.seed_app(app) + self.send_default_request() + outputs = self.get_all_output() + self.validate_outputs(outputs) + span_list = self.memory_exporter.get_finished_spans() + server_span = span_list[-1] + assert server_span.kind == SpanKind.SERVER + span_duration_nanos = server_span.end_time - server_span.start_time + self.assertLessEqual( + span_duration_nanos, + _SIMULATED_BACKGROUND_TASK_EXECUTION_TIME_S * 10**9, + ) + + def test_trailers(self): + """Test that trailers are emitted as expected and that the server span is ended + BEFORE the background task is finished.""" + app = otel_asgi.OpenTelemetryMiddleware( + background_execution_trailers_asgi + ) + self.seed_app(app) + self.send_default_request() + outputs = self.get_all_output() + + def add_body_and_trailer_span(expected: list): + body_span = { + "name": "GET / http send", + "kind": trace_api.SpanKind.INTERNAL, + "attributes": {"type": "http.response.body"}, + } + trailer_span = { + "name": "GET / http send", + "kind": trace_api.SpanKind.INTERNAL, + "attributes": {"type": "http.response.trailers"}, + } + expected[2:2] = [body_span] + expected[4:4] = [trailer_span] * 2 + return expected + + self.validate_outputs(outputs, modifiers=[add_body_and_trailer_span]) + span_list = self.memory_exporter.get_finished_spans() + server_span = span_list[-1] + assert server_span.kind == SpanKind.SERVER + span_duration_nanos = server_span.end_time - server_span.start_time + self.assertLessEqual( + span_duration_nanos, + _SIMULATED_BACKGROUND_TASK_EXECUTION_TIME_S * 10**9, + ) + def test_override_span_name(self): """Test that default span_names can be overwritten by our callback function.""" span_name = "Dymaxion"