diff --git a/tests/conftest.py b/tests/conftest.py index 16f3fbb9db..bad9f43e42 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import random import time from typing import Callable, TypeVar +from unittest import mock from unittest.mock import Mock from urllib.parse import urlparse @@ -9,7 +10,7 @@ import redis from packaging.version import Version from redis.backoff import NoBackoff -from redis.connection import parse_url +from redis.connection import Connection, parse_url from redis.exceptions import RedisClusterException from redis.retry import Retry @@ -39,7 +40,6 @@ def __init__( help=None, metavar=None, ): - _option_strings = [] for option_string in option_strings: _option_strings.append(option_string) @@ -72,7 +72,6 @@ def format_usage(self): def pytest_addoption(parser): - parser.addoption( "--redis-url", default=default_redis_url, @@ -354,23 +353,23 @@ def sslclient(request): def _gen_cluster_mock_resp(r, response): - connection = Mock() + connection = Mock(spec=Connection) connection.retry = Retry(NoBackoff(), 0) connection.read_response.return_value = response - r.connection = connection - return r + with mock.patch.object(r, "connection", connection): + yield r @pytest.fixture() def mock_cluster_resp_ok(request, **kwargs): r = _get_client(redis.Redis, request, **kwargs) - return _gen_cluster_mock_resp(r, "OK") + yield from _gen_cluster_mock_resp(r, "OK") @pytest.fixture() def mock_cluster_resp_int(request, **kwargs): r = _get_client(redis.Redis, request, **kwargs) - return _gen_cluster_mock_resp(r, 2) + yield from _gen_cluster_mock_resp(r, 2) @pytest.fixture() @@ -384,7 +383,7 @@ def mock_cluster_resp_info(request, **kwargs): "cluster_my_epoch:2\r\ncluster_stats_messages_sent:170262\r\n" "cluster_stats_messages_received:105653\r\n" ) - return _gen_cluster_mock_resp(r, response) + yield from _gen_cluster_mock_resp(r, response) @pytest.fixture() @@ -408,7 +407,7 @@ def mock_cluster_resp_nodes(request, **kwargs): "fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 " "master,fail - 1447829446956 1447829444948 1 disconnected\n" ) - return _gen_cluster_mock_resp(r, response) + yield from _gen_cluster_mock_resp(r, response) @pytest.fixture() @@ -419,7 +418,7 @@ def mock_cluster_resp_slaves(request, **kwargs): "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " "1447836789290 3 connected']" ) - return _gen_cluster_mock_resp(r, response) + yield from _gen_cluster_mock_resp(r, response) @pytest.fixture(scope="session") diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index c837f284f7..10ab4732c2 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -8,7 +8,7 @@ from packaging.version import Version from redis._parsers import _AsyncHiredisParser, _AsyncRESP2Parser from redis.asyncio.client import Monitor -from redis.asyncio.connection import parse_url +from redis.asyncio.connection import Connection, parse_url from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.utils import HIREDIS_AVAILABLE @@ -138,23 +138,25 @@ async def decoded_r(create_redis): def _gen_cluster_mock_resp(r, response): - connection = mock.AsyncMock() + connection = mock.AsyncMock(spec=Connection) connection.retry = Retry(NoBackoff(), 0) connection.read_response.return_value = response - r.connection = connection - return r + with mock.patch.object(r, "connection", connection): + yield r @pytest_asyncio.fixture() async def mock_cluster_resp_ok(create_redis, **kwargs): r = await create_redis(**kwargs) - return _gen_cluster_mock_resp(r, "OK") + for mocked in _gen_cluster_mock_resp(r, "OK"): + yield mocked @pytest_asyncio.fixture() async def mock_cluster_resp_int(create_redis, **kwargs): r = await create_redis(**kwargs) - return _gen_cluster_mock_resp(r, 2) + for mocked in _gen_cluster_mock_resp(r, 2): + yield mocked @pytest_asyncio.fixture() @@ -168,7 +170,8 @@ async def mock_cluster_resp_info(create_redis, **kwargs): "cluster_my_epoch:2\r\ncluster_stats_messages_sent:170262\r\n" "cluster_stats_messages_received:105653\r\n" ) - return _gen_cluster_mock_resp(r, response) + for mocked in _gen_cluster_mock_resp(r, response): + yield mocked @pytest_asyncio.fixture() @@ -192,7 +195,8 @@ async def mock_cluster_resp_nodes(create_redis, **kwargs): "fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 " "master,fail - 1447829446956 1447829444948 1 disconnected\n" ) - return _gen_cluster_mock_resp(r, response) + for mocked in _gen_cluster_mock_resp(r, response): + yield mocked @pytest_asyncio.fixture() @@ -203,7 +207,8 @@ async def mock_cluster_resp_slaves(create_redis, **kwargs): "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " "1447836789290 3 connected']" ) - return _gen_cluster_mock_resp(r, response) + for mocked in _gen_cluster_mock_resp(r, response): + yield mocked async def wait_for_command( diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 1cb1fa5195..332101edd5 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -175,7 +175,7 @@ def cmd_init_mock(self, r: ClusterNode) -> None: def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode: - connection = mock.AsyncMock() + connection = mock.AsyncMock(spec=Connection) connection.is_connected = True connection.read_response.return_value = response while node._free: @@ -185,7 +185,7 @@ def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode: def mock_node_resp_exc(node: ClusterNode, exc: Exception) -> ClusterNode: - connection = mock.AsyncMock() + connection = mock.AsyncMock(spec=Connection) connection.is_connected = True connection.read_response.side_effect = exc while node._free: diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py index 35707553f8..76ec2bbd26 100644 --- a/tests/test_asyncio/test_cwe_404.py +++ b/tests/test_asyncio/test_cwe_404.py @@ -213,8 +213,9 @@ def all_clear(): p.send_event.clear() async def wait_for_send(): - asyncio.wait( - [p.send_event.wait() for p in proxies], return_when=asyncio.FIRST_COMPLETED + await asyncio.wait( + [asyncio.Task(p.send_event.wait()) for p in proxies], + return_when=asyncio.FIRST_COMPLETED, ) @contextlib.contextmanager @@ -228,11 +229,10 @@ def set_delay(delay: float): for p in proxies: await stack.enter_async_context(p) - with contextlib.closing( - RedisCluster.from_url( - f"redis://127.0.0.1:{remap_base}", address_remap=remap - ) - ) as r: + r = RedisCluster.from_url( + f"redis://127.0.0.1:{remap_base}", address_remap=remap + ) + try: await r.initialize() await r.set("foo", "foo") await r.set("bar", "bar") @@ -257,3 +257,5 @@ async def doit(): assert await r.get("foo") == b"foo" await asyncio.gather(*[doit() for _ in range(10)]) + finally: + await r.close() diff --git a/tests/test_connect.py b/tests/test_connect.py index f07750dc80..696e69ceea 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -115,7 +115,7 @@ def get_request(self): return connstream, fromaddr -if hasattr(socket, "UnixStreamServer"): +if hasattr(socketserver, "UnixStreamServer"): class _RedisUDSServer(socketserver.UnixStreamServer): def __init__(self, *args, **kw) -> None: