Skip to content

Commit c08e8a4

Browse files
authored
Merge pull request #356 from bvanelli/feature/implement-websockets
Feature/implement websockets
2 parents acbf469 + b538861 commit c08e8a4

File tree

7 files changed

+506
-97
lines changed

7 files changed

+506
-97
lines changed

Pipfile

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ mypy = "*"
99
pytest = "*"
1010
pytest-cov = "*"
1111
yapf = "*"
12+
toml = "*" # see https://github.com/google/yapf/issues/936
1213

1314
[packages]
1415
nkeys = "*"
16+
aiohttp = "*"

nats/aio/client.py

+83-97
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
DEFAULT_SUB_PENDING_MSGS_LIMIT,
5757
Subscription,
5858
)
59+
from .transport import Transport, TcpTransport, WebSocketTransport
5960

6061
__version__ = '2.2.0'
6162
__lang__ = 'python3'
@@ -208,10 +209,7 @@ def __init__(self) -> None:
208209
self._pings_outstanding: int = 0
209210
self._pongs_received: int = 0
210211
self._pongs: List[asyncio.Future] = []
211-
self._bare_io_reader: Optional[asyncio.StreamReader] = None
212-
self._io_reader: Optional[asyncio.StreamReader] = None
213-
self._bare_io_writer: Optional[asyncio.StreamWriter] = None
214-
self._io_writer: Optional[asyncio.StreamWriter] = None
212+
self._transport: Optional[Transport] = None
215213
self._err: Optional[Exception] = None
216214

217215
# callbacks
@@ -663,14 +661,14 @@ async def _close(self, status: int, do_cbs: bool = True) -> None:
663661
# Relinquish control to allow background tasks to wrap up.
664662
await asyncio.sleep(0)
665663

666-
assert self._io_writer, "Client.connect must be called first"
664+
assert self._transport, "Client.connect must be called first"
667665
if self._current_server is not None:
668666
# In case there is any pending data at this point, flush before disconnecting.
669667
if self._pending_data_size > 0:
670-
self._io_writer.writelines(self._pending[:])
668+
self._transport.writelines(self._pending[:])
671669
self._pending = []
672670
self._pending_data_size = 0
673-
await self._io_writer.drain()
671+
await self._transport.drain()
674672

675673
# Cleanup subscriptions since not reconnecting so no need
676674
# to replay the subscriptions anymore.
@@ -682,10 +680,10 @@ async def _close(self, status: int, do_cbs: bool = True) -> None:
682680
sub._message_iterator._cancel()
683681
self._subs.clear()
684682

685-
if self._io_writer is not None:
686-
self._io_writer.close()
683+
if self._transport is not None:
684+
self._transport.close()
687685
try:
688-
await self._io_writer.wait_closed()
686+
await self._transport.wait_closed()
689687
except Exception as e:
690688
await self._error_cb(e)
691689

@@ -1167,6 +1165,17 @@ def connected_server_version(self) -> ServerVersion:
11671165
return ServerVersion(self._current_server.server_version)
11681166
return ServerVersion("0.0.0-unknown")
11691167

1168+
@property
1169+
def ssl_context(self) -> ssl.SSLContext:
1170+
ssl_context: Optional[ssl.SSLContext] = None
1171+
if "tls" in self.options:
1172+
ssl_context = self.options.get('tls')
1173+
else:
1174+
ssl_context = ssl.create_default_context()
1175+
if ssl_context is None:
1176+
raise errors.Error('nats: no ssl context provided')
1177+
return ssl_context
1178+
11701179
async def _send_command(self, cmd: bytes, priority: bool = False) -> None:
11711180
if priority:
11721181
self._pending.insert(0, cmd)
@@ -1208,6 +1217,8 @@ def _setup_server_pool(self, connect_url: Union[str, List[str]]) -> None:
12081217
# Closer to how the Go client handles this.
12091218
# e.g. nats://localhost:4222
12101219
uri = urlparse(connect_url)
1220+
elif "ws://" in connect_url or "wss://" in connect_url:
1221+
uri = urlparse(connect_url)
12111222
elif ":" in connect_url:
12121223
# Expand the scheme for the user
12131224
# e.g. localhost:4222
@@ -1234,6 +1245,14 @@ def _setup_server_pool(self, connect_url: Union[str, List[str]]) -> None:
12341245
self._server_pool.append(Srv(uri))
12351246
except ValueError:
12361247
raise errors.Error("nats: invalid connect url option")
1248+
# make sure protocols aren't mixed
1249+
if not (all(server.uri.scheme in ("nats", "tls")
1250+
for server in self._server_pool)
1251+
or all(server.uri.scheme in ("ws", "wss")
1252+
for server in self._server_pool)):
1253+
raise errors.Error(
1254+
"nats: mixing of websocket and non websocket URLs is not allowed"
1255+
)
12371256
else:
12381257
raise errors.Error("nats: invalid connect url option")
12391258

