From 20739e03ec6ccb010391b5179315368a2dd3a594 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 13:58:33 +0200 Subject: [PATCH] Improve exception handling during handshake. Also refactor tests for Sans-I/O client and server. --- src/websockets/client.py | 8 +- src/websockets/server.py | 22 +- tests/test_client.py | 726 ++++++++++++++++++++------------------- tests/test_connection.py | 1 + tests/test_server.py | 658 +++++++++++++++++++++-------------- 5 files changed, 799 insertions(+), 616 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 0e36fd028..e5f294986 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -175,10 +175,10 @@ def process_response(self, response: Response) -> None: try: s_w_accept = headers["Sec-WebSocket-Accept"] - except KeyError as exc: - raise InvalidHeader("Sec-WebSocket-Accept") from exc - except MultipleValuesError as exc: - raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc + except KeyError: + raise InvalidHeader("Sec-WebSocket-Accept") from None + except MultipleValuesError: + raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from None if s_w_accept != accept_key(self.key): raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) diff --git a/src/websockets/server.py b/src/websockets/server.py index b2671f402..006d5bdd5 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -253,10 +253,10 @@ def process_request( try: key = headers["Sec-WebSocket-Key"] - except KeyError as exc: - raise InvalidHeader("Sec-WebSocket-Key") from exc - except MultipleValuesError as exc: - raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc + except KeyError: + raise InvalidHeader("Sec-WebSocket-Key") from None + except MultipleValuesError: + raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from None try: raw_key = base64.b64decode(key.encode(), validate=True) @@ -267,10 +267,10 @@ def process_request( try: version = headers["Sec-WebSocket-Version"] - except KeyError as exc: - raise InvalidHeader("Sec-WebSocket-Version") from exc - except MultipleValuesError as exc: - raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc + except KeyError: + raise InvalidHeader("Sec-WebSocket-Version") from None + except MultipleValuesError: + raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from None if version != "13": raise InvalidHeaderValue("Sec-WebSocket-Version", version) @@ -308,8 +308,8 @@ def process_origin(self, headers: Headers) -> Origin | None: # per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3. try: origin = headers.get("Origin") - except MultipleValuesError as exc: - raise InvalidHeader("Origin", "multiple values") from exc + except MultipleValuesError: + raise InvalidHeader("Origin", "multiple values") from None if origin is not None: origin = cast(Origin, origin) if self.origins is not None: @@ -503,7 +503,7 @@ def reject(self, status: StatusLike, text: str) -> Response: HTTP response to send to the client. """ - # If a user passes an int instead of a HTTPStatus, fix it automatically. + # If status is an int instead of an HTTPStatus, fix it automatically. status = http.HTTPStatus(status) body = text.encode() headers = Headers( diff --git a/tests/test_client.py b/tests/test_client.py index 47558c1c0..2468be85e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,11 +1,14 @@ +import contextlib +import dataclasses import logging +import types import unittest -import unittest.mock +from unittest.mock import patch from websockets.client import * from websockets.client import backoff from websockets.datastructures import Headers -from websockets.exceptions import InvalidHandshake, InvalidHeader +from websockets.exceptions import InvalidHandshake, InvalidHeader, InvalidStatus from websockets.frames import OP_TEXT, Frame from websockets.http11 import Request, Response from websockets.protocol import CONNECTING, OPEN @@ -22,13 +25,19 @@ from .utils import DATE, DeprecationTestCase -class ConnectTests(unittest.TestCase): - def test_send_connect(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("wss://example.com/test")) +URI = parse_uri("wss://example.com/test") # for tests where the URI doesn't matter + + +@patch("websockets.client.generate_key", return_value=KEY) +class BasicTests(unittest.TestCase): + """Test basic opening handshake scenarios.""" + + def test_send_request(self, _generate_key): + """Client sends a handshake request.""" + client = ClientProtocol(URI) request = client.connect() - self.assertIsInstance(request, Request) client.send_request(request) + self.assertEqual( client.data_to_send(), [ @@ -42,11 +51,56 @@ def test_send_connect(self): ], ) self.assertFalse(client.close_expected()) + self.assertEqual(client.state, CONNECTING) + + def test_receive_successful_response(self, _generate_key): + """Client receives a successful handshake response.""" + client = ClientProtocol(URI) + client.receive_data( + ( + f"HTTP/1.1 101 Switching Protocols\r\n" + f"Upgrade: websocket\r\n" + f"Connection: Upgrade\r\n" + f"Sec-WebSocket-Accept: {ACCEPT}\r\n" + f"Date: {DATE}\r\n" + f"\r\n" + ).encode(), + ) + + self.assertEqual(client.data_to_send(), []) + self.assertFalse(client.close_expected()) + self.assertEqual(client.state, OPEN) + + def test_receive_failed_response(self, _generate_key): + """Client receives a failed handshake response.""" + client = ClientProtocol(URI) + client.receive_data( + ( + f"HTTP/1.1 404 Not Found\r\n" + f"Date: {DATE}\r\n" + f"Content-Length: 13\r\n" + f"Content-Type: text/plain; charset=utf-8\r\n" + f"Connection: close\r\n" + f"\r\n" + f"Sorry folks.\n" + ).encode(), + ) + + self.assertEqual(client.data_to_send(), [b""]) + self.assertTrue(client.close_expected()) + self.assertEqual(client.state, CONNECTING) + + +class RequestTests(unittest.TestCase): + """Test generating opening handshake requests.""" - def test_connect_request(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("wss://example.com/test")) + @patch("websockets.client.generate_key", return_value=KEY) + def test_connect(self, _generate_key): + """connect() creates an opening handshake request.""" + client = ClientProtocol(URI) request = client.connect() + + self.assertIsInstance(request, Request) self.assertEqual(request.path, "/test") self.assertEqual( request.headers, @@ -62,12 +116,14 @@ def test_connect_request(self): ) def test_path(self): + """connect() uses the path from the URI.""" client = ClientProtocol(parse_uri("wss://example.com/endpoint?test=1")) request = client.connect() self.assertEqual(request.path, "/endpoint?test=1") def test_port(self): + """connect() uses the port from the URI or the default port.""" for uri, host in [ ("ws://example.com/", "example.com"), ("ws://example.com:80/", "example.com"), @@ -83,85 +139,41 @@ def test_port(self): self.assertEqual(request.headers["Host"], host) def test_user_info(self): + """connect() perfoms HTTP Basic Authentication with user info from the URI.""" client = ClientProtocol(parse_uri("wss://hello:iloveyou@example.com/")) request = client.connect() self.assertEqual(request.headers["Authorization"], "Basic aGVsbG86aWxvdmV5b3U=") def test_origin(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - origin="https://example.com", - ) + """connect(origin=...) generates an Origin header.""" + client = ClientProtocol(URI, origin="https://example.com") request = client.connect() self.assertEqual(request.headers["Origin"], "https://example.com") def test_extensions(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientOpExtensionFactory()], - ) + """connect(extensions=...) generates a Sec-WebSocket-Extensions header.""" + client = ClientProtocol(URI, extensions=[ClientOpExtensionFactory()]) request = client.connect() self.assertEqual(request.headers["Sec-WebSocket-Extensions"], "x-op; op") def test_subprotocols(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - subprotocols=["chat"], - ) + """connect(subprotocols=...) generates a Sec-WebSocket-Protocol header.""" + client = ClientProtocol(URI, subprotocols=["chat"]) request = client.connect() self.assertEqual(request.headers["Sec-WebSocket-Protocol"], "chat") -class AcceptRejectTests(unittest.TestCase): - def test_receive_accept(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() - client.receive_data( - ( - f"HTTP/1.1 101 Switching Protocols\r\n" - f"Upgrade: websocket\r\n" - f"Connection: Upgrade\r\n" - f"Sec-WebSocket-Accept: {ACCEPT}\r\n" - f"Date: {DATE}\r\n" - f"\r\n" - ).encode(), - ) - [response] = client.events_received() - self.assertIsInstance(response, Response) - self.assertEqual(client.data_to_send(), []) - self.assertFalse(client.close_expected()) - self.assertEqual(client.state, OPEN) - - def test_receive_reject(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() - client.receive_data( - ( - f"HTTP/1.1 404 Not Found\r\n" - f"Date: {DATE}\r\n" - f"Content-Length: 13\r\n" - f"Content-Type: text/plain; charset=utf-8\r\n" - f"Connection: close\r\n" - f"\r\n" - f"Sorry folks.\n" - ).encode(), - ) - [response] = client.events_received() - self.assertIsInstance(response, Response) - self.assertEqual(client.data_to_send(), [b""]) - self.assertTrue(client.close_expected()) - self.assertEqual(client.state, CONNECTING) +@patch("websockets.client.generate_key", return_value=KEY) +class ResponseTests(unittest.TestCase): + """Test receiving opening handshake responses.""" - def test_accept_response(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() + def test_receive_successful_response(self, _generate_key): + """Client receives a successful handshake response.""" + client = ClientProtocol(URI) client.receive_data( ( f"HTTP/1.1 101 Switching Protocols\r\n" @@ -173,6 +185,7 @@ def test_accept_response(self): ).encode(), ) [response] = client.events_received() + self.assertEqual(response.status_code, 101) self.assertEqual(response.reason_phrase, "Switching Protocols") self.assertEqual( @@ -187,11 +200,11 @@ def test_accept_response(self): ), ) self.assertIsNone(response.body) + self.assertIsNone(client.handshake_exc) - def test_reject_response(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() + def test_receive_failed_response(self, _generate_key): + """Client receives a failed handshake response.""" + client = ClientProtocol(URI) client.receive_data( ( f"HTTP/1.1 404 Not Found\r\n" @@ -204,6 +217,7 @@ def test_reject_response(self): ).encode(), ) [response] = client.events_received() + self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase, "Not Found") self.assertEqual( @@ -218,394 +232,416 @@ def test_reject_response(self): ), ) self.assertEqual(response.body, b"Sorry folks.\n") + self.assertIsInstance(client.handshake_exc, InvalidStatus) + self.assertEqual( + str(client.handshake_exc), + "server rejected WebSocket connection: HTTP 404", + ) - def test_no_response(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() + def test_receive_no_response(self, _generate_key): + """Client receives no handshake response.""" + client = ClientProtocol(URI) client.receive_eof() + self.assertEqual(client.events_received(), []) + self.assertIsInstance(client.handshake_exc, EOFError) + self.assertEqual( + str(client.handshake_exc), + "connection closed while reading HTTP status line", + ) - def test_partial_response(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() + def test_receive_truncated_response(self, _generate_key): + """Client receives a truncated handshake response.""" + client = ClientProtocol(URI) client.receive_data(b"HTTP/1.1 101 Switching Protocols\r\n") client.receive_eof() + self.assertEqual(client.events_received(), []) + self.assertIsInstance(client.handshake_exc, EOFError) + self.assertEqual( + str(client.handshake_exc), + "connection closed while reading HTTP headers", + ) - def test_random_response(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() + def test_receive_random_response(self, _generate_key): + """Client receives a junk handshake response.""" + client = ClientProtocol(URI) client.receive_data(b"220 smtp.invalid\r\n") client.receive_data(b"250 Hello relay.invalid\r\n") client.receive_data(b"250 Ok\r\n") client.receive_data(b"250 Ok\r\n") - client.receive_eof() - self.assertEqual(client.events_received(), []) - def make_accept_response(self, client): - request = client.connect() - return Response( - status_code=101, - reason_phrase="Switching Protocols", - headers=Headers( - { - "Upgrade": "websocket", - "Connection": "Upgrade", - "Sec-WebSocket-Accept": accept_key( - request.headers["Sec-WebSocket-Key"] - ), - } - ), + self.assertEqual(client.events_received(), []) + self.assertIsInstance(client.handshake_exc, ValueError) + self.assertEqual( + str(client.handshake_exc), + "invalid HTTP status line: 220 smtp.invalid", ) - def test_basic(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - client.receive_data(response.serialize()) - [response] = client.events_received() +@contextlib.contextmanager +def alter_and_receive_response(client): + """Generate a handshake response that can be altered for testing.""" + # We could start by sending a handshake request, i.e.: + # request = client.connect() + # client.send_request(request) + # However, in the current implementation, these calls have no effect on the + # state of the client. Therefore, they're unnecessary and can be skipped. + response = Response( + status_code=101, + reason_phrase="Switching Protocols", + headers=Headers( + { + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Accept": accept_key(client.key), + } + ), + ) + yield response + client.receive_data(response.serialize()) + [parsed_response] = client.events_received() + assert response == dataclasses.replace(parsed_response, _exception=None) + + +class HandshakeTests(unittest.TestCase): + """Test processing of handshake responses to configure the connection.""" + + def assertHandshakeSuccess(self, client): + """Assert that the opening handshake succeeded.""" self.assertEqual(client.state, OPEN) + self.assertIsNone(client.handshake_exc) - def test_missing_connection(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - del response.headers["Connection"] - client.receive_data(response.serialize()) - [response] = client.events_received() - + def assertHandshakeError(self, client, exc_type, msg): + """Assert that the opening handshake failed with the given exception.""" self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "missing Connection header") + self.assertIsInstance(client.handshake_exc, exc_type) + # Exception chaining isn't used is client handshake implementation. + assert client.handshake_exc.__cause__ is None + self.assertEqual(str(client.handshake_exc), msg) - def test_invalid_connection(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - del response.headers["Connection"] - response.headers["Connection"] = "close" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_basic(self): + """Handshake succeeds.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client): + pass - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "invalid Connection header: close") + self.assertHandshakeSuccess(client) - def test_missing_upgrade(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - del response.headers["Upgrade"] - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_missing_connection(self): + """Handshake fails when the Connection header is missing.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + del response.headers["Connection"] + + self.assertHandshakeError( + client, + InvalidHeader, + "missing Connection header", + ) - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "missing Upgrade header") + def test_invalid_connection(self): + """Handshake fails when the Connection header is invalid.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + del response.headers["Connection"] + response.headers["Connection"] = "close" + + self.assertHandshakeError( + client, + InvalidHeader, + "invalid Connection header: close", + ) - def test_invalid_upgrade(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - del response.headers["Upgrade"] - response.headers["Upgrade"] = "h2c" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_missing_upgrade(self): + """Handshake fails when the Upgrade header is missing.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + del response.headers["Upgrade"] + + self.assertHandshakeError( + client, + InvalidHeader, + "missing Upgrade header", + ) - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") + def test_invalid_upgrade(self): + """Handshake fails when the Upgrade header is invalid.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + del response.headers["Upgrade"] + response.headers["Upgrade"] = "h2c" + + self.assertHandshakeError( + client, + InvalidHeader, + "invalid Upgrade header: h2c", + ) def test_missing_accept(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - del response.headers["Sec-WebSocket-Accept"] - client.receive_data(response.serialize()) - [response] = client.events_received() - - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Accept header") + """Handshake fails when the Sec-WebSocket-Accept header is missing.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + del response.headers["Sec-WebSocket-Accept"] + + self.assertHandshakeError( + client, + InvalidHeader, + "missing Sec-WebSocket-Accept header", + ) def test_multiple_accept(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Accept"] = ACCEPT - client.receive_data(response.serialize()) - [response] = client.events_received() - - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual( - str(raised.exception), + """Handshake fails when the Sec-WebSocket-Accept header is repeated.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Accept"] = ACCEPT + + self.assertHandshakeError( + client, + InvalidHeader, "invalid Sec-WebSocket-Accept header: multiple values", ) def test_invalid_accept(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - del response.headers["Sec-WebSocket-Accept"] - response.headers["Sec-WebSocket-Accept"] = ACCEPT - client.receive_data(response.serialize()) - [response] = client.events_received() - - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual( - str(raised.exception), f"invalid Sec-WebSocket-Accept header: {ACCEPT}" + """Handshake fails when the Sec-WebSocket-Accept header is invalid.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + del response.headers["Sec-WebSocket-Accept"] + response.headers["Sec-WebSocket-Accept"] = ACCEPT + + self.assertHandshakeError( + client, + InvalidHeader, + f"invalid Sec-WebSocket-Accept header: {ACCEPT}", ) def test_no_extensions(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - client.receive_data(response.serialize()) - [response] = client.events_received() + """Handshake succeeds without extensions.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client): + pass - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, []) - def test_no_extension(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientOpExtensionFactory()], - ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_offer_extension(self): + """Client offers an extension.""" + client = ClientProtocol(URI, extensions=[ClientRsv2ExtensionFactory()]) + request = client.connect() - self.assertEqual(client.state, OPEN) - self.assertEqual(client.extensions, [OpExtension()]) + self.assertEqual(request.headers["Sec-WebSocket-Extensions"], "x-rsv2") - def test_extension(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientRsv2ExtensionFactory()], - ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_enable_extension(self): + """Client offers an extension and the server enables it.""" + client = ClientProtocol(URI, extensions=[ClientRsv2ExtensionFactory()]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [Rsv2Extension()]) - def test_unexpected_extension(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_extension_not_enabled(self): + """Client offers an extension, but the server doesn't enable it.""" + client = ClientProtocol(URI, extensions=[ClientRsv2ExtensionFactory()]) + with alter_and_receive_response(client): + pass - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHandshake) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "no extensions supported") + self.assertHandshakeSuccess(client) + self.assertEqual(client.extensions, []) - def test_unsupported_extension(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientRsv2ExtensionFactory()], + def test_no_extensions_offered(self): + """Server enables an extension when the client didn't offer any.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" + + self.assertHandshakeError( + client, + InvalidHandshake, + "no extensions supported", ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - client.receive_data(response.serialize()) - [response] = client.events_received() - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHandshake) as raised: - raise client.handshake_exc - self.assertEqual( - str(raised.exception), + def test_extension_not_offered(self): + """Server enables an extension that the client didn't offer.""" + client = ClientProtocol(URI, extensions=[ClientRsv2ExtensionFactory()]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-op; op" + + self.assertHandshakeError( + client, + InvalidHandshake, "Unsupported extension: name = x-op, params = [('op', None)]", ) def test_supported_extension_parameters(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientOpExtensionFactory("this")], - ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" - client.receive_data(response.serialize()) - [response] = client.events_received() + """Server enables an extension with parameters supported by the client.""" + client = ClientProtocol(URI, extensions=[ClientOpExtensionFactory("this")]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [OpExtension("this")]) def test_unsupported_extension_parameters(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientOpExtensionFactory("this")], - ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" - client.receive_data(response.serialize()) - [response] = client.events_received() - - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHandshake) as raised: - raise client.handshake_exc - self.assertEqual( - str(raised.exception), + """Server enables an extension with parameters unsupported by the client.""" + client = ClientProtocol(URI, extensions=[ClientOpExtensionFactory("this")]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" + + self.assertHandshakeError( + client, + InvalidHandshake, "Unsupported extension: name = x-op, params = [('op', 'that')]", ) def test_multiple_supported_extension_parameters(self): + """Client offers the same extension with several parameters.""" client = ClientProtocol( - parse_uri("wss://example.com/"), + URI, extensions=[ ClientOpExtensionFactory("this"), ClientOpExtensionFactory("that"), ], ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" - client.receive_data(response.serialize()) - [response] = client.events_received() + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [OpExtension("that")]) def test_multiple_extensions(self): + """Client offers several extensions and the server enables them.""" client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientOpExtensionFactory(), ClientRsv2ExtensionFactory()], + URI, + extensions=[ + ClientOpExtensionFactory(), + ClientRsv2ExtensionFactory(), + ], ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - client.receive_data(response.serialize()) - [response] = client.events_received() + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-op; op" + response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [OpExtension(), Rsv2Extension()]) def test_multiple_extensions_order(self): + """Client respects the order of extensions chosen by the server.""" client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientOpExtensionFactory(), ClientRsv2ExtensionFactory()], + URI, + extensions=[ + ClientOpExtensionFactory(), + ClientRsv2ExtensionFactory(), + ], ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - client.receive_data(response.serialize()) - [response] = client.events_received() + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" + response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [Rsv2Extension(), OpExtension()]) def test_no_subprotocols(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - client.receive_data(response.serialize()) - [response] = client.events_received() + """Handshake succeeds without subprotocols.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client): + pass - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertIsNone(client.subprotocol) - def test_no_subprotocol(self): - client = ClientProtocol(parse_uri("wss://example.com/"), subprotocols=["chat"]) - response = self.make_accept_response(client) - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_no_subprotocol_requested(self): + """Client doesn't offer a subprotocol, but the server enables one.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Protocol"] = "chat" - self.assertEqual(client.state, OPEN) - self.assertIsNone(client.subprotocol) + self.assertHandshakeError( + client, + InvalidHandshake, + "no subprotocols supported", + ) - def test_subprotocol(self): - client = ClientProtocol(parse_uri("wss://example.com/"), subprotocols=["chat"]) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Protocol"] = "chat" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_offer_subprotocol(self): + """Client offers a subprotocol.""" + client = ClientProtocol(URI, subprotocols=["chat"]) + request = client.connect() - self.assertEqual(client.state, OPEN) - self.assertEqual(client.subprotocol, "chat") + self.assertEqual(request.headers["Sec-WebSocket-Protocol"], "chat") - def test_unexpected_subprotocol(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Protocol"] = "chat" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_enable_subprotocol(self): + """Client offers a subprotocol and the server enables it.""" + client = ClientProtocol(URI, subprotocols=["chat"]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Protocol"] = "chat" - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHandshake) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "no subprotocols supported") + self.assertHandshakeSuccess(client) + self.assertEqual(client.subprotocol, "chat") - def test_multiple_subprotocols(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - subprotocols=["superchat", "chat"], - ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Protocol"] = "superchat" - response.headers["Sec-WebSocket-Protocol"] = "chat" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_no_subprotocol_accepted(self): + """Client offers a subprotocol, but the server doesn't enable it.""" + client = ClientProtocol(URI, subprotocols=["chat"]) + with alter_and_receive_response(client): + pass - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHandshake) as raised: - raise client.handshake_exc - self.assertEqual( - str(raised.exception), - "invalid Sec-WebSocket-Protocol header: " - "multiple values: superchat, chat", - ) + self.assertHandshakeSuccess(client) + self.assertIsNone(client.subprotocol) - def test_supported_subprotocol(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - subprotocols=["superchat", "chat"], - ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Protocol"] = "chat" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_multiple_subprotocols(self): + """Client offers several subprotocols and the server enables one.""" + client = ClientProtocol(URI, subprotocols=["superchat", "chat"]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Protocol"] = "chat" - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.subprotocol, "chat") def test_unsupported_subprotocol(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - subprotocols=["superchat", "chat"], + """Client offers subprotocols but the server enables another one.""" + client = ClientProtocol(URI, subprotocols=["superchat", "chat"]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Protocol"] = "otherchat" + + self.assertHandshakeError( + client, + InvalidHandshake, + "unsupported subprotocol: otherchat", ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Protocol"] = "otherchat" - client.receive_data(response.serialize()) - [response] = client.events_received() - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHandshake) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "unsupported subprotocol: otherchat") + def test_multiple_subprotocols_accepted(self): + """Server attempts to enable multiple subprotocols.""" + client = ClientProtocol(URI, subprotocols=["superchat", "chat"]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Protocol"] = "superchat" + response.headers["Sec-WebSocket-Protocol"] = "chat" + + self.assertHandshakeError( + client, + InvalidHandshake, + "invalid Sec-WebSocket-Protocol header: " + "multiple values: superchat, chat", + ) class MiscTests(unittest.TestCase): def test_bypass_handshake(self): - client = ClientProtocol(parse_uri("ws://example.com/test"), state=OPEN) + """ClientProtocol bypasses the opening handshake.""" + client = ClientProtocol(URI, state=OPEN) client.receive_data(b"\x81\x06Hello!") [frame] = client.events_received() self.assertEqual(frame, Frame(OP_TEXT, b"Hello!")) def test_custom_logger(self): + """ClientProtocol accepts a logger argument.""" logger = logging.getLogger("test") with self.assertLogs("test", logging.DEBUG) as logs: - ClientProtocol(parse_uri("wss://example.com/test"), logger=logger) + ClientProtocol(URI, logger=logger) self.assertEqual(len(logs.records), 1) class BackwardsCompatibilityTests(DeprecationTestCase): def test_client_connection_class(self): + """ClientConnection is a deprecated alias for ClientProtocol.""" with self.assertDeprecationWarning( "ClientConnection was renamed to ClientProtocol" ): @@ -618,7 +654,9 @@ def test_client_connection_class(self): class BackoffTests(unittest.TestCase): def test_backoff(self): + """backoff() yields a random delay, then exponentially increasing delays.""" backoff_gen = backoff() + self.assertIsInstance(backoff_gen, types.GeneratorType) initial_delay = next(backoff_gen) self.assertGreaterEqual(initial_delay, 0) diff --git a/tests/test_connection.py b/tests/test_connection.py index 6592d67d0..9ad2ebea4 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -5,6 +5,7 @@ class BackwardsCompatibilityTests(DeprecationTestCase): def test_connection_class(self): + """Connection is a deprecated alias for Protocol.""" with self.assertDeprecationWarning( "websockets.connection was renamed to websockets.protocol " "and Connection was renamed to Protocol" diff --git a/tests/test_server.py b/tests/test_server.py index 52c8a2b99..844ba64ec 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,7 +1,8 @@ import http import logging +import sys import unittest -import unittest.mock +from unittest.mock import patch from websockets.datastructures import Headers from websockets.exceptions import ( @@ -25,8 +26,28 @@ from .utils import DATE, DeprecationTestCase -class ConnectTests(unittest.TestCase): - def test_receive_connect(self): +def make_request(): + """Generate a handshake request that can be altered for testing.""" + return Request( + path="/test", + headers=Headers( + { + "Host": "example.com", + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Key": KEY, + "Sec-WebSocket-Version": "13", + } + ), + ) + + +@patch("email.utils.formatdate", return_value=DATE) +class BasicTests(unittest.TestCase): + """Test basic opening handshake scenarios.""" + + def test_receive_request(self, _formatdate): + """Server receives a handshake request.""" server = ServerProtocol() server.receive_data( ( @@ -39,80 +60,18 @@ def test_receive_connect(self): f"\r\n" ).encode(), ) - [request] = server.events_received() - self.assertIsInstance(request, Request) + self.assertEqual(server.data_to_send(), []) self.assertFalse(server.close_expected()) + self.assertEqual(server.state, CONNECTING) - def test_connect_request(self): - server = ServerProtocol() - server.receive_data( - ( - f"GET /test HTTP/1.1\r\n" - f"Host: example.com\r\n" - f"Upgrade: websocket\r\n" - f"Connection: Upgrade\r\n" - f"Sec-WebSocket-Key: {KEY}\r\n" - f"Sec-WebSocket-Version: 13\r\n" - f"\r\n" - ).encode(), - ) - [request] = server.events_received() - self.assertEqual(request.path, "/test") - self.assertEqual( - request.headers, - Headers( - { - "Host": "example.com", - "Upgrade": "websocket", - "Connection": "Upgrade", - "Sec-WebSocket-Key": KEY, - "Sec-WebSocket-Version": "13", - } - ), - ) - - def test_no_request(self): - server = ServerProtocol() - server.receive_eof() - self.assertEqual(server.events_received(), []) - - def test_partial_request(self): - server = ServerProtocol() - server.receive_data(b"GET /test HTTP/1.1\r\n") - server.receive_eof() - self.assertEqual(server.events_received(), []) - - def test_junk_request(self): - server = ServerProtocol() - server.receive_data(b"HELO relay.invalid\r\n") - server.receive_data(b"MAIL FROM: \r\n") - server.receive_data(b"RCPT TO: \r\n") - self.assertEqual(server.events_received(), []) - - -class AcceptRejectTests(unittest.TestCase): - def make_request(self): - return Request( - path="/test", - headers=Headers( - { - "Host": "example.com", - "Upgrade": "websocket", - "Connection": "Upgrade", - "Sec-WebSocket-Key": KEY, - "Sec-WebSocket-Version": "13", - } - ), - ) - - def test_send_response_after_successful_accept(self): + def test_accept_and_send_successful_response(self, _formatdate): + """Server accepts a handshake request and sends a successful response.""" server = ServerProtocol() - request = self.make_request() - with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - response = server.accept(request) - self.assertIsInstance(response, Response) + request = make_request() + response = server.accept(request) server.send_response(response) + self.assertEqual( server.data_to_send(), [ @@ -127,37 +86,37 @@ def test_send_response_after_successful_accept(self): self.assertFalse(server.close_expected()) self.assertEqual(server.state, OPEN) - def test_send_response_after_failed_accept(self): + def test_send_response_after_failed_accept(self, _formatdate): + """Server accepts a handshake request but sends a failed response.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Sec-WebSocket-Key"] - with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - response = server.accept(request) - self.assertIsInstance(response, Response) + response = server.accept(request) server.send_response(response) + self.assertEqual( server.data_to_send(), [ f"HTTP/1.1 400 Bad Request\r\n" f"Date: {DATE}\r\n" f"Connection: close\r\n" - f"Content-Length: 94\r\n" + f"Content-Length: 73\r\n" f"Content-Type: text/plain; charset=utf-8\r\n" f"\r\n" f"Failed to open a WebSocket connection: " - f"missing Sec-WebSocket-Key header; 'sec-websocket-key'.\n".encode(), + f"missing Sec-WebSocket-Key header.\n".encode(), b"", ], ) self.assertTrue(server.close_expected()) self.assertEqual(server.state, CONNECTING) - def test_send_response_after_reject(self): + def test_send_response_after_reject(self, _formatdate): + """Server rejects a handshake request and sends a failed response.""" server = ServerProtocol() - with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") - self.assertIsInstance(response, Response) + response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") server.send_response(response) + self.assertEqual( server.data_to_send(), [ @@ -174,23 +133,124 @@ def test_send_response_after_reject(self): self.assertTrue(server.close_expected()) self.assertEqual(server.state, CONNECTING) - def test_send_response_without_accept_or_reject(self): + def test_send_response_without_accept_or_reject(self, _formatdate): + """Server doesn't accept or reject and sends a failed response.""" server = ServerProtocol() - server.send_response(Response(410, "Gone", Headers(), b"AWOL.\n")) + server.send_response( + Response( + 410, + "Gone", + Headers( + { + "Connection": "close", + "Content-Length": 6, + "Content-Type": "text/plain", + } + ), + b"AWOL.\n", + ) + ) self.assertEqual( server.data_to_send(), [ - "HTTP/1.1 410 Gone\r\n\r\nAWOL.\n".encode(), + "HTTP/1.1 410 Gone\r\n" + "Connection: close\r\n" + "Content-Length: 6\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "AWOL.\n".encode(), b"", ], ) self.assertTrue(server.close_expected()) self.assertEqual(server.state, CONNECTING) - def test_accept_response(self): + +class RequestTests(unittest.TestCase): + """Test receiving opening handshake requests.""" + + def test_receive_request(self): + """Server receives a handshake request.""" server = ServerProtocol() - with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - response = server.accept(self.make_request()) + server.receive_data( + ( + f"GET /test HTTP/1.1\r\n" + f"Host: example.com\r\n" + f"Upgrade: websocket\r\n" + f"Connection: Upgrade\r\n" + f"Sec-WebSocket-Key: {KEY}\r\n" + f"Sec-WebSocket-Version: 13\r\n" + f"\r\n" + ).encode(), + ) + [request] = server.events_received() + + self.assertIsInstance(request, Request) + self.assertEqual(request.path, "/test") + self.assertEqual( + request.headers, + Headers( + { + "Host": "example.com", + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Key": KEY, + "Sec-WebSocket-Version": "13", + } + ), + ) + self.assertIsNone(server.handshake_exc) + + def test_receive_no_request(self): + """Server receives no handshake request.""" + server = ServerProtocol() + server.receive_eof() + + self.assertEqual(server.events_received(), []) + self.assertIsInstance(server.handshake_exc, EOFError) + self.assertEqual( + str(server.handshake_exc), + "connection closed while reading HTTP request line", + ) + + def test_receive_truncated_request(self): + """Server receives a truncated handshake request.""" + server = ServerProtocol() + server.receive_data(b"GET /test HTTP/1.1\r\n") + server.receive_eof() + + self.assertEqual(server.events_received(), []) + self.assertIsInstance(server.handshake_exc, EOFError) + self.assertEqual( + str(server.handshake_exc), + "connection closed while reading HTTP headers", + ) + + def test_receive_junk_request(self): + """Server receives a junk handshake request.""" + server = ServerProtocol() + server.receive_data(b"HELO relay.invalid\r\n") + server.receive_data(b"MAIL FROM: \r\n") + server.receive_data(b"RCPT TO: \r\n") + + self.assertEqual(server.events_received(), []) + self.assertIsInstance(server.handshake_exc, ValueError) + self.assertEqual( + str(server.handshake_exc), + "invalid HTTP request line: HELO relay.invalid", + ) + + +class ResponseTests(unittest.TestCase): + """Test generating opening handshake responses.""" + + @patch("email.utils.formatdate", return_value=DATE) + def test_accept_response(self, _formatdate): + """accept() creates a successful opening handshake response.""" + server = ServerProtocol() + request = make_request() + response = server.accept(request) + self.assertIsInstance(response, Response) self.assertEqual(response.status_code, 101) self.assertEqual(response.reason_phrase, "Switching Protocols") @@ -207,10 +267,12 @@ def test_accept_response(self): ) self.assertIsNone(response.body) - def test_reject_response(self): + @patch("email.utils.formatdate", return_value=DATE) + def test_reject_response(self, _formatdate): + """reject() creates a failed opening handshake response.""" server = ServerProtocol() - with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") + response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") + self.assertIsInstance(response, Response) self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase, "Not Found") @@ -228,477 +290,552 @@ def test_reject_response(self): self.assertEqual(response.body, b"Sorry folks.\n") def test_reject_response_supports_int_status(self): + """reject() accepts an integer status code instead of an HTTPStatus.""" server = ServerProtocol() response = server.reject(404, "Sorry folks.\n") + self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase, "Not Found") - def test_basic(self): + @patch("websockets.server.ServerProtocol.process_request") + def test_unexpected_error(self, process_request): + """accept() handles unexpected errors and returns an error response.""" server = ServerProtocol() - request = self.make_request() + request = make_request() + process_request.side_effect = (Exception("BOOM"),) response = server.accept(request) - self.assertEqual(response.status_code, 101) + self.assertEqual(response.status_code, 500) + self.assertIsInstance(server.handshake_exc, Exception) + self.assertEqual(str(server.handshake_exc), "BOOM") - def test_unexpected_exception(self): + +class HandshakeTests(unittest.TestCase): + """Test processing of handshake responses to configure the connection.""" + + def assertHandshakeSuccess(self, server): + """Assert that the opening handshake succeeded.""" + self.assertEqual(server.state, OPEN) + self.assertIsNone(server.handshake_exc) + + def assertHandshakeError(self, server, exc_type, msg): + """Assert that the opening handshake failed with the given exception.""" + self.assertEqual(server.state, CONNECTING) + self.assertIsInstance(server.handshake_exc, exc_type) + exc = server.handshake_exc + exc_str = str(exc) + while exc.__cause__ is not None: + exc = exc.__cause__ + exc_str += "; " + str(exc) + self.assertEqual(exc_str, msg) + + def test_basic(self): + """Handshake succeeds.""" server = ServerProtocol() - request = self.make_request() - with unittest.mock.patch( - "websockets.server.ServerProtocol.process_request", - side_effect=Exception("BOOM"), - ): - response = server.accept(request) + request = make_request() + response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 500) - with self.assertRaises(Exception) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), - "BOOM", - ) + self.assertHandshakeSuccess(server) def test_missing_connection(self): + """Handshake fails when the Connection header is missing.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Connection"] response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") - with self.assertRaises(InvalidUpgrade) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidUpgrade, "missing Connection header", ) def test_invalid_connection(self): + """Handshake fails when the Connection header is invalid.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Connection"] request.headers["Connection"] = "close" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") - with self.assertRaises(InvalidUpgrade) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidUpgrade, "invalid Connection header: close", ) def test_missing_upgrade(self): + """Handshake fails when the Upgrade header is missing.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Upgrade"] response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") - with self.assertRaises(InvalidUpgrade) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidUpgrade, "missing Upgrade header", ) def test_invalid_upgrade(self): + """Handshake fails when the Upgrade header is invalid.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Upgrade"] request.headers["Upgrade"] = "h2c" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") - with self.assertRaises(InvalidUpgrade) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidUpgrade, "invalid Upgrade header: h2c", ) def test_missing_key(self): + """Handshake fails when the Sec-WebSocket-Key header is missing.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Sec-WebSocket-Key"] response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, "missing Sec-WebSocket-Key header", ) def test_multiple_key(self): + """Handshake fails when the Sec-WebSocket-Key header is repeated.""" server = ServerProtocol() - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Key"] = KEY response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, "invalid Sec-WebSocket-Key header: multiple values", ) def test_invalid_key(self): + """Handshake fails when the Sec-WebSocket-Key header is invalid.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Sec-WebSocket-Key"] - request.headers["Sec-WebSocket-Key"] = "not Base64 data!" + request.headers["Sec-WebSocket-Key"] = "" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), - "invalid Sec-WebSocket-Key header: not Base64 data!", + if sys.version_info[:2] >= (3, 11): + b64_exc = "Only base64 data is allowed" + else: # pragma: no cover + b64_exc = "Non-base64 digit found" + self.assertHandshakeError( + server, + InvalidHeader, + f"invalid Sec-WebSocket-Key header: ; {b64_exc}", ) def test_truncated_key(self): + """Handshake fails when the Sec-WebSocket-Key header is truncated.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Sec-WebSocket-Key"] - request.headers["Sec-WebSocket-Key"] = KEY[ - :16 - ] # 12 bytes instead of 16, Base64-encoded + # 12 bytes instead of 16, Base64-encoded + request.headers["Sec-WebSocket-Key"] = KEY[:16] response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, f"invalid Sec-WebSocket-Key header: {KEY[:16]}", ) def test_missing_version(self): + """Handshake fails when the Sec-WebSocket-Version header is missing.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Sec-WebSocket-Version"] response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, "missing Sec-WebSocket-Version header", ) def test_multiple_version(self): + """Handshake fails when the Sec-WebSocket-Version header is repeated.""" server = ServerProtocol() - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Version"] = "11" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, "invalid Sec-WebSocket-Version header: multiple values", ) def test_invalid_version(self): + """Handshake fails when the Sec-WebSocket-Version header is invalid.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Sec-WebSocket-Version"] request.headers["Sec-WebSocket-Version"] = "11" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, "invalid Sec-WebSocket-Version header: 11", ) - def test_no_origin(self): + def test_origin(self): + """Handshake succeeds when checking origin.""" server = ServerProtocol(origins=["https://example.com"]) - request = self.make_request() + request = make_request() + request.headers["Origin"] = "https://example.com" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 403) - with self.assertRaises(InvalidOrigin) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), - "missing Origin header", - ) + self.assertHandshakeSuccess(server) + self.assertEqual(server.origin, "https://example.com") - def test_origin(self): + def test_no_origin(self): + """Handshake fails when checking origin and the Origin header is missing.""" server = ServerProtocol(origins=["https://example.com"]) - request = self.make_request() - request.headers["Origin"] = "https://example.com" + request = make_request() response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) - self.assertEqual(server.origin, "https://example.com") + self.assertEqual(response.status_code, 403) + self.assertHandshakeError( + server, + InvalidOrigin, + "missing Origin header", + ) def test_unexpected_origin(self): + """Handshake fails when checking origin and the Origin header is unexpected.""" server = ServerProtocol(origins=["https://example.com"]) - request = self.make_request() + request = make_request() request.headers["Origin"] = "https://other.example.com" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 403) - with self.assertRaises(InvalidOrigin) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidOrigin, "invalid Origin header: https://other.example.com", ) def test_multiple_origin(self): + """Handshake fails when checking origins and the Origin header is repeated.""" server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) - request = self.make_request() + request = make_request() request.headers["Origin"] = "https://example.com" request.headers["Origin"] = "https://other.example.com" response = server.accept(request) + server.send_response(response) # This is prohibited by the HTTP specification, so the return code is # 400 Bad Request rather than 403 Forbidden. self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, "invalid Origin header: multiple values", ) def test_supported_origin(self): + """Handshake succeeds when checking origins and the origin is supported.""" server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) - request = self.make_request() + request = make_request() request.headers["Origin"] = "https://other.example.com" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(server.origin, "https://other.example.com") def test_unsupported_origin(self): + """Handshake succeeds when checking origins and the origin is unsupported.""" server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) - request = self.make_request() + request = make_request() request.headers["Origin"] = "https://original.example.com" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 403) - with self.assertRaises(InvalidOrigin) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidOrigin, "invalid Origin header: https://original.example.com", ) def test_no_origin_accepted(self): + """Handshake succeeds when the lack of an origin is accepted.""" server = ServerProtocol(origins=[None]) - request = self.make_request() + request = make_request() response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertIsNone(server.origin) def test_no_extensions(self): + """Handshake succeeds without extensions.""" server = ServerProtocol() - request = self.make_request() - response = server.accept(request) - - self.assertEqual(response.status_code, 101) - self.assertNotIn("Sec-WebSocket-Extensions", response.headers) - self.assertEqual(server.extensions, []) - - def test_no_extension(self): - server = ServerProtocol(extensions=[ServerOpExtensionFactory()]) - request = self.make_request() + request = make_request() response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_extension(self): + """Server enables an extension when the client offers it.""" server = ServerProtocol(extensions=[ServerOpExtensionFactory()]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Extensions"], "x-op; op") self.assertEqual(server.extensions, [OpExtension()]) - def test_unexpected_extension(self): + def test_extension_not_enabled(self): + """Server doesn't enable an extension when the client doesn't offer it.""" + server = ServerProtocol(extensions=[ServerOpExtensionFactory()]) + request = make_request() + response = server.accept(request) + server.send_response(response) + + self.assertHandshakeSuccess(server) + self.assertNotIn("Sec-WebSocket-Extensions", response.headers) + self.assertEqual(server.extensions, []) + + def test_no_extensions_supported(self): + """Client offers an extension, but the server doesn't support any.""" server = ServerProtocol() - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) - def test_unsupported_extension(self): + def test_extension_not_supported(self): + """Client offers an extension, but the server doesn't support it.""" server = ServerProtocol(extensions=[ServerRsv2ExtensionFactory()]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_supported_extension_parameters(self): + """Client offers an extension with parameters supported by the server.""" server = ServerProtocol(extensions=[ServerOpExtensionFactory("this")]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Extensions"], "x-op; op=this") self.assertEqual(server.extensions, [OpExtension("this")]) def test_unsupported_extension_parameters(self): + """Client offers an extension with parameters unsupported by the server.""" server = ServerProtocol(extensions=[ServerOpExtensionFactory("this")]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_multiple_supported_extension_parameters(self): + """Server supports the same extension with several parameters.""" server = ServerProtocol( extensions=[ ServerOpExtensionFactory("this"), ServerOpExtensionFactory("that"), ] ) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Extensions"], "x-op; op=that") self.assertEqual(server.extensions, [OpExtension("that")]) def test_multiple_extensions(self): + """Server enables several extensions when the client offers them.""" server = ServerProtocol( extensions=[ServerOpExtensionFactory(), ServerRsv2ExtensionFactory()] ) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" request.headers["Sec-WebSocket-Extensions"] = "x-rsv2" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual( response.headers["Sec-WebSocket-Extensions"], "x-op; op, x-rsv2" ) self.assertEqual(server.extensions, [OpExtension(), Rsv2Extension()]) def test_multiple_extensions_order(self): + """Server respects the order of extensions set in its configuration.""" server = ServerProtocol( extensions=[ServerOpExtensionFactory(), ServerRsv2ExtensionFactory()] ) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-rsv2" request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual( response.headers["Sec-WebSocket-Extensions"], "x-rsv2, x-op; op" ) self.assertEqual(server.extensions, [Rsv2Extension(), OpExtension()]) def test_no_subprotocols(self): + """Handshake succeeds without subprotocols.""" server = ServerProtocol() - request = self.make_request() + request = make_request() response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) - def test_no_subprotocol(self): + def test_no_subprotocol_requested(self): + """Server expects a subprotocol, but the client doesn't offer it.""" server = ServerProtocol(subprotocols=["chat"]) - request = self.make_request() + request = make_request() response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(NegotiationError) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + NegotiationError, "missing subprotocol", ) def test_subprotocol(self): + """Server enables a subprotocol when the client offers it.""" server = ServerProtocol(subprotocols=["chat"]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "chat") self.assertEqual(server.subprotocol, "chat") - def test_unexpected_subprotocol(self): + def test_no_subprotocols_supported(self): + """Client offers a subprotocol, but the server doesn't support any.""" server = ServerProtocol() - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) def test_multiple_subprotocols(self): + """Server enables all of the subprotocols when the client offers them.""" server = ServerProtocol(subprotocols=["superchat", "chat"]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" request.headers["Sec-WebSocket-Protocol"] = "superchat" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "superchat") self.assertEqual(server.subprotocol, "superchat") def test_supported_subprotocol(self): + """Server enables one of the subprotocols when the client offers it.""" server = ServerProtocol(subprotocols=["superchat", "chat"]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "chat") self.assertEqual(server.subprotocol, "chat") def test_unsupported_subprotocol(self): + """Server expects one of the subprotocols, but the client doesn't offer any.""" server = ServerProtocol(subprotocols=["superchat", "chat"]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "otherchat" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(NegotiationError) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + NegotiationError, "invalid subprotocol; expected one of superchat, chat", ) @@ -708,34 +845,40 @@ def optional_chat(protocol, subprotocols): return "chat" def test_select_subprotocol(self): + """Server enables a subprotocol with select_subprotocol.""" server = ServerProtocol(select_subprotocol=self.optional_chat) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "chat") self.assertEqual(server.subprotocol, "chat") def test_select_no_subprotocol(self): + """Server doesn't enable any subprotocol with select_subprotocol.""" server = ServerProtocol(select_subprotocol=self.optional_chat) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "otherchat" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) class MiscTests(unittest.TestCase): def test_bypass_handshake(self): + """ServerProtocol bypasses the opening handshake.""" server = ServerProtocol(state=OPEN) server.receive_data(b"\x81\x86\x00\x00\x00\x00Hello!") [frame] = server.events_received() self.assertEqual(frame, Frame(OP_TEXT, b"Hello!")) def test_custom_logger(self): + """ServerProtocol accepts a logger argument.""" logger = logging.getLogger("test") with self.assertLogs("test", logging.DEBUG) as logs: ServerProtocol(logger=logger) @@ -744,6 +887,7 @@ def test_custom_logger(self): class BackwardsCompatibilityTests(DeprecationTestCase): def test_server_connection_class(self): + """ServerConnection is a deprecated alias for ServerProtocol.""" with self.assertDeprecationWarning( "ServerConnection was renamed to ServerProtocol" ):