1
+ import typing
1
2
from unittest import mock
2
3
3
4
import opentelemetry .instrumentation .asgi as otel_asgi
5
+ from opentelemetry import trace
6
+ from opentelemetry .context import Context
7
+ from opentelemetry .propagate import get_global_textmap , set_global_textmap
8
+ from opentelemetry .propagators .textmap import (
9
+ CarrierT ,
10
+ Getter ,
11
+ Setter ,
12
+ TextMapPropagator ,
13
+ default_getter ,
14
+ default_setter ,
15
+ )
4
16
from opentelemetry .test .asgitestutil import AsgiTestBase
5
17
from opentelemetry .test .test_base import TestBase
6
18
from opentelemetry .trace import SpanKind
13
25
from .test_asgi_middleware import simple_asgi
14
26
15
27
28
+ class MockTextMapPropagator (TextMapPropagator ):
29
+ """Mock propagator for testing purposes using both getter `get` and `all`."""
30
+
31
+ TRACE_ID_KEY = "mock-traceid"
32
+ SPAN_ID_KEY = "mock-spanid"
33
+
34
+ def extract (
35
+ self ,
36
+ carrier : CarrierT ,
37
+ context : typing .Optional [Context ] = None ,
38
+ getter : Getter = default_getter ,
39
+ ) -> Context :
40
+ if context is None :
41
+ context = Context ()
42
+
43
+ trace_id_list = getter .get (carrier , self .TRACE_ID_KEY )
44
+ span_id_list = getter .get (carrier , self .SPAN_ID_KEY )
45
+ carrier_keys = getter .keys (carrier )
46
+
47
+ if not trace_id_list or not span_id_list :
48
+ assert not any (key in carrier_keys for key in self .fields )
49
+ return context
50
+
51
+ assert all (key in carrier_keys for key in self .fields )
52
+ return trace .set_span_in_context (
53
+ trace .NonRecordingSpan (
54
+ trace .SpanContext (
55
+ trace_id = int (trace_id_list [0 ]),
56
+ span_id = int (span_id_list [0 ]),
57
+ is_remote = True ,
58
+ )
59
+ ),
60
+ context ,
61
+ )
62
+
63
+ def inject (
64
+ self ,
65
+ carrier : CarrierT ,
66
+ context : typing .Optional [Context ] = None ,
67
+ setter : Setter = default_setter ,
68
+ ) -> None :
69
+ span = trace .get_current_span (context )
70
+ setter .set (
71
+ carrier , self .TRACE_ID_KEY , str (span .get_span_context ().trace_id )
72
+ )
73
+ setter .set (
74
+ carrier , self .SPAN_ID_KEY , str (span .get_span_context ().span_id )
75
+ )
76
+
77
+ @property
78
+ def fields (self ):
79
+ return {self .TRACE_ID_KEY , self .SPAN_ID_KEY }
80
+
81
+
16
82
async def http_app_with_custom_headers (scope , receive , send ):
17
83
message = await receive ()
18
84
assert scope ["type" ] == "http"
@@ -34,6 +100,8 @@ async def http_app_with_custom_headers(scope, receive, send):
34
100
b"my-custom-regex-value-3,my-custom-regex-value-4" ,
35
101
),
36
102
(b"my-secret-header" , b"my-secret-value" ),
103
+ (MockTextMapPropagator .TRACE_ID_KEY .encode (), b"1" ),
104
+ (MockTextMapPropagator .SPAN_ID_KEY .encode (), b"2" ),
37
105
],
38
106
}
39
107
)
@@ -60,6 +128,8 @@ async def websocket_app_with_custom_headers(scope, receive, send):
60
128
b"my-custom-regex-value-3,my-custom-regex-value-4" ,
61
129
),
62
130
(b"my-secret-header" , b"my-secret-value" ),
131
+ (MockTextMapPropagator .TRACE_ID_KEY .encode (), b"1" ),
132
+ (MockTextMapPropagator .SPAN_ID_KEY .encode (), b"2" ),
63
133
],
64
134
}
65
135
)
@@ -88,6 +158,11 @@ def setUp(self):
88
158
self .app = otel_asgi .OpenTelemetryMiddleware (
89
159
simple_asgi , tracer_provider = self .tracer_provider
90
160
)
161
+ self .previous_propagator = get_global_textmap ()
162
+ set_global_textmap (MockTextMapPropagator ())
163
+
164
+ def tearDown (self ):
165
+ set_global_textmap (self .previous_propagator )
91
166
92
167
def test_http_custom_request_headers_in_span_attributes (self ):
93
168
self .scope ["headers" ].extend (
0 commit comments