Skip to content

Commit e7306aa

Browse files
Fix issue 2349: Let async HiredisParser finish parsing after a Connection.disconnect() (#2557)
* A failing unittest * Do not clear the redis-reader's state when we disconnect so that it can finish reading the final message * Test that reading a message of two chunks after a disconnect() works. * Add Changes * fix typos
1 parent 9e00b91 commit e7306aa

File tree

3 files changed

+92
-6
lines changed

3 files changed

+92
-6
lines changed

CHANGES

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
* Add test and fix async HiredisParser when reading during a disconnect() (#2349)
12
* Use hiredis-py pack_command if available.
23
* Support `.unlink()` in ClusterPipeline
34
* Simplify synchronous SocketBuffer state management

redis/asyncio/connection.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -350,13 +350,14 @@ async def _readline(self) -> bytes:
350350
class HiredisParser(BaseParser):
351351
"""Parser class for connections using Hiredis"""
352352

353-
__slots__ = BaseParser.__slots__ + ("_reader",)
353+
__slots__ = BaseParser.__slots__ + ("_reader", "_connected")
354354

355355
def __init__(self, socket_read_size: int):
356356
if not HIREDIS_AVAILABLE:
357357
raise RedisError("Hiredis is not available.")
358358
super().__init__(socket_read_size=socket_read_size)
359359
self._reader: Optional[hiredis.Reader] = None
360+
self._connected: bool = False
360361

361362
def on_connect(self, connection: "Connection"):
362363
self._stream = connection._reader
@@ -369,13 +370,13 @@ def on_connect(self, connection: "Connection"):
369370
kwargs["errors"] = connection.encoder.encoding_errors
370371

371372
self._reader = hiredis.Reader(**kwargs)
373+
self._connected = True
372374

373375
def on_disconnect(self):
374-
self._stream = None
375-
self._reader = None
376+
self._connected = False
376377

377378
async def can_read_destructive(self):
378-
if not self._stream or not self._reader:
379+
if not self._connected:
379380
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
380381
if self._reader.gets():
381382
return True
@@ -397,8 +398,10 @@ async def read_from_socket(self):
397398
async def read_response(
398399
self, disable_decoding: bool = False
399400
) -> Union[EncodableT, List[EncodableT]]:
400-
if not self._stream or not self._reader:
401-
self.on_disconnect()
401+
# If `on_disconnect()` has been called, prohibit any more reads
402+
# even if they could happen because data might be present.
403+
# We still allow reads in progress to finish
404+
if not self._connected:
402405
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
403406

404407
response = self._reader.gets()

tests/test_asyncio/test_connection.py

+82
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
from redis.asyncio.connection import (
1111
BaseParser,
1212
Connection,
13+
HiredisParser,
1314
PythonParser,
1415
UnixDomainSocketConnection,
1516
)
1617
from redis.asyncio.retry import Retry
1718
from redis.backoff import NoBackoff
1819
from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError
20+
from redis.utils import HIREDIS_AVAILABLE
1921
from tests.conftest import skip_if_server_version_lt
2022

2123
from .compat import mock
@@ -191,3 +193,83 @@ async def test_connection_parse_response_resume(r: redis.Redis):
191193
pytest.fail("didn't receive a response")
192194
assert response
193195
assert i > 0
196+
197+
198+
@pytest.mark.onlynoncluster
199+
@pytest.mark.parametrize(
200+
"parser_class", [PythonParser, HiredisParser], ids=["PythonParser", "HiredisParser"]
201+
)
202+
async def test_connection_disconect_race(parser_class):
203+
"""
204+
This test reproduces the case in issue #2349
205+
where a connection is closed while the parser is reading to feed the
206+
internal buffer.The stream `read()` will succeed, but when it returns,
207+
another task has already called `disconnect()` and is waiting for
208+
close to finish. When we attempts to feed the buffer, we will fail
209+
since the buffer is no longer there.
210+
211+
This test verifies that a read in progress can finish even
212+
if the `disconnect()` method is called.
213+
"""
214+
if parser_class == PythonParser:
215+
pytest.xfail("doesn't work yet with PythonParser")
216+
if parser_class == HiredisParser and not HIREDIS_AVAILABLE:
217+
pytest.skip("Hiredis not available")
218+
219+
args = {}
220+
args["parser_class"] = parser_class
221+
222+
conn = Connection(**args)
223+
224+
cond = asyncio.Condition()
225+
# 0 == initial
226+
# 1 == reader is reading
227+
# 2 == closer has closed and is waiting for close to finish
228+
state = 0
229+
230+
# Mock read function, which wait for a close to happen before returning
231+
# Can either be invoked as two `read()` calls (HiredisParser)
232+
# or as a `readline()` followed by `readexact()` (PythonParser)
233+
chunks = [b"$13\r\n", b"Hello, World!\r\n"]
234+
235+
async def read(_=None):
236+
nonlocal state
237+
async with cond:
238+
if state == 0:
239+
state = 1 # we are reading
240+
cond.notify()
241+
# wait until the closing task has done
242+
await cond.wait_for(lambda: state == 2)
243+
return chunks.pop(0)
244+
245+
# function closes the connection while reader is still blocked reading
246+
async def do_close():
247+
nonlocal state
248+
async with cond:
249+
await cond.wait_for(lambda: state == 1)
250+
state = 2
251+
cond.notify()
252+
await conn.disconnect()
253+
254+
async def do_read():
255+
return await conn.read_response()
256+
257+
reader = mock.AsyncMock()
258+
writer = mock.AsyncMock()
259+
writer.transport = mock.Mock()
260+
writer.transport.get_extra_info.side_effect = None
261+
262+
# for HiredisParser
263+
reader.read.side_effect = read
264+
# for PythonParser
265+
reader.readline.side_effect = read
266+
reader.readexactly.side_effect = read
267+
268+
async def open_connection(*args, **kwargs):
269+
return reader, writer
270+
271+
with patch.object(asyncio, "open_connection", open_connection):
272+
await conn.connect()
273+
274+
vals = await asyncio.gather(do_read(), do_close())
275+
assert vals == [b"Hello, World!", None]

0 commit comments

Comments
 (0)