diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 456fb4563b..00865515ab 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -998,10 +998,8 @@ def __init__( self.connection_kwargs = connection_kwargs self.max_connections = max_connections - self._created_connections: int - self._available_connections: List[AbstractConnection] - self._in_use_connections: Set[AbstractConnection] - self.reset() # lgtm [py/init-calls-subclass] + self._available_connections: List[AbstractConnection] = [] + self._in_use_connections: Set[AbstractConnection] = set() self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) def __repr__(self): @@ -1011,7 +1009,6 @@ def __repr__(self): ) def reset(self): - self._created_connections = 0 self._available_connections = [] self._in_use_connections = set() @@ -1019,7 +1016,7 @@ def can_get_connection(self) -> bool: """Return True if a connection can be retrieved from the pool.""" return ( self._available_connections - or self._created_connections < self.max_connections + or len(self._in_use_connections) < self.max_connections ) async def get_connection(self, command_name, *keys, **options): @@ -1027,27 +1024,14 @@ async def get_connection(self, command_name, *keys, **options): try: connection = self._available_connections.pop() except IndexError: + if len(self._in_use_connections) >= self.max_connections: + raise ConnectionError("Too many connections") from None connection = self.make_connection() self._in_use_connections.add(connection) try: - # ensure this connection is connected to Redis - await connection.connect() - # connections that the pool provides should be ready to send - # a command. if not, the connection was either returned to the - # pool before all data has been read or the socket has been - # closed. either way, reconnect and verify everything is good. - try: - if await connection.can_read_destructive(): - raise ConnectionError("Connection has data") from None - except (ConnectionError, OSError): - await connection.disconnect() - await connection.connect() - if await connection.can_read_destructive(): - raise ConnectionError("Connection not ready") from None + await self.ensure_connection(connection) except BaseException: - # release the connection back to the pool so that we don't - # leak it await self.release(connection) raise @@ -1063,12 +1047,25 @@ def get_encoder(self): ) def make_connection(self): - """Create a new connection""" - if self._created_connections >= self.max_connections: - raise ConnectionError("Too many connections") - self._created_connections += 1 + """Create a new connection. Can be overridden by child classes.""" return self.connection_class(**self.connection_kwargs) + async def ensure_connection(self, connection: AbstractConnection): + """Ensure that the connection object is connected and valid""" + await connection.connect() + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the + # pool before all data has been read or the socket has been + # closed. either way, reconnect and verify everything is good. + try: + if await connection.can_read_destructive(): + raise ConnectionError("Connection has data") from None + except (ConnectionError, OSError): + await connection.disconnect() + await connection.connect() + if await connection.can_read_destructive(): + raise ConnectionError("Connection not ready") from None + async def release(self, connection: AbstractConnection): """Releases the connection back to the pool""" # Connections should always be returned to the correct pool,