56
56
DEFAULT_SUB_PENDING_MSGS_LIMIT ,
57
57
Subscription ,
58
58
)
59
+ from .transport import Transport , TcpTransport , WebSocketTransport
59
60
60
61
__version__ = '2.2.0'
61
62
__lang__ = 'python3'
@@ -208,10 +209,7 @@ def __init__(self) -> None:
208
209
self ._pings_outstanding : int = 0
209
210
self ._pongs_received : int = 0
210
211
self ._pongs : List [asyncio .Future ] = []
211
- self ._bare_io_reader : Optional [asyncio .StreamReader ] = None
212
- self ._io_reader : Optional [asyncio .StreamReader ] = None
213
- self ._bare_io_writer : Optional [asyncio .StreamWriter ] = None
214
- self ._io_writer : Optional [asyncio .StreamWriter ] = None
212
+ self ._transport : Optional [Transport ] = None
215
213
self ._err : Optional [Exception ] = None
216
214
217
215
# callbacks
@@ -663,14 +661,14 @@ async def _close(self, status: int, do_cbs: bool = True) -> None:
663
661
# Relinquish control to allow background tasks to wrap up.
664
662
await asyncio .sleep (0 )
665
663
666
- assert self ._io_writer , "Client.connect must be called first"
664
+ assert self ._transport , "Client.connect must be called first"
667
665
if self ._current_server is not None :
668
666
# In case there is any pending data at this point, flush before disconnecting.
669
667
if self ._pending_data_size > 0 :
670
- self ._io_writer .writelines (self ._pending [:])
668
+ self ._transport .writelines (self ._pending [:])
671
669
self ._pending = []
672
670
self ._pending_data_size = 0
673
- await self ._io_writer .drain ()
671
+ await self ._transport .drain ()
674
672
675
673
# Cleanup subscriptions since not reconnecting so no need
676
674
# to replay the subscriptions anymore.
@@ -682,10 +680,10 @@ async def _close(self, status: int, do_cbs: bool = True) -> None:
682
680
sub ._message_iterator ._cancel ()
683
681
self ._subs .clear ()
684
682
685
- if self ._io_writer is not None :
686
- self ._io_writer .close ()
683
+ if self ._transport is not None :
684
+ self ._transport .close ()
687
685
try :
688
- await self ._io_writer .wait_closed ()
686
+ await self ._transport .wait_closed ()
689
687
except Exception as e :
690
688
await self ._error_cb (e )
691
689
@@ -1167,6 +1165,17 @@ def connected_server_version(self) -> ServerVersion:
1167
1165
return ServerVersion (self ._current_server .server_version )
1168
1166
return ServerVersion ("0.0.0-unknown" )
1169
1167
1168
+ @property
1169
+ def ssl_context (self ) -> ssl .SSLContext :
1170
+ ssl_context : Optional [ssl .SSLContext ] = None
1171
+ if "tls" in self .options :
1172
+ ssl_context = self .options .get ('tls' )
1173
+ else :
1174
+ ssl_context = ssl .create_default_context ()
1175
+ if ssl_context is None :
1176
+ raise errors .Error ('nats: no ssl context provided' )
1177
+ return ssl_context
1178
+
1170
1179
async def _send_command (self , cmd : bytes , priority : bool = False ) -> None :
1171
1180
if priority :
1172
1181
self ._pending .insert (0 , cmd )
@@ -1208,6 +1217,8 @@ def _setup_server_pool(self, connect_url: Union[str, List[str]]) -> None:
1208
1217
# Closer to how the Go client handles this.
1209
1218
# e.g. nats://localhost:4222
1210
1219
uri = urlparse (connect_url )
1220
+ elif "ws://" in connect_url or "wss://" in connect_url :
1221
+ uri = urlparse (connect_url )
1211
1222
elif ":" in connect_url :
1212
1223
# Expand the scheme for the user
1213
1224
# e.g. localhost:4222
@@ -1234,6 +1245,14 @@ def _setup_server_pool(self, connect_url: Union[str, List[str]]) -> None:
1234
1245
self ._server_pool .append (Srv (uri ))
1235
1246
except ValueError :
1236
1247
raise errors .Error ("nats: invalid connect url option" )
1248
+ # make sure protocols aren't mixed
1249
+ if not (all (server .uri .scheme in ("nats" , "tls" )
1250
+ for server in self ._server_pool )
1251
+ or all (server .uri .scheme in ("ws" , "wss" )
1252
+ for server in self ._server_pool )):
1253
+ raise errors .Error (
1254
+ "nats: mixing of websocket and non websocket URLs is not allowed"
1255
+ )
1237
1256
else :
1238
1257
raise errors .Error ("nats: invalid connect url option" )
1239
1258
@@ -1264,23 +1283,27 @@ async def _select_next_server(self) -> None:
1264
1283
await asyncio .sleep (self .options ["reconnect_time_wait" ])
1265
1284
try :
1266
1285
s .last_attempt = time .monotonic ()
1267
- connection_future = asyncio .open_connection (
1268
- s .uri .hostname , s .uri .port , limit = DEFAULT_BUFFER_SIZE
1269
- )
1270
- r , w = await asyncio .wait_for (
1271
- connection_future , self .options ['connect_timeout' ]
1272
- )
1286
+ if not self ._transport :
1287
+ if s .uri .scheme in ("ws" , "wss" ):
1288
+ self ._transport = WebSocketTransport ()
1289
+ else :
1290
+ # use TcpTransport as a fallback
1291
+ self ._transport = TcpTransport ()
1292
+ if s .uri .scheme == "wss" :
1293
+ # wss is expected to connect directly with tls
1294
+ await self ._transport .connect_tls (
1295
+ s .uri ,
1296
+ ssl_context = self .ssl_context ,
1297
+ buffer_size = DEFAULT_BUFFER_SIZE ,
1298
+ connect_timeout = self .options ['connect_timeout' ]
1299
+ )
1300
+ else :
1301
+ await self ._transport .connect (
1302
+ s .uri ,
1303
+ buffer_size = DEFAULT_BUFFER_SIZE ,
1304
+ connect_timeout = self .options ['connect_timeout' ]
1305
+ )
1273
1306
self ._current_server = s
1274
-
1275
- # We keep a reference to the initial transport we used when
1276
- # establishing the connection in case we later upgrade to TLS
1277
- # after getting the first INFO message. This is in order to
1278
- # prevent the GC closing the socket after we send CONNECT
1279
- # and replace the transport.
1280
- #
1281
- # See https://github.com/nats-io/asyncio-nats/issues/43
1282
- self ._bare_io_reader = self ._io_reader = r
1283
- self ._bare_io_writer = self ._io_writer = w
1284
1307
break
1285
1308
except Exception as e :
1286
1309
s .last_attempt = time .monotonic ()
@@ -1362,10 +1385,10 @@ async def _attempt_reconnect(self) -> None:
1362
1385
):
1363
1386
self ._flusher_task .cancel ()
1364
1387
1365
- if self ._io_writer is not None :
1366
- self ._io_writer .close ()
1388
+ if self ._transport is not None :
1389
+ self ._transport .close ()
1367
1390
try :
1368
- await self ._io_writer .wait_closed ()
1391
+ await self ._transport .wait_closed ()
1369
1392
except Exception as e :
1370
1393
await self ._error_cb (e )
1371
1394
@@ -1388,7 +1411,7 @@ async def _attempt_reconnect(self) -> None:
1388
1411
# Try to establish a TCP connection to a server in
1389
1412
# the cluster then send CONNECT command to it.
1390
1413
await self ._select_next_server ()
1391
- assert self ._io_writer , "_select_next_server must've set _io_writer "
1414
+ assert self ._transport , "_select_next_server must've set _transport "
1392
1415
await self ._process_connect_init ()
1393
1416
1394
1417
# Consider a reconnect to be done once CONNECT was
@@ -1416,16 +1439,16 @@ async def _attempt_reconnect(self) -> None:
1416
1439
sub_cmd = prot_command .sub_cmd (
1417
1440
sub ._subject , sub ._queue , sid
1418
1441
)
1419
- self ._io_writer .write (sub_cmd )
1442
+ self ._transport .write (sub_cmd )
1420
1443
1421
1444
if max_msgs > 0 :
1422
1445
unsub_cmd = prot_command .unsub_cmd (sid , max_msgs )
1423
- self ._io_writer .write (unsub_cmd )
1446
+ self ._transport .write (unsub_cmd )
1424
1447
1425
1448
for sid in subs_to_remove :
1426
1449
self ._subs .pop (sid )
1427
1450
1428
- await self ._io_writer .drain ()
1451
+ await self ._transport .drain ()
1429
1452
1430
1453
# Flush pending data before continuing in connected status.
1431
1454
# FIXME: Could use future here and wait for an error result
@@ -1820,12 +1843,11 @@ async def _process_connect_init(self) -> None:
1820
1843
with authentication. It is also responsible of setting up the
1821
1844
reading and ping interval tasks from the client.
1822
1845
"""
1823
- assert self ._io_reader , "must be called only from Client.connect"
1824
- assert self ._io_writer , "must be called only from Client.connect"
1846
+ assert self ._transport , "must be called only from Client.connect"
1825
1847
assert self ._current_server , "must be called only from Client.connect"
1826
1848
self ._status = Client .CONNECTING
1827
1849
1828
- connection_completed = self ._io_reader .readline ()
1850
+ connection_completed = self ._transport .readline ()
1829
1851
info_line = await asyncio .wait_for (
1830
1852
connection_completed , self .options ["connect_timeout" ]
1831
1853
)
@@ -1855,14 +1877,6 @@ async def _process_connect_init(self) -> None:
1855
1877
1856
1878
if 'tls_required' in self ._server_info and self ._server_info [
1857
1879
'tls_required' ]:
1858
- ssl_context : Optional [ssl .SSLContext ] = None
1859
- if "tls" in self .options :
1860
- ssl_context = self .options .get ('tls' )
1861
- elif self ._current_server .uri .scheme == 'tls' :
1862
- ssl_context = ssl .create_default_context ()
1863
- if ssl_context is None :
1864
- raise errors .Error ('nats: no ssl context provided' )
1865
-
1866
1880
# Check whether to reuse the original hostname for an implicit route.
1867
1881
hostname = None
1868
1882
if "tls_hostname" in self .options :
@@ -1872,55 +1886,26 @@ async def _process_connect_init(self) -> None:
1872
1886
else :
1873
1887
hostname = self ._current_server .uri .hostname
1874
1888
1875
- await self ._io_writer .drain () # just in case something is left
1876
-
1877
- # loop.start_tls was introduced in python 3.7
1878
- # the previous method is removed in 3.9
1879
- if sys .version_info .minor >= 7 :
1880
- # manually recreate the stream reader/writer with a tls upgraded transport
1881
- reader = asyncio .StreamReader ()
1882
- protocol = asyncio .StreamReaderProtocol (reader )
1883
- transport_future = asyncio .get_running_loop ().start_tls (
1884
- self ._io_writer .transport ,
1885
- protocol ,
1886
- ssl_context ,
1887
- server_hostname = hostname
1888
- )
1889
- transport = await asyncio .wait_for (
1890
- transport_future , self .options ['connect_timeout' ]
1891
- )
1892
- writer = asyncio .StreamWriter (
1893
- transport , protocol , reader , asyncio .get_running_loop ()
1894
- )
1895
- self ._io_reader , self ._io_writer = reader , writer
1896
- else :
1897
- transport = self ._io_writer .transport
1898
- sock = transport .get_extra_info ('socket' )
1899
- if not sock :
1900
- # This shouldn't happen
1901
- raise errors .Error ('nats: unable to get socket' )
1902
-
1903
- connection_future = asyncio .open_connection (
1904
- limit = DEFAULT_BUFFER_SIZE ,
1905
- sock = sock ,
1906
- ssl = ssl_context ,
1907
- server_hostname = hostname ,
1908
- )
1909
- self ._io_reader , self ._io_writer = await asyncio .wait_for (
1910
- connection_future , self .options ['connect_timeout' ]
1911
- )
1889
+ await self ._transport .drain () # just in case something is left
1890
+
1891
+ # connect to transport via tls
1892
+ await self ._transport .connect_tls (
1893
+ hostname ,
1894
+ self .ssl_context ,
1895
+ DEFAULT_BUFFER_SIZE ,
1896
+ self .options ['connect_timeout' ],
1897
+ )
1912
1898
1913
1899
# Refresh state of parser upon reconnect.
1914
1900
if self .is_reconnecting :
1915
1901
self ._ps .reset ()
1916
1902
1917
- assert self ._io_reader
1918
- assert self ._io_writer
1903
+ assert self ._transport
1919
1904
connect_cmd = self ._connect_command ()
1920
- self ._io_writer .write (connect_cmd )
1921
- await self ._io_writer .drain ()
1905
+ self ._transport .write (connect_cmd )
1906
+ await self ._transport .drain ()
1922
1907
if self .options ["verbose" ]:
1923
- future = self ._io_reader .readline ()
1908
+ future = self ._transport .readline ()
1924
1909
next_op = await asyncio .wait_for (
1925
1910
future , self .options ["connect_timeout" ]
1926
1911
)
@@ -1936,10 +1921,10 @@ async def _process_connect_init(self) -> None:
1936
1921
# await self._process_err(err_msg)
1937
1922
raise errors .Error ("nats: " + err_msg .rstrip ('\r \n ' ))
1938
1923
1939
- self ._io_writer .write (PING_PROTO )
1940
- await self ._io_writer .drain ()
1924
+ self ._transport .write (PING_PROTO )
1925
+ await self ._transport .drain ()
1941
1926
1942
- future = self ._io_reader .readline ()
1927
+ future = self ._transport .readline ()
1943
1928
next_op = await asyncio .wait_for (
1944
1929
future , self .options ["connect_timeout" ]
1945
1930
)
@@ -1973,11 +1958,12 @@ async def _process_connect_init(self) -> None:
1973
1958
)
1974
1959
1975
1960
async def _send_ping (self , future : asyncio .Future = None ) -> None :
1976
- assert self ._io_writer , "Client.connect must be called first"
1961
+ assert self ._transport , "Client.connect must be called first"
1977
1962
if future is None :
1978
1963
future = asyncio .Future ()
1979
1964
self ._pongs .append (future )
1980
- self ._io_writer .write (PING_PROTO )
1965
+ self ._transport .write (PING_PROTO )
1966
+ self ._pending_data_size += len (PING_PROTO )
1981
1967
await self ._flush_pending ()
1982
1968
1983
1969
async def _flusher (self ) -> None :
@@ -1986,7 +1972,7 @@ async def _flusher(self) -> None:
1986
1972
and then flushes them to the socket.
1987
1973
"""
1988
1974
assert self ._error_cb , "Client.connect must be called first"
1989
- assert self ._io_writer , "Client.connect must be called first"
1975
+ assert self ._transport , "Client.connect must be called first"
1990
1976
assert self ._flush_queue , "Client.connect must be called first"
1991
1977
while True :
1992
1978
if not self .is_connected or self .is_connecting :
@@ -1996,10 +1982,10 @@ async def _flusher(self) -> None:
1996
1982
1997
1983
try :
1998
1984
if self ._pending_data_size > 0 :
1999
- self ._io_writer .writelines (self ._pending [:])
1985
+ self ._transport .writelines (self ._pending [:])
2000
1986
self ._pending = []
2001
1987
self ._pending_data_size = 0
2002
- await self ._io_writer .drain ()
1988
+ await self ._transport .drain ()
2003
1989
except OSError as e :
2004
1990
await self ._error_cb (e )
2005
1991
await self ._process_op_err (e )
@@ -2037,15 +2023,15 @@ async def _read_loop(self) -> None:
2037
2023
while True :
2038
2024
try :
2039
2025
should_bail = self .is_closed or self .is_reconnecting
2040
- if should_bail or self ._io_reader is None :
2026
+ if should_bail or self ._transport is None :
2041
2027
break
2042
- if self .is_connected and self ._io_reader .at_eof ():
2028
+ if self .is_connected and self ._transport .at_eof ():
2043
2029
err = errors .UnexpectedEOF ()
2044
2030
await self ._error_cb (err )
2045
2031
await self ._process_op_err (err )
2046
2032
break
2047
2033
2048
- b = await self ._io_reader .read (DEFAULT_BUFFER_SIZE )
2034
+ b = await self ._transport .read (DEFAULT_BUFFER_SIZE )
2049
2035
await self ._ps .parse (b )
2050
2036
except errors .ProtocolError :
2051
2037
await self ._process_op_err (errors .ProtocolError ())
0 commit comments