@@ -1264,23 +1283,27 @@ async def _select_next_server(self) -> None:
12641283
await asyncio.sleep(self.options["reconnect_time_wait"])
12651284
try:
12661285
s.last_attempt = time.monotonic()
1267-
connection_future = asyncio.open_connection(
1268-
s.uri.hostname, s.uri.port, limit=DEFAULT_BUFFER_SIZE
1269-
)
1270-
r, w = await asyncio.wait_for(
1271-
connection_future, self.options['connect_timeout']
1272-
)
1286+
if not self._transport:
1287+
if s.uri.scheme in ("ws", "wss"):
1288+
self._transport = WebSocketTransport()
1289+
else:
1290+
# use TcpTransport as a fallback
1291+
self._transport = TcpTransport()
1292+
if s.uri.scheme == "wss":
1293+
# wss is expected to connect directly with tls
1294+
await self._transport.connect_tls(
1295+
s.uri,
1296+
ssl_context=self.ssl_context,
1297+
buffer_size=DEFAULT_BUFFER_SIZE,
1298+
connect_timeout=self.options['connect_timeout']
1299+
)
1300+
else:
1301+
await self._transport.connect(
1302+
s.uri,
1303+
buffer_size=DEFAULT_BUFFER_SIZE,
1304+
connect_timeout=self.options['connect_timeout']
1305+
)
12731306
self._current_server = s
1274-
1275-
# We keep a reference to the initial transport we used when
1276-
# establishing the connection in case we later upgrade to TLS
1277-
# after getting the first INFO message. This is in order to
1278-
# prevent the GC closing the socket after we send CONNECT
1279-
# and replace the transport.
1280-
#
1281-
# See https://github.com/nats-io/asyncio-nats/issues/43
1282-
self._bare_io_reader = self._io_reader = r
1283-
self._bare_io_writer = self._io_writer = w
12841307
break
12851308
except Exception as e:
12861309
s.last_attempt = time.monotonic()
@@ -1362,10 +1385,10 @@ async def _attempt_reconnect(self) -> None:
13621385
):
13631386
self._flusher_task.cancel()
13641387

1365-
if self._io_writer is not None:
1366-
self._io_writer.close()
1388+
if self._transport is not None:
1389+
self._transport.close()
13671390
try:
1368-
await self._io_writer.wait_closed()
1391+
await self._transport.wait_closed()
13691392
except Exception as e:
13701393
await self._error_cb(e)
13711394

@@ -1388,7 +1411,7 @@ async def _attempt_reconnect(self) -> None:
13881411
# Try to establish a TCP connection to a server in
13891412
# the cluster then send CONNECT command to it.
13901413
await self._select_next_server()
1391-
assert self._io_writer, "_select_next_server must've set _io_writer"
1414+
assert self._transport, "_select_next_server must've set _transport"
13921415
await self._process_connect_init()
13931416

13941417
# Consider a reconnect to be done once CONNECT was
@@ -1416,16 +1439,16 @@ async def _attempt_reconnect(self) -> None:
14161439
sub_cmd = prot_command.sub_cmd(
14171440
sub._subject, sub._queue, sid
14181441
)
1419-
self._io_writer.write(sub_cmd)
1442+
self._transport.write(sub_cmd)
14201443

