Skip to content

Commit 778f473

Browse files
committed
tests: add 'connect' tests for all Redis connection classes
1 parent 363f6e3 commit 778f473

File tree

2 files changed

+330
-0
lines changed

2 files changed

+330
-0
lines changed

Diff for: tests/test_asyncio/test_connect.py

+145
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import asyncio
2+
import logging
3+
import re
4+
import socket
5+
import ssl
6+
7+
import pytest
8+
9+
from redis.asyncio.connection import (
10+
Connection,
11+
SSLConnection,
12+
UnixDomainSocketConnection,
13+
)
14+
15+
from ..ssl_utils import get_ssl_filename
16+
17+
_logger = logging.getLogger(__name__)
18+
19+
20+
_CLIENT_NAME = "test-suite-client"
21+
_CMD_SEP = b"\r\n"
22+
_SUCCESS_RESP = b"+OK" + _CMD_SEP
23+
_ERROR_RESP = b"-ERR" + _CMD_SEP
24+
_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP}
25+
26+
27+
@pytest.fixture
28+
def tcp_address():
29+
with socket.socket() as sock:
30+
sock.bind(("127.0.0.1", 0))
31+
return sock.getsockname()
32+
33+
34+
@pytest.fixture
35+
def uds_address(tmpdir):
36+
return tmpdir / "uds.sock"
37+
38+
39+
async def test_tcp_connect(tcp_address):
40+
host, port = tcp_address
41+
conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10)
42+
await _assert_connect(conn, tcp_address)
43+
44+
45+
async def test_uds_connect(uds_address):
46+
path = str(uds_address)
47+
conn = UnixDomainSocketConnection(
48+
path=path, client_name=_CLIENT_NAME, socket_timeout=10
49+
)
50+
await _assert_connect(conn, path)
51+
52+
53+
@pytest.mark.ssl
54+
async def test_tcp_ssl_connect(tcp_address):
55+
host, port = tcp_address
56+
certfile = get_ssl_filename("server-cert.pem")
57+
keyfile = get_ssl_filename("server-key.pem")
58+
conn = SSLConnection(
59+
host=host,
60+
port=port,
61+
client_name=_CLIENT_NAME,
62+
ssl_ca_certs=certfile,
63+
socket_timeout=10,
64+
)
65+
await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)
66+
67+
68+
async def _assert_connect(conn, server_address, certfile=None, keyfile=None):
69+
stop_event = asyncio.Event()
70+
finished = asyncio.Event()
71+
72+
async def _handler(reader, writer):
73+
try:
74+
return await _redis_request_handler(reader, writer, stop_event)
75+
finally:
76+
finished.set()
77+
78+
if isinstance(server_address, str):
79+
server = await asyncio.start_unix_server(_handler, path=server_address)
80+
elif certfile:
81+
host, port = server_address
82+
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
83+
context.minimum_version = ssl.TLSVersion.TLSv1_2
84+
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
85+
server = await asyncio.start_server(_handler, host=host, port=port, ssl=context)
86+
else:
87+
host, port = server_address
88+
server = await asyncio.start_server(_handler, host=host, port=port)
89+
90+
async with server as aserver:
91+
await aserver.start_serving()
92+
try:
93+
await conn.connect()
94+
await conn.disconnect()
95+
finally:
96+
stop_event.set()
97+
aserver.close()
98+
await aserver.wait_closed()
99+
await finished.wait()
100+
101+
102+
async def _redis_request_handler(reader, writer, stop_event):
103+
buffer = b""
104+
command = None
105+
command_ptr = None
106+
fragment_length = None
107+
while not stop_event.is_set() or buffer:
108+
_logger.info(str(stop_event.is_set()))
109+
try:
110+
buffer += await asyncio.wait_for(reader.read(1024), timeout=0.5)
111+
except TimeoutError:
112+
continue
113+
if not buffer:
114+
continue
115+
parts = re.split(_CMD_SEP, buffer)
116+
buffer = parts[-1]
117+
for fragment in parts[:-1]:
118+
fragment = fragment.decode()
119+
_logger.info("Command fragment: %s", fragment)
120+
121+
if fragment.startswith("*") and command is None:
122+
command = [None for _ in range(int(fragment[1:]))]
123+
command_ptr = 0
124+
fragment_length = None
125+
continue
126+
127+
if fragment.startswith("$") and command[command_ptr] is None:
128+
fragment_length = int(fragment[1:])
129+
continue
130+
131+
assert len(fragment) == fragment_length
132+
command[command_ptr] = fragment
133+
command_ptr += 1
134+
135+
if command_ptr < len(command):
136+
continue
137+
138+
command = " ".join(command)
139+
_logger.info("Command %s", command)
140+
resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP)
141+
_logger.info("Response from %s", resp)
142+
writer.write(resp)
143+
await writer.drain()
144+
command = None
145+
_logger.info("Exit handler")

