diff --git a/tests/test_tcp.py b/tests/test_tcp.py index 33d806b6..f37154c9 100644 --- a/tests/test_tcp.py +++ b/tests/test_tcp.py @@ -2609,14 +2609,18 @@ async def client(addr): def test_remote_shutdown_receives_trailing_data(self): if self.implementation == 'asyncio': + # this is an issue in asyncio raise unittest.SkipTest() - CHUNK = 1024 * 128 - SIZE = 32 + CHUNK = 1024 * 16 + SIZE = 8 + count = 0 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) client_sslctx = self._create_client_ssl_context() future = None + filled = threading.Lock() + eof_received = threading.Lock() def server(sock): incoming = ssl.MemoryBIO() @@ -2647,68 +2651,71 @@ def server(sock): sslobj.write(b'pong') sock.send(outgoing.read()) - time.sleep(0.2) # wait for the peer to fill its backlog - - # send close_notify but don't wait for response - with self.assertRaises(ssl.SSLWantReadError): - sslobj.unwrap() - sock.send(outgoing.read()) - - # should receive all data data_len = 0 - while True: - try: - chunk = len(sslobj.read(16384)) - data_len += chunk - except ssl.SSLWantReadError: - incoming.write(sock.recv(16384)) - except ssl.SSLZeroReturnError: - break - - self.assertEqual(data_len, CHUNK * SIZE) - - # verify that close_notify is received - sslobj.unwrap() - sock.close() + with filled: + # trigger peer's resume_writing() + incoming.write(sock.recv(65536 * 4)) + while True: + try: + chunk = len(sslobj.read(16384)) + data_len += chunk + except ssl.SSLWantReadError: + break - def eof_server(sock): - sock.starttls(sslctx, server_side=True) - self.assertEqual(sock.recv_all(4), b'ping') - sock.send(b'pong') + # send close_notify but don't wait for response + with self.assertRaises(ssl.SSLWantReadError): + sslobj.unwrap() + sock.send(outgoing.read()) - time.sleep(0.2) # wait for the peer to fill its backlog + with eof_received: + # should receive all data + while True: + try: + chunk = len(sslobj.read(16384)) + data_len += chunk + except ssl.SSLWantReadError: + incoming.write(sock.recv(16384)) + except ssl.SSLZeroReturnError: + break - # send EOF - sock.shutdown(socket.SHUT_WR) + self.assertEqual(data_len, CHUNK * count) - # should receive all data - data = sock.recv_all(CHUNK * SIZE) - self.assertEqual(len(data), CHUNK * SIZE) + # verify that close_notify is received + sslobj.unwrap() sock.close() async def client(addr): - nonlocal future + nonlocal future, count future = self.loop.create_future() - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='') - writer.write(b'ping') - data = await reader.readexactly(4) - self.assertEqual(data, b'pong') - - # fill write backlog in a hacky way - renegotiation won't help - for _ in range(SIZE): - writer.transport._test__append_write_backlog(b'x' * CHUNK) + with eof_received: + with filled: + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='') + writer.write(b'ping') + data = await reader.readexactly(4) + self.assertEqual(data, b'pong') + + count = 0 + try: + while True: + writer.write(b'x' * CHUNK) + count += 1 + await asyncio.wait_for( + asyncio.ensure_future(writer.drain()), 0.5) + except asyncio.TimeoutError: + # fill write backlog in a hacky way + for _ in range(SIZE): + writer.transport._test__append_write_backlog( + b'x' * CHUNK) + count += 1 - try: data = await reader.read() self.assertEqual(data, b'') - except (BrokenPipeError, ConnectionResetError): - pass await future @@ -2728,9 +2735,6 @@ def wrapper(sock): with self.tcp_server(run(server)) as srv: self.loop.run_until_complete(client(srv.addr)) - with self.tcp_server(run(eof_server)) as srv: - self.loop.run_until_complete(client(srv.addr)) - def test_connect_timeout_warning(self): s = socket.socket(socket.AF_INET) s.bind(('127.0.0.1', 0)) @@ -2842,7 +2846,7 @@ def server(sock): sock.shutdown(socket.SHUT_WR) loop.call_soon_threadsafe(eof.set) # make sure we have enough time to reproduce the issue - assert sock.recv(1024) == b'' + self.assertEqual(sock.recv(1024), b'') sock.close() class Protocol(asyncio.Protocol): @@ -2875,7 +2879,92 @@ async def client(addr): tr.resume_reading() await pr.fut tr.close() - assert extra == b'extra bytes' + if self.implementation != 'asyncio': + # extra data received after transport.close() should be + # ignored - this is likely a bug in asyncio + self.assertIsNone(extra) + + with self.tcp_server(server) as srv: + loop.run_until_complete(client(srv.addr)) + + def test_shutdown_while_pause_reading(self): + if self.implementation == 'asyncio': + raise unittest.SkipTest() + + loop = self.loop + conn_made = loop.create_future() + eof_recvd = loop.create_future() + conn_lost = loop.create_future() + data_recv = False + + def server(sock): + sslctx = self._create_server_ssl_context(self.ONLYCERT, + self.ONLYKEY) + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True) + + while True: + try: + sslobj.do_handshake() + sslobj.write(b'trailing data') + break + except ssl.SSLWantReadError: + if outgoing.pending: + sock.send(outgoing.read()) + incoming.write(sock.recv(16384)) + if outgoing.pending: + sock.send(outgoing.read()) + + while True: + try: + self.assertEqual(sslobj.read(), b'') # close_notify + break + except ssl.SSLWantReadError: + incoming.write(sock.recv(16384)) + + while True: + try: + sslobj.unwrap() + except ssl.SSLWantReadError: + if outgoing.pending: + sock.send(outgoing.read()) + # incoming.write(sock.recv(16384)) + else: + if outgoing.pending: + sock.send(outgoing.read()) + break + + self.assertEqual(sock.recv(16384), b'') # socket closed + + class Protocol(asyncio.Protocol): + def connection_made(self, transport): + conn_made.set_result(None) + + def data_received(self, data): + nonlocal data_recv + data_recv = True + + def eof_received(self): + eof_recvd.set_result(None) + + def connection_lost(self, exc): + if exc is None: + conn_lost.set_result(None) + else: + conn_lost.set_exception(exc) + + async def client(addr): + ctx = self._create_client_ssl_context() + tr, _ = await loop.create_connection(Protocol, *addr, ssl=ctx) + await conn_made + self.assertFalse(data_recv) + + tr.pause_reading() + tr.close() + + await eof_recvd + await conn_lost with self.tcp_server(server) as srv: loop.run_until_complete(client(srv.addr)) diff --git a/uvloop/includes/stdlib.pxi b/uvloop/includes/stdlib.pxi index 7fd8ac98..adf9806b 100644 --- a/uvloop/includes/stdlib.pxi +++ b/uvloop/includes/stdlib.pxi @@ -129,6 +129,7 @@ cdef ssl_MemoryBIO = ssl.MemoryBIO cdef ssl_create_default_context = ssl.create_default_context cdef ssl_SSLError = ssl.SSLError cdef ssl_SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError) +cdef ssl_SSLZeroReturnError = ssl.SSLZeroReturnError cdef ssl_CertificateError = ssl.CertificateError cdef int ssl_SSL_ERROR_WANT_READ = ssl.SSL_ERROR_WANT_READ cdef int ssl_SSL_ERROR_WANT_WRITE = ssl.SSL_ERROR_WANT_WRITE diff --git a/uvloop/sslproto.pxd b/uvloop/sslproto.pxd index bc94bfd5..a6daa5c0 100644 --- a/uvloop/sslproto.pxd +++ b/uvloop/sslproto.pxd @@ -24,7 +24,7 @@ cdef enum AppProtocolState: cdef class _SSLProtocolTransport: cdef: - object _loop + Loop _loop SSLProtocol _ssl_protocol bint _closed @@ -41,7 +41,7 @@ cdef class SSLProtocol: size_t _write_buffer_size object _waiter - object _loop + Loop _loop _SSLProtocolTransport _app_transport bint _app_transport_created @@ -65,7 +65,6 @@ cdef class SSLProtocol: bint _ssl_writing_paused bint _app_reading_paused - bint _eof_received size_t _incoming_high_water size_t _incoming_low_water @@ -100,6 +99,7 @@ cdef class SSLProtocol: cdef _start_shutdown(self) cdef _check_shutdown_timeout(self) + cdef _do_read_into_void(self) cdef _do_flush(self) cdef _do_shutdown(self) cdef _on_shutdown_complete(self, shutdown_exc) diff --git a/uvloop/sslproto.pyx b/uvloop/sslproto.pyx index 66676574..ac87d499 100644 --- a/uvloop/sslproto.pyx +++ b/uvloop/sslproto.pyx @@ -17,7 +17,7 @@ cdef class _SSLProtocolTransport: # TODO: # _sendfile_compatible = constants._SendfileMode.FALLBACK - def __cinit__(self, loop, ssl_protocol): + def __cinit__(self, Loop loop, ssl_protocol): self._loop = loop # SSLProtocol instance self._ssl_protocol = ssl_protocol @@ -278,7 +278,6 @@ cdef class SSLProtocol: self._incoming_high_water = 0 self._incoming_low_water = 0 self._set_read_buffer_limits() - self._eof_received = False self._app_writing_paused = False self._outgoing_high_water = 0 @@ -392,7 +391,6 @@ cdef class SSLProtocol: will close itself. If it returns a true value, closing the transport is up to the protocol. """ - self._eof_received = True try: if self._loop.get_debug(): aio_logger.debug("%r received EOF", self) @@ -400,20 +398,17 @@ cdef class SSLProtocol: if self._state == DO_HANDSHAKE: self._on_handshake_complete(ConnectionResetError) - elif self._state == WRAPPED: - self._set_state(FLUSHING) - if self._app_reading_paused: - return True - else: - self._do_flush() - - elif self._state == FLUSHING: - self._do_write() + elif self._state == WRAPPED or self._state == FLUSHING: + # We treat a low-level EOF as a critical situation similar to a + # broken connection - just send whatever is in the buffer and + # close. No application level eof_received() is called - + # because we don't want the user to think that this is a + # graceful shutdown triggered by SSL "close_notify". self._set_state(SHUTDOWN) - self._do_shutdown() + self._on_shutdown_complete(None) elif self._state == SHUTDOWN: - self._do_shutdown() + self._on_shutdown_complete(None) except Exception: self._transport.close() @@ -444,6 +439,9 @@ cdef class SSLProtocol: elif self._state == WRAPPED and new_state == FLUSHING: allowed = True + elif self._state == WRAPPED and new_state == SHUTDOWN: + allowed = True + elif self._state == FLUSHING and new_state == SHUTDOWN: allowed = True @@ -505,7 +503,7 @@ cdef class SSLProtocol: cdef _on_handshake_complete(self, handshake_exc): if self._handshake_timeout_handle is not None: self._handshake_timeout_handle.cancel() - self._shutdown_timeout_handle = None + self._handshake_timeout_handle = None sslobj = self._sslobj try: @@ -561,23 +559,60 @@ cdef class SSLProtocol: self._transport._force_close( aio_TimeoutError('SSL shutdown timed out')) - cdef _do_flush(self): - self._do_read() - self._set_state(SHUTDOWN) - self._do_shutdown() + cdef _do_read_into_void(self): + """Consume and discard incoming application data. - cdef _do_shutdown(self): + If close_notify is received for the first time, call eof_received. + """ + cdef: + bint close_notify = False try: - if not self._eof_received: - self._sslobj.unwrap() + while True: + if not self._sslobj_read(SSL_READ_MAX_SIZE): + close_notify = True + break except ssl_SSLAgainErrors as exc: + pass + except ssl_SSLZeroReturnError: + close_notify = True + if close_notify: + self._call_eof_received() + + cdef _do_flush(self): + """Flush the write backlog, discarding new data received. + + We don't send close_notify in FLUSHING because we still want to send + the remaining data over SSL, even if we received a close_notify. Also, + no application-level resume_writing() or pause_writing() will be called + in FLUSHING, as we could fully manage the flow control internally. + """ + try: + self._do_read_into_void() + self._do_write() self._process_outgoing() - except ssl_SSLError as exc: - self._on_shutdown_complete(exc) + self._control_ssl_reading() + except Exception as ex: + self._on_shutdown_complete(ex) else: - self._process_outgoing() - self._call_eof_received() - self._on_shutdown_complete(None) + if not self._get_write_buffer_size(): + self._set_state(SHUTDOWN) + self._do_shutdown() + + cdef _do_shutdown(self): + """Send close_notify and wait for the same from the peer.""" + try: + # we must skip all application data (if any) before unwrap + self._do_read_into_void() + try: + self._sslobj.unwrap() + except ssl_SSLAgainErrors as exc: + self._process_outgoing() + else: + self._process_outgoing() + if not self._get_write_buffer_size(): + self._on_shutdown_complete(None) + except Exception as ex: + self._on_shutdown_complete(ex) cdef _on_shutdown_complete(self, shutdown_exc): if self._shutdown_timeout_handle is not None: @@ -585,9 +620,9 @@ cdef class SSLProtocol: self._shutdown_timeout_handle = None if shutdown_exc: - self._fatal_error(shutdown_exc) + self._fatal_error(shutdown_exc, 'Error occurred during shutdown') else: - self._loop.call_soon(self._transport.close) + self._transport.close() cdef _abort(self, exc): self._set_state(UNWRAPPED) @@ -610,11 +645,14 @@ cdef class SSLProtocol: try: if self._state == WRAPPED: self._do_write() + self._process_outgoing() + self._control_app_writing() except Exception as ex: self._fatal_error(ex, 'Fatal error on SSL protocol') cdef _do_write(self): + """Do SSL write, consumes write backlog and fills outgoing BIO.""" cdef size_t data_len, count try: while self._write_backlog: @@ -631,19 +669,18 @@ cdef class SSLProtocol: self._write_buffer_size -= data_len except ssl_SSLAgainErrors as exc: pass - self._process_outgoing() cdef _process_outgoing(self): + """Send bytes from the outgoing BIO.""" if not self._ssl_writing_paused: data = self._outgoing_read() if len(data): self._transport.write(data) - self._control_app_writing() # Incoming flow cdef _do_read(self): - if self._state != WRAPPED and self._state != FLUSHING: + if self._state != WRAPPED: return try: if not self._app_reading_paused: @@ -653,8 +690,8 @@ cdef class SSLProtocol: self._do_read__copied() if self._write_backlog: self._do_write() - else: - self._process_outgoing() + self._process_outgoing() + self._control_app_writing() self._control_ssl_reading() except Exception as ex: self._fatal_error(ex, 'Fatal error on SSL protocol') @@ -689,7 +726,11 @@ cdef class SSLProtocol: else: break else: - self._loop.call_soon(lambda: self._do_read()) + self._loop._call_soon_handle( + new_MethodHandle(self._loop, + "SSLProtocol._do_read", + self._do_read, + self)) except ssl_SSLAgainErrors as exc: pass finally: @@ -734,17 +775,18 @@ cdef class SSLProtocol: self._start_shutdown() cdef _call_eof_received(self): - try: - if self._app_state == STATE_CON_MADE: - self._app_state = STATE_EOF + if self._app_state == STATE_CON_MADE: + self._app_state = STATE_EOF + try: keep_open = self._app_protocol.eof_received() + except (KeyboardInterrupt, SystemExit): + raise + except BaseException as ex: + self._fatal_error(ex, 'Error calling eof_received()') + else: if keep_open: aio_logger.warning('returning true from eof_received() ' 'has no effect when using ssl') - except (KeyboardInterrupt, SystemExit): - raise - except BaseException as ex: - self._fatal_error(ex, 'Error calling eof_received()') # Flow control for writes from APP socket @@ -794,15 +836,12 @@ cdef class SSLProtocol: cdef _resume_reading(self): if self._app_reading_paused: self._app_reading_paused = False - - def resume(): - if self._state == WRAPPED: - self._do_read() - elif self._state == FLUSHING: - self._do_flush() - elif self._state == SHUTDOWN: - self._do_shutdown() - self._loop.call_soon(resume) + if self._state == WRAPPED: + self._loop._call_soon_handle( + new_MethodHandle(self._loop, + "SSLProtocol._do_read", + self._do_read, + self)) # Flow control for reads from SSL socket @@ -839,7 +878,16 @@ cdef class SSLProtocol: """ assert self._ssl_writing_paused self._ssl_writing_paused = False - self._process_outgoing() + + if self._state == WRAPPED: + self._process_outgoing() + self._control_app_writing() + + elif self._state == FLUSHING: + self._do_flush() + + elif self._state == SHUTDOWN: + self._do_shutdown() cdef _fatal_error(self, exc, message='Fatal error on transport'): if self._transport: