12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import asyncio
15
16
import types
16
17
from unittest import mock
17
18
@@ -45,6 +46,35 @@ def __exit__(self, *args):
45
46
return self
46
47
47
48
49
+ class MockAsyncCursor :
50
+ def __init__ (self , * args , ** kwargs ):
51
+ pass
52
+
53
+ # pylint: disable=unused-argument, no-self-use
54
+ async def execute (self , query , params = None , throw_exception = False ):
55
+ if throw_exception :
56
+ raise Exception ("Test Exception" )
57
+
58
+ # pylint: disable=unused-argument, no-self-use
59
+ async def executemany (self , query , params = None , throw_exception = False ):
60
+ if throw_exception :
61
+ raise Exception ("Test Exception" )
62
+
63
+ # pylint: disable=unused-argument, no-self-use
64
+ async def callproc (self , query , params = None , throw_exception = False ):
65
+ if throw_exception :
66
+ raise Exception ("Test Exception" )
67
+
68
+ async def __aenter__ (self , * args , ** kwargs ):
69
+ return self
70
+
71
+ async def __aexit__ (self , * args , ** kwargs ):
72
+ pass
73
+
74
+ def close (self ):
75
+ pass
76
+
77
+
48
78
class MockConnection :
49
79
commit = mock .MagicMock (spec = types .MethodType )
50
80
commit .__name__ = "commit"
@@ -64,22 +94,68 @@ def get_dsn_parameters(self): # pylint: disable=no-self-use
64
94
return {"dbname" : "test" }
65
95
66
96
97
+ class MockAsyncConnection :
98
+ commit = mock .MagicMock (spec = types .MethodType )
99
+ commit .__name__ = "commit"
100
+
101
+ rollback = mock .MagicMock (spec = types .MethodType )
102
+ rollback .__name__ = "rollback"
103
+
104
+ def __init__ (self , * args , ** kwargs ):
105
+ self .cursor_factory = kwargs .pop ("cursor_factory" , None )
106
+
107
+ @staticmethod
108
+ async def connect (* args , ** kwargs ):
109
+ return MockAsyncConnection (** kwargs )
110
+
111
+ def cursor (self ):
112
+ if self .cursor_factory :
113
+ cur = self .cursor_factory (self )
114
+ return cur
115
+ return MockAsyncCursor ()
116
+
117
+ def get_dsn_parameters (self ): # pylint: disable=no-self-use
118
+ return {"dbname" : "test" }
119
+
120
+ async def __aenter__ (self ):
121
+ return self
122
+
123
+ async def __aexit__ (self , * args ):
124
+ return mock .MagicMock (spec = types .MethodType )
125
+
126
+
67
127
class TestPostgresqlIntegration (TestBase ):
68
128
def setUp (self ):
69
129
super ().setUp ()
70
130
self .cursor_mock = mock .patch (
71
131
"opentelemetry.instrumentation.psycopg.pg_cursor" , MockCursor
72
132
)
133
+ self .cursor_async_mock = mock .patch (
134
+ "opentelemetry.instrumentation.psycopg.pg_async_cursor" ,
135
+ MockAsyncCursor ,
136
+ )
73
137
self .connection_mock = mock .patch ("psycopg.connect" , MockConnection )
138
+ self .connection_sync_mock = mock .patch (
139
+ "psycopg.Connection.connect" , MockConnection
140
+ )
141
+ self .connection_async_mock = mock .patch (
142
+ "psycopg.AsyncConnection.connect" , MockAsyncConnection .connect
143
+ )
74
144
75
145
self .cursor_mock .start ()
146
+ self .cursor_async_mock .start ()
76
147
self .connection_mock .start ()
148
+ self .connection_sync_mock .start ()
149
+ self .connection_async_mock .start ()
77
150
78
151
def tearDown (self ):
79
152
super ().tearDown ()
80
153
self .memory_exporter .clear ()
81
154
self .cursor_mock .stop ()
155
+ self .cursor_async_mock .stop ()
82
156
self .connection_mock .stop ()
157
+ self .connection_sync_mock .stop ()
158
+ self .connection_async_mock .stop ()
83
159
with self .disable_logging ():
84
160
PsycopgInstrumentor ().uninstrument ()
85
161
@@ -114,6 +190,91 @@ def test_instrumentor(self):
114
190
spans_list = self .memory_exporter .get_finished_spans ()
115
191
self .assertEqual (len (spans_list ), 1 )
116
192
193
+ # pylint: disable=unused-argument
194
+ def test_instrumentor_with_connection_class (self ):
195
+ PsycopgInstrumentor ().instrument ()
196
+
197
+ cnx = psycopg .Connection .connect (database = "test" )
198
+
199
+ cursor = cnx .cursor ()
200
+
201
+ query = "SELECT * FROM test"
202
+ cursor .execute (query )
203
+
204
+ spans_list = self .memory_exporter .get_finished_spans ()
205
+ self .assertEqual (len (spans_list ), 1 )
206
+ span = spans_list [0 ]
207
+
208
+ # Check version and name in span's instrumentation info
209
+ self .assertEqualSpanInstrumentationInfo (
210
+ span , opentelemetry .instrumentation .psycopg
211
+ )
212
+
213
+ # check that no spans are generated after uninstrument
214
+ PsycopgInstrumentor ().uninstrument ()
215
+
216
+ cnx = psycopg .Connection .connect (database = "test" )
217
+ cursor = cnx .cursor ()
218
+ query = "SELECT * FROM test"
219
+ cursor .execute (query )
220
+
221
+ spans_list = self .memory_exporter .get_finished_spans ()
222
+ self .assertEqual (len (spans_list ), 1 )
223
+
224
+ async def test_wrap_async_connection_class_with_cursor (self ):
225
+ PsycopgInstrumentor ().instrument ()
226
+
227
+ async def test_async_connection ():
228
+ acnx = await psycopg .AsyncConnection .connect (database = "test" )
229
+ async with acnx as cnx :
230
+ async with cnx .cursor () as cursor :
231
+ await cursor .execute ("SELECT * FROM test" )
232
+
233
+ asyncio .run (test_async_connection ())
234
+ spans_list = self .memory_exporter .get_finished_spans ()
235
+ self .assertEqual (len (spans_list ), 1 )
236
+ span = spans_list [0 ]
237
+
238
+ # Check version and name in span's instrumentation info
239
+ self .assertEqualSpanInstrumentationInfo (
240
+ span , opentelemetry .instrumentation .psycopg
241
+ )
242
+
243
+ # check that no spans are generated after uninstrument
244
+ PsycopgInstrumentor ().uninstrument ()
245
+
246
+ asyncio .run (test_async_connection ())
247
+
248
+ spans_list = self .memory_exporter .get_finished_spans ()
249
+ self .assertEqual (len (spans_list ), 1 )
250
+
251
+ # pylint: disable=unused-argument
252
+ async def test_instrumentor_with_async_connection_class (self ):
253
+ PsycopgInstrumentor ().instrument ()
254
+
255
+ async def test_async_connection ():
256
+ acnx = await psycopg .AsyncConnection .connect (database = "test" )
257
+ async with acnx as cnx :
258
+ await cnx .execute ("SELECT * FROM test" )
259
+
260
+ asyncio .run (test_async_connection ())
261
+
262
+ spans_list = self .memory_exporter .get_finished_spans ()
263
+ self .assertEqual (len (spans_list ), 1 )
264
+ span = spans_list [0 ]
265
+
266
+ # Check version and name in span's instrumentation info
267
+ self .assertEqualSpanInstrumentationInfo (
268
+ span , opentelemetry .instrumentation .psycopg
269
+ )
270
+
271
+ # check that no spans are generated after uninstrument
272
+ PsycopgInstrumentor ().uninstrument ()
273
+ asyncio .run (test_async_connection ())
274
+
275
+ spans_list = self .memory_exporter .get_finished_spans ()
276
+ self .assertEqual (len (spans_list ), 1 )
277
+
117
278
def test_span_name (self ):
118
279
PsycopgInstrumentor ().instrument ()
119
280
@@ -140,6 +301,33 @@ def test_span_name(self):
140
301
self .assertEqual (spans_list [4 ].name , "query" )
141
302
self .assertEqual (spans_list [5 ].name , "query" )
142
303
304
+ async def test_span_name_async (self ):
305
+ PsycopgInstrumentor ().instrument ()
306
+
307
+ cnx = psycopg .AsyncConnection .connect (database = "test" )
308
+ async with cnx .cursor () as cursor :
309
+ await cursor .execute ("Test query" , ("param1Value" , False ))
310
+ await cursor .execute (
311
+ """multi
312
+ line
313
+ query"""
314
+ )
315
+ await cursor .execute ("tab\t separated query" )
316
+ await cursor .execute ("/* leading comment */ query" )
317
+ await cursor .execute (
318
+ "/* leading comment */ query /* trailing comment */"
319
+ )
320
+ await cursor .execute ("query /* trailing comment */" )
321
+
322
+ spans_list = self .memory_exporter .get_finished_spans ()
323
+ self .assertEqual (len (spans_list ), 6 )
324
+ self .assertEqual (spans_list [0 ].name , "Test" )
325
+ self .assertEqual (spans_list [1 ].name , "multi" )
326
+ self .assertEqual (spans_list [2 ].name , "tab" )
327
+ self .assertEqual (spans_list [3 ].name , "query" )
328
+ self .assertEqual (spans_list [4 ].name , "query" )
329
+ self .assertEqual (spans_list [5 ].name , "query" )
330
+
143
331
# pylint: disable=unused-argument
144
332
def test_not_recording (self ):
145
333
mock_tracer = mock .Mock ()
@@ -160,6 +348,26 @@ def test_not_recording(self):
160
348
161
349
PsycopgInstrumentor ().uninstrument ()
162
350
351
+ # pylint: disable=unused-argument
352
+ async def test_not_recording_async (self ):
353
+ mock_tracer = mock .Mock ()
354
+ mock_span = mock .Mock ()
355
+ mock_span .is_recording .return_value = False
356
+ mock_tracer .start_span .return_value = mock_span
357
+ PsycopgInstrumentor ().instrument ()
358
+ with mock .patch ("opentelemetry.trace.get_tracer" ) as tracer :
359
+ tracer .return_value = mock_tracer
360
+ cnx = psycopg .AsyncConnection .connect (database = "test" )
361
+ async with cnx .cursor () as cursor :
362
+ query = "SELECT * FROM test"
363
+ cursor .execute (query )
364
+ self .assertFalse (mock_span .is_recording ())
365
+ self .assertTrue (mock_span .is_recording .called )
366
+ self .assertFalse (mock_span .set_attribute .called )
367
+ self .assertFalse (mock_span .set_status .called )
368
+
369
+ PsycopgInstrumentor ().uninstrument ()
370
+
163
371
# pylint: disable=unused-argument
164
372
def test_custom_tracer_provider (self ):
165
373
resource = resources .Resource .create ({})
0 commit comments