Skip to content

Commit e312bec

Browse files
committed
Convert the unittest tests to pytest test style.
1 parent 767adcc commit e312bec

File tree

1 file changed

+102
-114
lines changed

1 file changed

+102
-114
lines changed

instrumentation/opentelemetry-instrumentation-aiohttp-server/tests/test_aiohttp_server_integration.py

+102-114
Original file line numberDiff line numberDiff line change
@@ -12,124 +12,112 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import asyncio
16-
import contextlib
17-
import typing
18-
import unittest
19-
import urllib.parse
20-
from functools import partial
21-
from unittest import mock
22-
15+
import pytest
16+
import pytest_asyncio
2317
import aiohttp
24-
import aiohttp.test_utils
18+
from http import HTTPMethod, HTTPStatus
2519
from pkg_resources import iter_entry_points
20+
from unittest import mock
2621

22+
from opentelemetry import trace as trace_api
23+
from opentelemetry.test.test_base import TestBase
2724
from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor
2825
from opentelemetry.semconv.trace import SpanAttributes
29-
from opentelemetry.test.test_base import TestBase
3026

27+
from opentelemetry.test.globals_test import (
28+
reset_trace_globals,
29+
)
30+
31+
32+
@pytest.fixture(scope="session")
33+
def tracer():
34+
test_base = TestBase()
35+
36+
tracer_provider, memory_exporter = test_base.create_tracer_provider()
37+
38+
reset_trace_globals()
39+
trace_api.set_tracer_provider(tracer_provider)
40+
41+
yield tracer_provider, memory_exporter
42+
43+
reset_trace_globals()
44+
45+
46+
async def default_handler(request, status=200):
47+
return aiohttp.web.Response(status=status)
48+
49+
50+
@pytest_asyncio.fixture
51+
async def server_fixture(tracer, aiohttp_server):
52+
_, memory_exporter = tracer
53+
54+
AioHttpServerInstrumentor().instrument()
55+
56+
app = aiohttp.web.Application()
57+
app.add_routes(
58+
[aiohttp.web.get("/test-path", default_handler)])
59+
60+
server = await aiohttp_server(app)
61+
yield server, app
62+
63+
memory_exporter.clear()
64+
65+
AioHttpServerInstrumentor().uninstrument()
66+
67+
68+
def test_checking_instrumentor_pkg_installed():
69+
entry_points = iter_entry_points(
70+
"opentelemetry_instrumentor", "aiohttp-server"
71+
)
72+
73+
instrumentor = next(entry_points).load()()
74+
assert (isinstance(instrumentor, AioHttpServerInstrumentor))
75+
76+
77+
@pytest.mark.asyncio
78+
@pytest.mark.parametrize("url, expected_method, expected_status_code", [
79+
("/test-path", HTTPMethod.GET, HTTPStatus.OK),
80+
("/not-found", HTTPMethod.GET, HTTPStatus.NOT_FOUND)
81+
])
82+
async def test_status_code_instrumentation(tracer, server_fixture,
83+
aiohttp_client, url,
84+
expected_method,
85+
expected_status_code):
86+
_, memory_exporter = tracer
87+
server, app = server_fixture
88+
89+
assert len(memory_exporter.get_finished_spans()) == 0
90+
91+
client = await aiohttp_client(server)
92+
await client.get(url)
93+
94+
assert len(memory_exporter.get_finished_spans()) == 1
95+
96+
[span] = memory_exporter.get_finished_spans()
97+
98+
assert expected_method == span.attributes[SpanAttributes.HTTP_METHOD]
99+
assert expected_status_code == span.attributes[SpanAttributes.HTTP_STATUS_CODE]
100+
101+
assert f"http://{server.host}:{server.port}{url}" == span.attributes[
102+
SpanAttributes.HTTP_URL
103+
]
104+
105+
106+
@pytest.mark.skip(reason="Historical purposes. Can't see the reason of this mock.")
107+
def test_not_recording(self):
108+
mock_tracer = mock.Mock()
109+
mock_span = mock.Mock()
110+
mock_span.is_recording.return_value = False
111+
mock_tracer.start_span.return_value = mock_span
112+
with mock.patch("opentelemetry.trace.get_tracer") as patched:
113+
patched.start_span.return_value = mock_span
114+
# pylint: disable=W0612
115+
# host, port = run_with_test_server(
116+
# self.get_default_request(), self.URL, self.default_handler
117+
# )
31118

