Skip to content

Commit

Permalink
Remove process-id checks from asyncio. Asyncio and fork() does not mix.
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur committed Aug 24, 2023
1 parent 19b55c6 commit b3e7893
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 120 deletions.
124 changes: 6 additions & 118 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
import copy
import enum
import inspect
import os
import socket
import ssl
import sys
import threading
import weakref
from abc import abstractmethod
from itertools import chain
Expand Down Expand Up @@ -41,7 +39,6 @@
from redis.exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
ChildDeadlockedError,
ConnectionError,
DataError,
RedisError,
Expand Down Expand Up @@ -97,7 +94,6 @@ class AbstractConnection:
"""Manages communication to and from a Redis server"""

__slots__ = (
"pid",
"db",
"username",
"client_name",
Expand Down Expand Up @@ -158,7 +154,6 @@ def __init__(
"1. 'password' and (optional) 'username'\n"
"2. 'credential_provider'"
)
self.pid = os.getpid()
self.db = db
self.client_name = client_name
self.lib_name = lib_name
Expand Down Expand Up @@ -381,12 +376,11 @@ async def disconnect(self, nowait: bool = False) -> None:
if not self.is_connected:
return
try:
if os.getpid() == self.pid:
self._writer.close() # type: ignore[union-attr]
# wait for close to finish, except when handling errors and
# forcefully disconnecting.
if not nowait:
await self._writer.wait_closed() # type: ignore[union-attr]
self._writer.close() # type: ignore[union-attr]
# wait for close to finish, except when handling errors and
# forcefully disconnecting.
if not nowait:
await self._writer.wait_closed() # type: ignore[union-attr]
except OSError:
pass
finally:
Expand Down Expand Up @@ -1004,15 +998,6 @@ def __init__(
self.connection_kwargs = connection_kwargs
self.max_connections = max_connections

# a lock to protect the critical section in _checkpid().
# this lock is acquired when the process id changes, such as
# after a fork. during this time, multiple threads in the child
# process could attempt to acquire this lock. the first thread
# to acquire the lock will reset the data structures and lock
# object of this pool. subsequent threads acquiring this lock
# will notice the first thread already did the work and simply
# release the lock.
self._fork_lock = threading.Lock()
self._lock = asyncio.Lock()
self._created_connections: int
self._available_connections: List[AbstractConnection]
Expand All @@ -1032,67 +1017,8 @@ def reset(self):
self._available_connections = []
self._in_use_connections = set()

# this must be the last operation in this method. while reset() is
# called when holding _fork_lock, other threads in this process
# can call _checkpid() which compares self.pid and os.getpid() without
# holding any lock (for performance reasons). keeping this assignment
# as the last operation ensures that those other threads will also
# notice a pid difference and block waiting for the first thread to
# release _fork_lock. when each of these threads eventually acquire
# _fork_lock, they will notice that another thread already called
# reset() and they will immediately release _fork_lock and continue on.
self.pid = os.getpid()

def _checkpid(self):
# _checkpid() attempts to keep ConnectionPool fork-safe on modern
# systems. this is called by all ConnectionPool methods that
# manipulate the pool's state such as get_connection() and release().
#
# _checkpid() determines whether the process has forked by comparing
# the current process id to the process id saved on the ConnectionPool
# instance. if these values are the same, _checkpid() simply returns.
#
# when the process ids differ, _checkpid() assumes that the process
# has forked and that we're now running in the child process. the child
# process cannot use the parent's file descriptors (e.g., sockets).
# therefore, when _checkpid() sees the process id change, it calls
# reset() in order to reinitialize the child's ConnectionPool. this
# will cause the child to make all new connection objects.
#
# _checkpid() is protected by self._fork_lock to ensure that multiple
# threads in the child process do not call reset() multiple times.
#
# there is an extremely small chance this could fail in the following
# scenario:
# 1. process A calls _checkpid() for the first time and acquires
# self._fork_lock.
# 2. while holding self._fork_lock, process A forks (the fork()
# could happen in a different thread owned by process A)
# 3. process B (the forked child process) inherits the
# ConnectionPool's state from the parent. that state includes
# a locked _fork_lock. process B will not be notified when
# process A releases the _fork_lock and will thus never be
# able to acquire the _fork_lock.
#
# to mitigate this possible deadlock, _checkpid() will only wait 5
# seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
# that time it is assumed that the child is deadlocked and a
# redis.ChildDeadlockedError error is raised.
if self.pid != os.getpid():
acquired = self._fork_lock.acquire(timeout=5)
if not acquired:
raise ChildDeadlockedError
# reset() the instance for the new process if another thread
# hasn't already done so
try:
if self.pid != os.getpid():
self.reset()
finally:
self._fork_lock.release()

async def get_connection(self, command_name, *keys, **options):
"""Get a connection from the pool"""
self._checkpid()
async with self._lock:
try:
connection = self._available_connections.pop()
Expand Down Expand Up @@ -1141,7 +1067,6 @@ def make_connection(self):

async def release(self, connection: AbstractConnection):
"""Releases the connection back to the pool"""
self._checkpid()
async with self._lock:
try:
self._in_use_connections.remove(connection)
Expand All @@ -1150,18 +1075,7 @@ async def release(self, connection: AbstractConnection):
# that the pool doesn't actually own
pass

if self.owns_connection(connection):
self._available_connections.append(connection)
else:
# pool doesn't own this connection. do not add it back
# to the pool and decrement the count so that another
# connection can take its place if needed
self._created_connections -= 1
await connection.disconnect()
return

def owns_connection(self, connection: AbstractConnection):
return connection.pid == self.pid
self._available_connections.append(connection)

async def disconnect(self, inuse_connections: bool = True):
"""
Expand All @@ -1171,7 +1085,6 @@ async def disconnect(self, inuse_connections: bool = True):
current in use, potentially by other tasks. Otherwise only disconnect
connections that are idle in the pool.
"""
self._checkpid()
async with self._lock:
if inuse_connections:
connections: Iterable[AbstractConnection] = chain(
Expand Down Expand Up @@ -1259,17 +1172,6 @@ def reset(self):
# disconnect them later.
self._connections = []

# this must be the last operation in this method. while reset() is
# called when holding _fork_lock, other threads in this process
# can call _checkpid() which compares self.pid and os.getpid() without
# holding any lock (for performance reasons). keeping this assignment
# as the last operation ensures that those other threads will also
# notice a pid difference and block waiting for the first thread to
# release _fork_lock. when each of these threads eventually acquire
# _fork_lock, they will notice that another thread already called
# reset() and they will immediately release _fork_lock and continue on.
self.pid = os.getpid()

def make_connection(self):
"""Make a fresh connection."""
connection = self.connection_class(**self.connection_kwargs)
Expand All @@ -1288,8 +1190,6 @@ async def get_connection(self, command_name, *keys, **options):
create new connections when we need to, i.e.: the actual number of
connections will only increase in response to demand.
"""
# Make sure we haven't changed process.
self._checkpid()

# Try and get a connection from the pool. If one isn't available within
# self.timeout then raise a ``ConnectionError``.
Expand Down Expand Up @@ -1331,17 +1231,6 @@ async def get_connection(self, command_name, *keys, **options):

async def release(self, connection: AbstractConnection):
"""Releases the connection back to the pool."""
# Make sure we haven't changed process.
self._checkpid()
if not self.owns_connection(connection):
# pool doesn't own this connection. do not add it back
# to the pool. instead add a None value which is a placeholder
# that will cause the pool to recreate the connection if
# its needed.
await connection.disconnect()
self.pool.put_nowait(None)
return

# Put the connection back into the pool.
try:
self.pool.put_nowait(connection)
Expand All @@ -1352,7 +1241,6 @@ async def release(self, connection: AbstractConnection):

async def disconnect(self, inuse_connections: bool = True):
"""Disconnects all connections in the pool."""
self._checkpid()
async with self._lock:
resp = await asyncio.gather(
*(connection.disconnect() for connection in self._connections),
Expand Down
2 changes: 0 additions & 2 deletions tests/test_asyncio/test_connection_pool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import os
import re

import pytest
Expand Down Expand Up @@ -94,7 +93,6 @@ class DummyConnection(Connection):

def __init__(self, **kwargs):
self.kwargs = kwargs
self.pid = os.getpid()

async def connect(self):
pass
Expand Down

0 comments on commit b3e7893

Please # to comment.