14211444
if max_msgs > 0:
14221445
unsub_cmd = prot_command.unsub_cmd(sid, max_msgs)
1423-
self._io_writer.write(unsub_cmd)
1446+
self._transport.write(unsub_cmd)
14241447

14251448
for sid in subs_to_remove:
14261449
self._subs.pop(sid)
14271450

1428-
await self._io_writer.drain()
1451+
await self._transport.drain()
14291452

14301453
# Flush pending data before continuing in connected status.
14311454
# FIXME: Could use future here and wait for an error result
@@ -1820,12 +1843,11 @@ async def _process_connect_init(self) -> None:
18201843
with authentication. It is also responsible of setting up the
18211844
reading and ping interval tasks from the client.
18221845
"""
1823-
assert self._io_reader, "must be called only from Client.connect"
1824-
assert self._io_writer, "must be called only from Client.connect"
1846+
assert self._transport, "must be called only from Client.connect"
18251847
assert self._current_server, "must be called only from Client.connect"
18261848
self._status = Client.CONNECTING
18271849

1828-
connection_completed = self._io_reader.readline()
1850+
connection_completed = self._transport.readline()
18291851
info_line = await asyncio.wait_for(
18301852
connection_completed, self.options["connect_timeout"]
18311853
)
@@ -1855,14 +1877,6 @@ async def _process_connect_init(self) -> None:
18551877

18561878
if 'tls_required' in self._server_info and self._server_info[
18571879
'tls_required']:
1858-
ssl_context: Optional[ssl.SSLContext] = None
1859-
if "tls" in self.options:
1860-
ssl_context = self.options.get('tls')
1861-
elif self._current_server.uri.scheme == 'tls':
1862-
ssl_context = ssl.create_default_context()
1863-
if ssl_context is None:
1864-
raise errors.Error('nats: no ssl context provided')
1865-
18661880
# Check whether to reuse the original hostname for an implicit route.
18671881
hostname = None
18681882
if "tls_hostname" in self.options:
@@ -1872,55 +1886,26 @@ async def _process_connect_init(self) -> None:
18721886
else:
18731887
hostname = self._current_server.uri.hostname
18741888

1875-
await self._io_writer.drain() # just in case something is left
1876-
1877-
# loop.start_tls was introduced in python 3.7
1878-
# the previous method is removed in 3.9
1879-
if sys.version_info.minor >= 7:
1880-
# manually recreate the stream reader/writer with a tls upgraded transport
1881-
reader = asyncio.StreamReader()
1882-
protocol = asyncio.StreamReaderProtocol(reader)
1883-
transport_future = asyncio.get_running_loop().start_tls(
1884-
self._io_writer.transport,
1885-
protocol,
1886-
ssl_context,
1887-
server_hostname=hostname
1888-
)
1889-
transport = await asyncio.wait_for(
1890-
transport_future, self.options['connect_timeout']
1891-
)
1892-
writer = asyncio.StreamWriter(
1893-
transport, protocol, reader, asyncio.get_running_loop()
1894-
)
1895-
self._io_reader, self._io_writer = reader, writer
1896-
else:
1897-
transport = self._io_writer.transport
1898-
sock = transport.get_extra_info('socket')
1899-
if not sock:
1900-
# This shouldn't happen
1901-
raise errors.Error('nats: unable to get socket')
1902-
1903-
connection_future = asyncio.open_connection(
1904-
limit=DEFAULT_BUFFER_SIZE,
1905-
sock=sock,
1906-
ssl=ssl_context,
1907-
server_hostname=hostname,
1908-
)
1909-
self._io_reader, self._io_writer = await asyncio.wait_for(
1910-
connection_future, self.options['connect_timeout']
1911-
)
1889+
await self._transport.drain() # just in case something is left
1890+
1891+
# connect to transport via tls
1892+
await self._transport.connect_tls(
1893+
hostname,
1894+
self.ssl_context,
1895+
DEFAULT_BUFFER_SIZE,
1896+
self.options['connect_timeout'],
1897+
)
19121898

19131899
# Refresh state of parser upon reconnect.
19141900
if self.is_reconnecting:
19151901
self._ps.reset()
19161902

1917-
assert self._io_reader
1918-
assert self._io_writer
1903+
assert self._transport
19191904
connect_cmd = self._connect_command()
1920-
self._io_writer.write(connect_cmd)
1921-
await self._io_writer.drain()
1905+
self._transport.write(connect_cmd)
1906+
await self._transport.drain()
19221907
if self.options["verbose"]:
1923-
future = self._io_reader.readline()
1908+
future = self._transport.readline()
19241909
next_op = await asyncio.wait_for(
19251910
future, self.options["connect_timeout"]
19261911
)
@@ -1936,10 +1921,10 @@ async def _process_connect_init(self) -> None:
19361921
# await self._process_err(err_msg)
19371922
raise errors.Error("nats: " + err_msg.rstrip('\r\n'))
19381923

1939-
self._io_writer.write(PING_PROTO)
1940-
await self._io_writer.drain()
1924+
self._transport.write(PING_PROTO)
1925+
await self._transport.drain()
19411926

1942-
future = self._io_reader.readline()
1927+
future = self._transport.readline()
19431928
next_op = await asyncio.wait_for(
19441929
future, self.options["connect_timeout"]
19451930
)
@@ -1973,11 +1958,12 @@ async def _process_connect_init(self) -> None:
19731958
)
19741959

19751960
async def _send_ping(self, future: asyncio.Future = None) -> None:
1976-
assert self._io_writer, "Client.connect must be called first"
1961+
assert self._transport, "Client.connect must be called first"
19771962
if future is None:
19781963
future = asyncio.Future()
19791964
self._pongs.append(future)
1980-
self._io_writer.write(PING_PROTO)
1965+
self._transport.write(PING_PROTO)
1966+
self._pending_data_size += len(PING_PROTO)
19811967
await self._flush_pending()
19821968

19831969
async def _flusher(self) -> None:
@@ -1986,7 +1972,7 @@ async def _flusher(self) -> None:
19861972
and then flushes them to the socket.
19871973
"""
19881974
assert self._error_cb, "Client.connect must be called first"
1989-
assert self._io_writer, "Client.connect must be called first"
1975+
assert self._transport, "Client.connect must be called first"
19901976
assert self._flush_queue, "Client.connect must be called first"
19911977
while True:
19921978
if not self.is_connected or self.is_connecting:
@@ -1996,10 +1982,10 @@ async def _flusher(self) -> None:
19961982