32-
def run_with_test_server(
33-
runnable: typing.Callable, url: str, handler: typing.Callable
34-
) -> typing.Tuple[str, int]:
35-
async def do_request():
36-
app = aiohttp.web.Application()
37-
parsed_url = urllib.parse.urlparse(url)
38-
app.add_routes([aiohttp.web.get(parsed_url.path, handler)])
39-
app.add_routes([aiohttp.web.post(parsed_url.path, handler)])
40-
app.add_routes([aiohttp.web.patch(parsed_url.path, handler)])
41-
42-
with contextlib.suppress(aiohttp.ClientError):
43-
async with aiohttp.test_utils.TestServer(app) as server:
44-
netloc = (server.host, server.port)
45-
await server.start_server()
46-
await runnable(server)
47-
return netloc
48-
49-
loop = asyncio.get_event_loop()
50-
return loop.run_until_complete(do_request())
51-
52-
53-
class TestAioHttpServerIntegration(TestBase):
54-
URL = "/test-path"
55-
56-
def setUp(self):
57-
super().setUp()
58-
AioHttpServerInstrumentor().instrument()
59-
60-
def tearDown(self):
61-
super().tearDown()
62-
AioHttpServerInstrumentor().uninstrument()
63-
64-
@staticmethod
65-
# pylint:disable=unused-argument
66-
async def default_handler(request, status=200):
67-
return aiohttp.web.Response(status=status)
68-
69-
def assert_spans(self, num_spans: int):
70-
finished_spans = self.memory_exporter.get_finished_spans()
71-
self.assertEqual(num_spans, len(finished_spans))
72-
if num_spans == 0:
73-
return None
74-
if num_spans == 1:
75-
return finished_spans[0]
76-
return finished_spans
77-
78-
@staticmethod
79-
def get_default_request(url: str = URL):
80-
async def default_request(server: aiohttp.test_utils.TestServer):
81-
async with aiohttp.test_utils.TestClient(server) as session:
82-
await session.get(url)
83-
84-
return default_request
85-
86-
def test_instrument(self):
87-
host, port = run_with_test_server(
88-
self.get_default_request(), self.URL, self.default_handler
89-
)
90-
span = self.assert_spans(1)
91-
self.assertEqual("GET", span.attributes[SpanAttributes.HTTP_METHOD])
92-
self.assertEqual(
93-
f"http://{host}:{port}/test-path",
94-
span.attributes[SpanAttributes.HTTP_URL],
95-
)
96-
self.assertEqual(200, span.attributes[SpanAttributes.HTTP_STATUS_CODE])
97-
98-
def test_status_codes(self):
99-
error_handler = partial(self.default_handler, status=400)
100-
host, port = run_with_test_server(
101-
self.get_default_request(), self.URL, error_handler
102-
)
103-
span = self.assert_spans(1)
104-
self.assertEqual("GET", span.attributes[SpanAttributes.HTTP_METHOD])
105-
self.assertEqual(
106-
f"http://{host}:{port}/test-path",
107-
span.attributes[SpanAttributes.HTTP_URL],
108-
)
109-
self.assertEqual(400, span.attributes[SpanAttributes.HTTP_STATUS_CODE])
110-
111-
def test_not_recording(self):
112-
mock_tracer = mock.Mock()
113-
mock_span = mock.Mock()
114-
mock_span.is_recording.return_value = False
115-
mock_tracer.start_span.return_value = mock_span
116-
with mock.patch("opentelemetry.trace.get_tracer"):
117-
# pylint: disable=W0612
118-
host, port = run_with_test_server(
119-
self.get_default_request(), self.URL, self.default_handler
120-
)
121-
122-
self.assertFalse(mock_span.is_recording())
123-
self.assertTrue(mock_span.is_recording.called)
124-
self.assertFalse(mock_span.set_attribute.called)
125-
self.assertFalse(mock_span.set_status.called)
126-
127-
128-
class TestLoadingAioHttpInstrumentor(unittest.TestCase):
129-
def test_loading_instrumentor(self):
130-
entry_points = iter_entry_points(
131-
"opentelemetry_instrumentor", "aiohttp-server"
132-
)
133-
134-
instrumentor = next(entry_points).load()()
135-
self.assertIsInstance(instrumentor, AioHttpServerInstrumentor)
119+
self.assertTrue(patched.start_span.called)
120+
self.assertFalse(mock_span.is_recording())
121+
self.assertTrue(mock_span.is_recording.called)
122+
self.assertFalse(mock_span.set_attribute.called)
123+
self.assertFalse(mock_span.set_status.called)

0 commit comments

Comments
 (0)