Diff for: tests/test_connect.py

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import logging
2+
import re
3+
import socket
4+
import socketserver
5+
import ssl
6+
import threading
7+
8+
import pytest
9+
10+
from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection
11+
12+
from .ssl_utils import get_ssl_filename
13+
14+
_logger = logging.getLogger(__name__)
15+
16+
17+
_CLIENT_NAME = "test-suite-client"
18+
_CMD_SEP = b"\r\n"
19+
_SUCCESS_RESP = b"+OK" + _CMD_SEP
20+
_ERROR_RESP = b"-ERR" + _CMD_SEP
21+
_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP}
22+
23+
24+
@pytest.fixture
25+
def tcp_address():
26+
with socket.socket() as sock:
27+
sock.bind(("127.0.0.1", 0))
28+
return sock.getsockname()
29+
30+
31+
@pytest.fixture
32+
def uds_address(tmpdir):
33+
return tmpdir / "uds.sock"
34+
35+
36+
def test_tcp_connect(tcp_address):
37+
host, port = tcp_address
38+
conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10)
39+
_assert_connect(conn, tcp_address)
40+
41+
42+
def test_uds_connect(uds_address):
43+
path = str(uds_address)
44+
conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME, socket_timeout=10)
45+
_assert_connect(conn, path)
46+
47+
48+
@pytest.mark.ssl
49+
def test_tcp_ssl_connect(tcp_address):
50+
host, port = tcp_address
51+
certfile = get_ssl_filename("server-cert.pem")
52+
keyfile = get_ssl_filename("server-key.pem")
53+
conn = SSLConnection(
54+
host=host,
55+
port=port,
56+
client_name=_CLIENT_NAME,
57+
ssl_ca_certs=certfile,
58+
socket_timeout=10,
59+
)
60+
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)
61+
62+
63+
def _assert_connect(conn, server_address, certfile=None, keyfile=None):
64+
if isinstance(server_address, str):
65+
server = _RedisUDSServer(server_address, _RedisRequestHandler)
66+
else:
67+
server = _RedisTCPServer(
68+
server_address, _RedisRequestHandler, certfile=certfile, keyfile=keyfile
69+
)
70+
with server as aserver:
71+
t = threading.Thread(target=aserver.serve_forever)
72+
t.start()
73+
try:
74+
aserver.wait_online()
75+
conn.connect()
76+
conn.disconnect()
77+
finally:
78+
aserver.stop()
79+
t.join(timeout=5)
80+
81+
82+
class _RedisTCPServer(socketserver.TCPServer):
83+
def __init__(self, *args, certfile=None, keyfile=None, **kw) -> None:
84+
self._ready_event = threading.Event()
85+
self._stop_requested = False
86+
self._certfile = certfile
87+
self._keyfile = keyfile
88+
super().__init__(*args, **kw)
89+
90+
def service_actions(self):
91+
self._ready_event.set()
92+
93+
def wait_online(self):
94+
self._ready_event.wait()
95+
96+
def stop(self):
97+
self._stop_requested = True
98+
self.shutdown()
99+
100+
def is_serving(self):
101+
return not self._stop_requested
102+
103+
def get_request(self):
104+
if self._certfile is None:
105+
return super().get_request()
106+
newsocket, fromaddr = self.socket.accept()
107+
connstream = ssl.wrap_socket(
108+
newsocket,
109+
server_side=True,
110+
certfile=self._certfile,
111+
keyfile=self._keyfile,
112+
ssl_version=ssl.PROTOCOL_TLSv1_2,
113+
)
114+
return connstream, fromaddr
115+
116+
117+
class _RedisUDSServer(socketserver.UnixStreamServer):
118+
def __init__(self, *args, **kw) -> None:
119+
self._ready_event = threading.Event()
120+
self._stop_requested = False
121+
super().__init__(*args, **kw)
122+
123+
def service_actions(self):
124+
self._ready_event.set()
125+
126+
def wait_online(self):
127+
self._ready_event.wait()
128+
129+
def stop(self):
130+
self._stop_requested = True
131+
self.shutdown()
132+
133+
def is_serving(self):
134+
return not self._stop_requested
135+
136+
137+
class _RedisRequestHandler(socketserver.StreamRequestHandler):
138+
def setup(self):
139+
_logger.info("%s connected", self.client_address)
140+
141+
def finish(self):
142+
_logger.info("%s disconnected", self.client_address)
143+
144+
def handle(self):
145+
buffer = b""
146+
command = None
147+
command_ptr = None
148+
fragment_length = None
149+
while self.server.is_serving() or buffer:
150+
try:
151+
buffer += self.request.recv(1024)
152+
except socket.timeout:
153+
continue
154+
if not buffer:
155+
continue
156+
parts = re.split(_CMD_SEP, buffer)
157+
buffer = parts[-1]
158+
for fragment in parts[:-1]:
159+
fragment = fragment.decode()
160+
_logger.info("Command fragment: %s", fragment)
161+
162+
if fragment.startswith("*") and command is None:
163+
command = [None for _ in range(int(fragment[1:]))]
164+
command_ptr = 0
165+
fragment_length = None
166+
continue
167+
168+
if fragment.startswith("$") and command[command_ptr] is None:
169+
fragment_length = int(fragment[1:])
170+
continue
171+
172+
assert len(fragment) == fragment_length
173+
command[command_ptr] = fragment
174+
command_ptr += 1
175+
176+
if command_ptr < len(command):
177+
continue
178+
179+
command = " ".join(command)
180+
_logger.info("Command %s", command)
181+
resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP)
182+
_logger.info("Response %s", resp)
183+
self.request.sendall(resp)
184+
command = None
185+
_logger.info("Exit handler")

0 commit comments

Comments
 (0)