From 1900efd13822541e84b448a2e49784fa7c908ed9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 5 Apr 2023 12:55:54 +0000 Subject: [PATCH 1/7] add cluster "host_port_remap" feature for asyncio.RedisCluster --- redis/asyncio/cluster.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index a4a9561cf1..7758665247 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -5,12 +5,14 @@ import warnings from typing import ( Any, + Callable, Deque, Dict, Generator, List, Mapping, Optional, + Tuple, Type, TypeVar, Union, @@ -250,6 +252,7 @@ def __init__( ssl_certfile: Optional[str] = None, ssl_check_hostname: bool = False, ssl_keyfile: Optional[str] = None, + host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: if db: raise RedisClusterException( @@ -337,7 +340,12 @@ def __init__( if host and port: startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) - self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs) + self.nodes_manager = NodesManager( + startup_nodes, + require_full_coverage, + kwargs, + host_port_remap=host_port_remap, + ) self.encoder = Encoder(encoding, encoding_errors, decode_responses) self.read_from_replicas = read_from_replicas self.reinitialize_steps = reinitialize_steps @@ -1059,6 +1067,7 @@ class NodesManager: "require_full_coverage", "slots_cache", "startup_nodes", + "host_port_remap", ) def __init__( @@ -1066,10 +1075,12 @@ def __init__( startup_nodes: List["ClusterNode"], require_full_coverage: bool, connection_kwargs: Dict[str, Any], + host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: self.startup_nodes = {node.name: node for node in startup_nodes} self.require_full_coverage = require_full_coverage self.connection_kwargs = connection_kwargs + self.host_port_remap = host_port_remap self.default_node: "ClusterNode" = None self.nodes_cache: Dict[str, "ClusterNode"] = {} @@ -1228,6 +1239,7 @@ async def initialize(self) -> None: if host == "": host = startup_node.host port = int(primary_node[1]) + host, port = self.remap_host_port(host, port) target_node = tmp_nodes_cache.get(get_node_name(host, port)) if not target_node: @@ -1246,6 +1258,7 @@ async def initialize(self) -> None: for replica_node in replica_nodes: host = replica_node[0] port = replica_node[1] + host, port = self.remap_host_port(host, port) target_replica_node = tmp_nodes_cache.get( get_node_name(host, port) @@ -1319,6 +1332,16 @@ async def close(self, attr: str = "nodes_cache") -> None: ) ) + def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: + """ + Remap the host and port returned from the cluster to a different + internal value. Useful if the client is not connecting directly + to the cluster. + """ + if self.host_port_remap: + return self.host_port_remap(host, port) + return host, port + class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): """ From 936b3c86ee904239a93c927e21054a1c700b4efe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 27 Apr 2023 12:51:53 +0000 Subject: [PATCH 2/7] Add a unittest for asyncio.RedisCluster --- tests/test_asyncio/test_cluster.py | 109 ++++++++++++++++++++++++++++- 1 file changed, 108 insertions(+), 1 deletion(-) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 13e5e26ae3..5ad3e2b4a8 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -11,7 +11,7 @@ from _pytest.fixtures import FixtureRequest from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster -from redis.asyncio.connection import Connection, SSLConnection +from redis.asyncio.connection import Connection, SSLConnection, async_timeout from redis.asyncio.parser import CommandsParser from redis.asyncio.retry import Retry from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff @@ -49,6 +49,71 @@ ] +class NodeProxy: + """A class to proxy a node connection to a different port""" + + def __init__(self, addr, redis_addr): + self.addr = addr + self.redis_addr = redis_addr + self.send_event = asyncio.Event() + self.server = None + self.task = None + self.n_connections = 0 + + async def start(self): + # test that we can connect to redis + async with async_timeout(2): + _, redis_writer = await asyncio.open_connection(*self.redis_addr) + redis_writer.close() + self.server = await asyncio.start_server( + self.handle, *self.addr, reuse_address=True + ) + self.task = asyncio.create_task(self.server.serve_forever()) + + async def handle(self, reader, writer): + # establish connection to redis + redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr) + try: + self.n_connections += 1 + pipe1 = asyncio.create_task(self.pipe(reader, redis_writer)) + pipe2 = asyncio.create_task(self.pipe(redis_reader, writer)) + await asyncio.gather(pipe1, pipe2) + finally: + redis_writer.close() + + async def aclose(self): + self.task.cancel() + try: + await self.task + except asyncio.CancelledError: + pass + await self.server.wait_closed() + + async def pipe( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ): + while True: + data = await reader.read(1000) + if not data: + break + writer.write(data) + await writer.drain() + + +@pytest.fixture +def redis_addr(request): + redis_url = request.config.getoption("--redis-url") + scheme, netloc = urlparse(redis_url)[:2] + assert scheme == "redis" + if ":" in netloc: + host, port = netloc.split(":") + return host, int(port) + else: + return netloc, 6379 + + @pytest_asyncio.fixture() async def slowlog(r: RedisCluster) -> None: """ @@ -809,6 +874,48 @@ async def test_default_node_is_replaced_after_exception(self, r): # Rollback to the old default node r.replace_default_node(curr_default_node) + async def test_host_port_remap(self, create_redis, redis_addr): + """Test that we can create a rediscluster object with + a host-port remapper and map connections through proxy objects + """ + + # we remap the first n nodes + offset = 1000 + n = 6 + ports = [redis_addr[1] + i for i in range(n)] + + def host_port_remap(host, port): + # remap first three nodes to our local proxy + # old = host, port + if int(port) in ports: + host, port = "127.0.0.1", int(port) + offset + # print(f"{old} {host, port}") + return host, port + + # create the proxies + proxies = [ + NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port)) + for port in ports + ] + await asyncio.gather(*[p.start() for p in proxies]) + try: + # create cluster: + r = await create_redis( + cls=RedisCluster, flushdb=False, host_port_remap=host_port_remap + ) + try: + assert await r.ping() is True + assert await r.set("byte_string", b"giraffe") + assert await r.get("byte_string") == b"giraffe" + finally: + await r.close() + finally: + await asyncio.gather(*[p.aclose() for p in proxies]) + + # verify that the proxies were indeed used + n_used = sum((1 if p.n_connections else 0) for p in proxies) + assert n_used > 1 + class TestClusterRedisCommands: """ From cdc4acf4867e383fa5c5e5a63457591ea296ed7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 27 Apr 2023 17:15:08 +0000 Subject: [PATCH 3/7] Add host_port_remap to _sync_ RedisCluster --- redis/cluster.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/redis/cluster.py b/redis/cluster.py index 5e6e7da546..b455ad57d4 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -466,6 +466,7 @@ def __init__( read_from_replicas: bool = False, dynamic_startup_nodes: bool = True, url: Optional[str] = None, + host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, **kwargs, ): """ @@ -594,6 +595,7 @@ def __init__( from_url=from_url, require_full_coverage=require_full_coverage, dynamic_startup_nodes=dynamic_startup_nodes, + host_port_remap=host_port_remap, **kwargs, ) @@ -1269,6 +1271,7 @@ def __init__( lock=None, dynamic_startup_nodes=True, connection_pool_class=ConnectionPool, + host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, **kwargs, ): self.nodes_cache = {} @@ -1280,6 +1283,7 @@ def __init__( self._require_full_coverage = require_full_coverage self._dynamic_startup_nodes = dynamic_startup_nodes self.connection_pool_class = connection_pool_class + self.host_port_remap = host_port_remap self._moved_exception = None self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() @@ -1502,6 +1506,7 @@ def initialize(self): if host == "": host = startup_node.host port = int(primary_node[1]) + host, port = self.remap_host_port(host, port) target_node = self._get_or_create_cluster_node( host, port, PRIMARY, tmp_nodes_cache @@ -1518,6 +1523,7 @@ def initialize(self): for replica_node in replica_nodes: host = str_if_bytes(replica_node[0]) port = replica_node[1] + host, port = self.remap_host_port(host, port) target_replica_node = self._get_or_create_cluster_node( host, port, REPLICA, tmp_nodes_cache @@ -1591,6 +1597,16 @@ def reset(self): # The read_load_balancer is None, do nothing pass + def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: + """ + Remap the host and port returned from the cluster to a different + internal value. Useful if the client is not connecting directly + to the cluster. + """ + if self.host_port_remap: + return self.host_port_remap(host, port) + return host, port + class ClusterPubSub(PubSub): """ From b1254d77115263030bbead66eafad127f95030fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 27 Apr 2023 17:59:01 +0000 Subject: [PATCH 4/7] add synchronous tests --- tests/test_cluster.py | 128 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 58f9b77d7d..55b99a7c49 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -1,9 +1,14 @@ import binascii import datetime +import select +import socket +import socketserver +import threading import warnings from queue import LifoQueue, Queue from time import sleep from unittest.mock import DEFAULT, Mock, call, patch +from urllib.parse import urlparse import pytest @@ -53,6 +58,85 @@ ] +class ProxyRequestHandler(socketserver.BaseRequestHandler): + def recv(self, sock): + """A recv with a timeout""" + r = select.select([sock], [], [], 0.01) + if not r[0]: + return None + return sock.recv(1000) + + def handle(self): + self.server.proxy.n_connections += 1 + conn = socket.create_connection(self.server.proxy.redis_addr) + stop = False + + def from_server(): + # read from server and pass to client + while not stop: + data = self.recv(conn) + if data is None: + continue + if not data: + self.request.shutdown(socket.SHUT_WR) + return + self.request.sendall(data) + + thread = threading.Thread(target=from_server) + thread.start() + try: + while True: + # read from client and send to server + data = self.request.recv(1000) + if not data: + return + conn.sendall(data) + finally: + conn.shutdown(socket.SHUT_WR) + stop = True # for safety + thread.join() + conn.close() + + +class NodeProxy: + """A class to proxy a node connection to a different port""" + + def __init__(self, addr, redis_addr): + self.addr = addr + self.redis_addr = redis_addr + self.server = socketserver.ThreadingTCPServer(self.addr, ProxyRequestHandler) + self.server.proxy = self + self.server.socket_reuse_address = True + self.thread = None + self.n_connections = 0 + + def start(self): + # test that we can connect to redis + s = socket.create_connection(self.redis_addr, timeout=2) + s.close() + # Start a thread with the server -- that thread will then start one + # more thread for each request + self.thread = threading.Thread(target=self.server.serve_forever) + # Exit the server thread when the main thread terminates + self.thread.daemon = True + self.thread.start() + + def close(self): + self.server.shutdown() + + +@pytest.fixture +def redis_addr(request): + redis_url = request.config.getoption("--redis-url") + scheme, netloc = urlparse(redis_url)[:2] + assert scheme == "redis" + if ":" in netloc: + host, port = netloc.split(":") + return host, int(port) + else: + return netloc, 6379 + + @pytest.fixture() def slowlog(request, r): """ @@ -823,6 +907,50 @@ def raise_connection_error(): assert "myself" not in nodes.get(curr_default_node.name).get("flags") assert r.get_default_node() != curr_default_node + def test_host_port_remap(self, request, redis_addr): + """Test that we can create a rediscluster object with + a host-port remapper and map connections through proxy objects + """ + + # we remap the first n nodes + offset = 1000 + n = 6 + ports = [redis_addr[1] + i for i in range(n)] + + def host_port_remap(host, port): + # remap first three nodes to our local proxy + # old = host, port + if int(port) in ports: + host, port = "127.0.0.1", int(port) + offset + # print(f"{old} {host, port}") + return host, port + + # create the proxies + proxies = [ + NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port)) + for port in ports + ] + for p in proxies: + p.start() + try: + # create cluster: + r = _get_client( + RedisCluster, request, flushdb=False, host_port_remap=host_port_remap + ) + try: + assert r.ping() is True + assert r.set("byte_string", b"giraffe") + assert r.get("byte_string") == b"giraffe" + finally: + r.close() + finally: + for p in proxies: + p.close() + + # verify that the proxies were indeed used + n_used = sum((1 if p.n_connections else 0) for p in proxies) + assert n_used > 1 + @pytest.mark.onlycluster class TestClusterRedisCommands: From d2d34f02978bee9ac29fa868c82b360cd34737da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 28 Apr 2023 09:49:00 +0000 Subject: [PATCH 5/7] rename arg to `address_remap` and take and return an address tuple. --- redis/asyncio/cluster.py | 14 +++++++------- redis/cluster.py | 12 ++++++------ tests/test_asyncio/test_cluster.py | 7 ++++--- tests/test_cluster.py | 7 ++++--- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 7758665247..ffa0840f30 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -252,7 +252,7 @@ def __init__( ssl_certfile: Optional[str] = None, ssl_check_hostname: bool = False, ssl_keyfile: Optional[str] = None, - host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: if db: raise RedisClusterException( @@ -344,7 +344,7 @@ def __init__( startup_nodes, require_full_coverage, kwargs, - host_port_remap=host_port_remap, + address_remap=address_remap, ) self.encoder = Encoder(encoding, encoding_errors, decode_responses) self.read_from_replicas = read_from_replicas @@ -1067,7 +1067,7 @@ class NodesManager: "require_full_coverage", "slots_cache", "startup_nodes", - "host_port_remap", + "address_remap", ) def __init__( @@ -1075,12 +1075,12 @@ def __init__( startup_nodes: List["ClusterNode"], require_full_coverage: bool, connection_kwargs: Dict[str, Any], - host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: self.startup_nodes = {node.name: node for node in startup_nodes} self.require_full_coverage = require_full_coverage self.connection_kwargs = connection_kwargs - self.host_port_remap = host_port_remap + self.address_remap = address_remap self.default_node: "ClusterNode" = None self.nodes_cache: Dict[str, "ClusterNode"] = {} @@ -1338,8 +1338,8 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: internal value. Useful if the client is not connecting directly to the cluster. """ - if self.host_port_remap: - return self.host_port_remap(host, port) + if self.address_remap: + return self.address_remap((host, port)) return host, port diff --git a/redis/cluster.py b/redis/cluster.py index b455ad57d4..5cbba23d2b 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -466,7 +466,7 @@ def __init__( read_from_replicas: bool = False, dynamic_startup_nodes: bool = True, url: Optional[str] = None, - host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, **kwargs, ): """ @@ -595,7 +595,7 @@ def __init__( from_url=from_url, require_full_coverage=require_full_coverage, dynamic_startup_nodes=dynamic_startup_nodes, - host_port_remap=host_port_remap, + address_remap=address_remap, **kwargs, ) @@ -1271,7 +1271,7 @@ def __init__( lock=None, dynamic_startup_nodes=True, connection_pool_class=ConnectionPool, - host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, **kwargs, ): self.nodes_cache = {} @@ -1283,7 +1283,7 @@ def __init__( self._require_full_coverage = require_full_coverage self._dynamic_startup_nodes = dynamic_startup_nodes self.connection_pool_class = connection_pool_class - self.host_port_remap = host_port_remap + self.address_remap = address_remap self._moved_exception = None self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() @@ -1603,8 +1603,8 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: internal value. Useful if the client is not connecting directly to the cluster. """ - if self.host_port_remap: - return self.host_port_remap(host, port) + if self.address_remap: + return self.address_remap((host, port)) return host, port diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 5ad3e2b4a8..6d0aba73fb 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -874,7 +874,7 @@ async def test_default_node_is_replaced_after_exception(self, r): # Rollback to the old default node r.replace_default_node(curr_default_node) - async def test_host_port_remap(self, create_redis, redis_addr): + async def test_address_remap(self, create_redis, redis_addr): """Test that we can create a rediscluster object with a host-port remapper and map connections through proxy objects """ @@ -884,9 +884,10 @@ async def test_host_port_remap(self, create_redis, redis_addr): n = 6 ports = [redis_addr[1] + i for i in range(n)] - def host_port_remap(host, port): + def address_remap(address): # remap first three nodes to our local proxy # old = host, port + host, port = address if int(port) in ports: host, port = "127.0.0.1", int(port) + offset # print(f"{old} {host, port}") @@ -901,7 +902,7 @@ def host_port_remap(host, port): try: # create cluster: r = await create_redis( - cls=RedisCluster, flushdb=False, host_port_remap=host_port_remap + cls=RedisCluster, flushdb=False, address_remap=address_remap ) try: assert await r.ping() is True diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 55b99a7c49..1f037c9edf 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -907,7 +907,7 @@ def raise_connection_error(): assert "myself" not in nodes.get(curr_default_node.name).get("flags") assert r.get_default_node() != curr_default_node - def test_host_port_remap(self, request, redis_addr): + def test_address_remap(self, request, redis_addr): """Test that we can create a rediscluster object with a host-port remapper and map connections through proxy objects """ @@ -917,9 +917,10 @@ def test_host_port_remap(self, request, redis_addr): n = 6 ports = [redis_addr[1] + i for i in range(n)] - def host_port_remap(host, port): + def address_remap(address): # remap first three nodes to our local proxy # old = host, port + host, port = address if int(port) in ports: host, port = "127.0.0.1", int(port) + offset # print(f"{old} {host, port}") @@ -935,7 +936,7 @@ def host_port_remap(host, port): try: # create cluster: r = _get_client( - RedisCluster, request, flushdb=False, host_port_remap=host_port_remap + RedisCluster, request, flushdb=False, address_remap=address_remap ) try: assert r.ping() is True From c83ee087f7a980a5ce20c6e050dca8493b7ceb80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 28 Apr 2023 09:51:35 +0000 Subject: [PATCH 6/7] Add class documentation --- redis/asyncio/cluster.py | 6 ++++++ redis/cluster.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index ffa0840f30..eb5f4db061 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -149,6 +149,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand maximum number of connections are already created, a :class:`~.MaxConnectionsError` is raised. This error may be retried as defined by :attr:`connection_error_retry_attempts` + :param address_remap: + | An optional callable which, when provided with an internal network + address of a node, e.g. a `(host, port)` tuple, will return the address + where the node is reachable. This can be used to map the addresses at + which the nodes _think_ they are, to addresses at which a client may + reach them, such as when they sit behind a proxy. | Rest of the arguments will be passed to the :class:`~redis.asyncio.connection.Connection` instances when created diff --git a/redis/cluster.py b/redis/cluster.py index 5cbba23d2b..3ecc2dab56 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -515,6 +515,12 @@ def __init__( reinitialize_steps to 1. To avoid reinitializing the cluster on moved errors, set reinitialize_steps to 0. + :param address_remap: + An optional callable which, when provided with an internal network + address of a node, e.g. a `(host, port)` tuple, will return the address + where the node is reachable. This can be used to map the addresses at + which the nodes _think_ they are, to addresses at which a client may + reach them, such as when they sit behind a proxy. :**kwargs: Extra arguments that will be sent into Redis instance when created From 0e7447a506691ef3fcb64b4de860d64ad42f0733 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 1 May 2023 18:26:10 +0000 Subject: [PATCH 7/7] Add CHANGES --- CHANGES | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGES b/CHANGES index 8f2017218a..3865ed1067 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Add `address_remap` parameter to `RedisCluster` * Fix incorrect usage of once flag in async Sentinel * asyncio: Fix memory leak caused by hiredis (#2693) * Allow data to drain from async PythonParser when reading during a disconnect()