Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

add "address_remap" feature to RedisCluster #2726

Merged
merged 7 commits into from
May 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Add `address_remap` parameter to `RedisCluster`
* Fix incorrect usage of once flag in async Sentinel
* asyncio: Fix memory leak caused by hiredis (#2693)
* Allow data to drain from async PythonParser when reading during a disconnect()
31 changes: 30 additions & 1 deletion redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
@@ -5,12 +5,14 @@
import warnings
from typing import (
Any,
Callable,
Deque,
Dict,
Generator,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
@@ -147,6 +149,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
maximum number of connections are already created, a
:class:`~.MaxConnectionsError` is raised. This error may be retried as defined
by :attr:`connection_error_retry_attempts`
:param address_remap:
| An optional callable which, when provided with an internal network
address of a node, e.g. a `(host, port)` tuple, will return the address
where the node is reachable. This can be used to map the addresses at
which the nodes _think_ they are, to addresses at which a client may
reach them, such as when they sit behind a proxy.

| Rest of the arguments will be passed to the
:class:`~redis.asyncio.connection.Connection` instances when created
@@ -250,6 +258,7 @@ def __init__(
ssl_certfile: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_keyfile: Optional[str] = None,
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
) -> None:
if db:
raise RedisClusterException(
@@ -337,7 +346,12 @@ def __init__(
if host and port:
startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))

self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs)
self.nodes_manager = NodesManager(
startup_nodes,
require_full_coverage,
kwargs,
address_remap=address_remap,
)
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self.read_from_replicas = read_from_replicas
self.reinitialize_steps = reinitialize_steps
@@ -1059,17 +1073,20 @@ class NodesManager:
"require_full_coverage",
"slots_cache",
"startup_nodes",
"address_remap",
)

def __init__(
self,
startup_nodes: List["ClusterNode"],
require_full_coverage: bool,
connection_kwargs: Dict[str, Any],
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
) -> None:
self.startup_nodes = {node.name: node for node in startup_nodes}
self.require_full_coverage = require_full_coverage
self.connection_kwargs = connection_kwargs
self.address_remap = address_remap

self.default_node: "ClusterNode" = None
self.nodes_cache: Dict[str, "ClusterNode"] = {}
@@ -1228,6 +1245,7 @@ async def initialize(self) -> None:
if host == "":
host = startup_node.host
port = int(primary_node[1])
host, port = self.remap_host_port(host, port)

target_node = tmp_nodes_cache.get(get_node_name(host, port))
if not target_node:
@@ -1246,6 +1264,7 @@ async def initialize(self) -> None:
for replica_node in replica_nodes:
host = replica_node[0]
port = replica_node[1]
host, port = self.remap_host_port(host, port)

target_replica_node = tmp_nodes_cache.get(
get_node_name(host, port)
@@ -1319,6 +1338,16 @@ async def close(self, attr: str = "nodes_cache") -> None:
)
)

def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
"""
Remap the host and port returned from the cluster to a different
internal value. Useful if the client is not connecting directly
to the cluster.
"""
if self.address_remap:
return self.address_remap((host, port))
return host, port


