Skip to content

Commit

Permalink
Ensure connection creation can be subclassed via make_connection()
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur committed Sep 10, 2023
1 parent 010d8b8 commit 004e15e
Showing 1 changed file with 23 additions and 26 deletions.
49 changes: 23 additions & 26 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,10 +998,8 @@ def __init__(
self.connection_kwargs = connection_kwargs
self.max_connections = max_connections

self._created_connections: int
self._available_connections: List[AbstractConnection]
self._in_use_connections: Set[AbstractConnection]
self.reset() # lgtm [py/init-calls-subclass]
self._available_connections: List[AbstractConnection] = []
self._in_use_connections: Set[AbstractConnection] = set()
self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)

def __repr__(self):
Expand All @@ -1011,43 +1009,29 @@ def __repr__(self):
)

def reset(self):
self._created_connections = 0
self._available_connections = []
self._in_use_connections = set()

def can_get_connection(self) -> bool:
"""Return True if a connection can be retrieved from the pool."""
return (
self._available_connections
or self._created_connections < self.max_connections
or len(self._in_use_connections) < self.max_connections
)

async def get_connection(self, command_name, *keys, **options):
"""Get a connection from the pool"""
try:
connection = self._available_connections.pop()
except IndexError:
if len(self._in_use_connections) >= self.max_connections:
raise ConnectionError("Too many connections") from None
connection = self.make_connection()
self._in_use_connections.add(connection)

try:
# ensure this connection is connected to Redis
await connection.connect()
# connections that the pool provides should be ready to send
# a command. if not, the connection was either returned to the
# pool before all data has been read or the socket has been
# closed. either way, reconnect and verify everything is good.
try:
if await connection.can_read_destructive():
raise ConnectionError("Connection has data") from None
except (ConnectionError, OSError):
await connection.disconnect()
await connection.connect()
if await connection.can_read_destructive():
raise ConnectionError("Connection not ready") from None
await self.ensure_connection(connection)
except BaseException:
# release the connection back to the pool so that we don't
# leak it
await self.release(connection)
raise

Expand All @@ -1063,12 +1047,25 @@ def get_encoder(self):
)

def make_connection(self):
"""Create a new connection"""
if self._created_connections >= self.max_connections:
raise ConnectionError("Too many connections")
self._created_connections += 1
"""Create a new connection. Can be overridden by child classes."""
return self.connection_class(**self.connection_kwargs)

async def ensure_connection(self, connection: AbstractConnection):
"""Ensure that the connection object is connected and valid"""
await connection.connect()
# connections that the pool provides should be ready to send
# a command. if not, the connection was either returned to the
# pool before all data has been read or the socket has been
# closed. either way, reconnect and verify everything is good.
try:
if await connection.can_read_destructive():
raise ConnectionError("Connection has data") from None
except (ConnectionError, OSError):
await connection.disconnect()
await connection.connect()
if await connection.can_read_destructive():
raise ConnectionError("Connection not ready") from None

async def release(self, connection: AbstractConnection):
"""Releases the connection back to the pool"""
# Connections should always be returned to the correct pool,
Expand Down

0 comments on commit 004e15e

Please # to comment.