19971983
try:
19981984
if self._pending_data_size > 0:
1999-
self._io_writer.writelines(self._pending[:])
1985+
self._transport.writelines(self._pending[:])
20001986
self._pending = []
20011987
self._pending_data_size = 0
2002-
await self._io_writer.drain()
1988+
await self._transport.drain()
20031989
except OSError as e:
20041990
await self._error_cb(e)
20051991
await self._process_op_err(e)
@@ -2037,15 +2023,15 @@ async def _read_loop(self) -> None:
20372023
while True:
20382024
try:
20392025
should_bail = self.is_closed or self.is_reconnecting
2040-
if should_bail or self._io_reader is None:
2026+
if should_bail or self._transport is None:
20412027
break
2042-
if self.is_connected and self._io_reader.at_eof():
2028+
if self.is_connected and self._transport.at_eof():
20432029
err = errors.UnexpectedEOF()
20442030
await self._error_cb(err)
20452031
await self._process_op_err(err)
20462032
break
20472033

2048-
b = await self._io_reader.read(DEFAULT_BUFFER_SIZE)
2034+
b = await self._transport.read(DEFAULT_BUFFER_SIZE)
20492035
await self._ps.parse(b)
20502036
except errors.ProtocolError:
20512037
await self._process_op_err(errors.ProtocolError())

0 commit comments

Comments
 (0)