class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
"""
22 changes: 22 additions & 0 deletions redis/cluster.py
Original file line number Diff line number Diff line change
@@ -466,6 +466,7 @@ def __init__(
read_from_replicas: bool = False,
dynamic_startup_nodes: bool = True,
url: Optional[str] = None,
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
**kwargs,
):
"""
@@ -514,6 +515,12 @@ def __init__(
reinitialize_steps to 1.
To avoid reinitializing the cluster on moved errors, set
reinitialize_steps to 0.
:param address_remap:
An optional callable which, when provided with an internal network
address of a node, e.g. a `(host, port)` tuple, will return the address
where the node is reachable. This can be used to map the addresses at
which the nodes _think_ they are, to addresses at which a client may
reach them, such as when they sit behind a proxy.

:**kwargs:
Extra arguments that will be sent into Redis instance when created
@@ -594,6 +601,7 @@ def __init__(
from_url=from_url,
require_full_coverage=require_full_coverage,
dynamic_startup_nodes=dynamic_startup_nodes,
address_remap=address_remap,
**kwargs,
)

@@ -1269,6 +1277,7 @@ def __init__(
lock=None,
dynamic_startup_nodes=True,
connection_pool_class=ConnectionPool,
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
**kwargs,
):
self.nodes_cache = {}
@@ -1280,6 +1289,7 @@ def __init__(
self._require_full_coverage = require_full_coverage
self._dynamic_startup_nodes = dynamic_startup_nodes
self.connection_pool_class = connection_pool_class
self.address_remap = address_remap
self._moved_exception = None
self.connection_kwargs = kwargs
self.read_load_balancer = LoadBalancer()
@@ -1502,6 +1512,7 @@ def initialize(self):
if host == "":
host = startup_node.host
port = int(primary_node[1])
host, port = self.remap_host_port(host, port)

target_node = self._get_or_create_cluster_node(
host, port, PRIMARY, tmp_nodes_cache
@@ -1518,6 +1529,7 @@ def initialize(self):
for replica_node in replica_nodes:
host = str_if_bytes(replica_node[0])
port = replica_node[1]
host, port = self.remap_host_port(host, port)

target_replica_node = self._get_or_create_cluster_node(
host, port, REPLICA, tmp_nodes_cache
@@ -1591,6 +1603,16 @@ def reset(self):
# The read_load_balancer is None, do nothing
pass

def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
"""
Remap the host and port returned from the cluster to a different
internal value. Useful if the client is not connecting directly
to the cluster.
"""
if self.address_remap:
return self.address_remap((host, port))
return host, port


class ClusterPubSub(PubSub):
"""
110 changes: 109 additions & 1 deletion tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
from _pytest.fixtures import FixtureRequest

from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster
from redis.asyncio.connection import Connection, SSLConnection
from redis.asyncio.connection import Connection, SSLConnection, async_timeout
from redis.asyncio.parser import CommandsParser
from redis.asyncio.retry import Retry
from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff
@@ -49,6 +49,71 @@
]


class NodeProxy:
"""A class to proxy a node connection to a different port"""

def __init__(self, addr, redis_addr):
self.addr = addr
self.redis_addr = redis_addr
self.send_event = asyncio.Event()
self.server = None
self.task = None
self.n_connections = 0

async def start(self):
# test that we can connect to redis
async with async_timeout(2):
_, redis_writer = await asyncio.open_connection(*self.redis_addr)
redis_writer.close()
self.server = await asyncio.start_server(
self.handle, *self.addr, reuse_address=True
)
self.task = asyncio.create_task(self.server.serve_forever())

async def handle(self, reader, writer):
# establish connection to redis
redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr)
try:
self.n_connections += 1
pipe1 = asyncio.create_task(self.pipe(reader, redis_writer))
pipe2 = asyncio.create_task(self.pipe(redis_reader, writer))
await asyncio.gather(pipe1, pipe2)
finally:
redis_writer.close()

async def aclose(self):
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
pass
await self.server.wait_closed()

async def pipe(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
):
while True:
data = await reader.read(1000)
if not data:
break
writer.write(data)
await writer.drain()


@pytest.fixture
def redis_addr(request):
redis_url = request.config.getoption("--redis-url")
scheme, netloc = urlparse(redis_url)[:2]
assert scheme == "redis"
if ":" in netloc:
host, port = netloc.split(":")
return host, int(port)
else:
return netloc, 6379


@pytest_asyncio.fixture()
async def slowlog(r: RedisCluster) -> None:
"""
@@ -809,6 +874,49 @@ async def test_default_node_is_replaced_after_exception(self, r):
# Rollback to the old default node
r.replace_default_node(curr_default_node)

async def test_address_remap(self, create_redis, redis_addr):
"""Test that we can create a rediscluster object with
a host-port remapper and map connections through proxy objects
"""

# we remap the first n nodes
offset = 1000
n = 6
ports = [redis_addr[1] + i for i in range(n)]

def address_remap(address):
# remap first three nodes to our local proxy
# old = host, port
host, port = address
if int(port) in ports:
host, port = "127.0.0.1", int(port) + offset
# print(f"{old} {host, port}")
return host, port

# create the proxies
proxies = [
NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port))
for port in ports
]
await asyncio.gather(*[p.start() for p in proxies])
try:
# create cluster:
r = await create_redis(
cls=RedisCluster, flushdb=False, address_remap=address_remap
)
try:
assert await r.ping() is True
assert await r.set("byte_string", b"giraffe")
assert await r.get("byte_string") == b"giraffe"
finally:
await r.close()
finally:
await asyncio.gather(*[p.aclose() for p in proxies])

# verify that the proxies were indeed used
n_used = sum((1 if p.n_connections else 0) for p in proxies)
assert n_used > 1


class TestClusterRedisCommands:
"""
Loading