Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add tls_handshake_first option. #511

Merged
merged 4 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 31 additions & 19 deletions nats/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
)
from .transport import TcpTransport, Transport, WebSocketTransport

__version__ = '2.4.0'
__version__ = '2.5.0'
__lang__ = 'python3'
_logger = logging.getLogger(__name__)
PROTOCOL = 1
Expand Down Expand Up @@ -305,6 +305,7 @@ async def connect(
no_echo: bool = False,
tls: Optional[ssl.SSLContext] = None,
tls_hostname: Optional[str] = None,
tls_handshake_first: bool = False,
user: Optional[str] = None,
password: Optional[str] = None,
token: Optional[str] = None,
Expand Down Expand Up @@ -448,6 +449,7 @@ async def subscribe_handler(msg):
self.options["token"] = token
self.options["connect_timeout"] = connect_timeout
self.options["drain_timeout"] = drain_timeout
self.options['tls_handshake_first'] = tls_handshake_first

if tls:
self.options['tls'] = tls
Expand Down Expand Up @@ -1886,6 +1888,24 @@ async def _process_connect_init(self) -> None:
assert self._current_server, "must be called only from Client.connect"
self._status = Client.CONNECTING

# Check whether to reuse the original hostname for an implicit route.
hostname = None
if "tls_hostname" in self.options:
hostname = self.options["tls_hostname"]
elif self._current_server.tls_name is not None:
hostname = self._current_server.tls_name
else:
hostname = self._current_server.uri.hostname

handshake_first = self.options['tls_handshake_first']
if handshake_first:
await self._transport.connect_tls(
hostname,
self.ssl_context,
DEFAULT_BUFFER_SIZE,
self.options['connect_timeout'],
)

connection_completed = self._transport.readline()
info_line = await asyncio.wait_for(
connection_completed, self.options["connect_timeout"]
Expand Down Expand Up @@ -1921,24 +1941,16 @@ async def _process_connect_init(self) -> None:

if 'tls_required' in self._server_info and self._server_info[
'tls_required'] and self._current_server.uri.scheme != "ws":
# Check whether to reuse the original hostname for an implicit route.
hostname = None
if "tls_hostname" in self.options:
hostname = self.options["tls_hostname"]
elif self._current_server.tls_name is not None:
hostname = self._current_server.tls_name
else:
hostname = self._current_server.uri.hostname

await self._transport.drain() # just in case something is left

# connect to transport via tls
await self._transport.connect_tls(
hostname,
self.ssl_context,
DEFAULT_BUFFER_SIZE,
self.options['connect_timeout'],
)
if not handshake_first:
await self._transport.drain() # just in case something is left

# connect to transport via tls
await self._transport.connect_tls(
hostname,
self.ssl_context,
DEFAULT_BUFFER_SIZE,
self.options['connect_timeout'],
)

# Refresh state of parser upon reconnect.
if self.is_reconnecting:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# These are here for GitHub's dependency graph and help with setuptools support in some environments.
setup(
name="nats-py",
version='2.4.0',
version='2.5.0',
license='Apache 2 License',
extras_require={
'nkeys': ['nkeys'],
Expand Down
8 changes: 8 additions & 0 deletions tests/conf/tls_handshake_first.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
port: 4224
tls {
cert_file: "./tests/certs/server-cert.pem"
key_file: "./tests/certs/server-key.pem"
ca_file: "./tests/certs/ca.pem"
handshake_first: true
verify: true
}
106 changes: 106 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import http.client
import json
import ssl
import os
import time
import unittest
import urllib
Expand All @@ -21,6 +22,7 @@
MultiTLSServerAuthTestCase,
SingleServerTestCase,
TLSServerTestCase,
TLSServerHandshakeFirstTestCase,
NoAuthUserServerTestCase,
async_test,
)
Expand Down Expand Up @@ -1797,6 +1799,110 @@ async def worker_handler(msg):
self.assertEqual(1, err_count)


class ClientTLSHandshakeFirstTest(TLSServerHandshakeFirstTestCase):

@async_test
async def test_connect(self):
if os.environ.get('NATS_SERVER_VERSION') != 'main':
pytest.skip("test requires nats-server@main")

nc = await nats.connect(
'nats://127.0.0.1:4224',
tls=self.ssl_ctx,
tls_handshake_first=True
)
self.assertEqual(nc._server_info['max_payload'], nc.max_payload)
self.assertTrue(nc._server_info['tls_required'])
self.assertTrue(nc._server_info['tls_verify'])
self.assertTrue(nc.max_payload > 0)
self.assertTrue(nc.is_connected)
await nc.close()
self.assertTrue(nc.is_closed)
self.assertFalse(nc.is_connected)

@async_test
async def test_default_connect_using_tls_scheme(self):
if os.environ.get('NATS_SERVER_VERSION') != 'main':
pytest.skip("test requires nats-server@main")

nc = NATS()

# Will attempt to connect using TLS with default certs.
with self.assertRaises(ssl.SSLError):
await nc.connect(
servers=['tls://127.0.0.1:4224'],
allow_reconnect=False,
tls_handshake_first=True,
)

@async_test
async def test_default_connect_using_tls_scheme_in_url(self):
if os.environ.get('NATS_SERVER_VERSION') != 'main':
pytest.skip("test requires nats-server@main")

nc = NATS()

# Will attempt to connect using TLS with default certs.
with self.assertRaises(ssl.SSLError):
await nc.connect(
'tls://127.0.0.1:4224',
allow_reconnect=False,
tls_handshake_first=True
)

@async_test
async def test_connect_tls_with_custom_hostname(self):
if os.environ.get('NATS_SERVER_VERSION') != 'main':
pytest.skip("test requires nats-server@main")

nc = NATS()

# Will attempt to connect using TLS with an invalid hostname.
with self.assertRaises(ssl.SSLError):
await nc.connect(
servers=['nats://127.0.0.1:4224'],
tls=self.ssl_ctx,
tls_hostname="nats.example",
tls_handshake_first=True,
allow_reconnect=False,
)

@async_test
async def test_subscribe(self):
if os.environ.get('NATS_SERVER_VERSION') != 'main':
pytest.skip("test requires nats-server@main")

nc = NATS()
msgs = []

async def subscription_handler(msg):
msgs.append(msg)

payload = b'hello world'
await nc.connect(
servers=['nats://127.0.0.1:4224'],
tls=self.ssl_ctx,
tls_handshake_first=True
)
sub = await nc.subscribe("foo", cb=subscription_handler)
await nc.publish("foo", payload)
await nc.publish("bar", payload)

with self.assertRaises(nats.errors.BadSubjectError):
await nc.publish("", b'')

# Wait a bit for message to be received.
await asyncio.sleep(0.2)

self.assertEqual(1, len(msgs))
msg = msgs[0]
self.assertEqual('foo', msg.subject)
self.assertEqual('', msg.reply)
self.assertEqual(payload, msg.data)
self.assertEqual(1, sub._received)
await nc.close()


class ClusterDiscoveryTest(ClusteringTestCase):

@async_test
Expand Down
27 changes: 27 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,33 @@ def tearDown(self):
self.loop.close()


class TLSServerHandshakeFirstTestCase(unittest.TestCase):

def setUp(self):
super().setUp()
self.loop = asyncio.new_event_loop()

self.natsd = NATSD(
port=4224,
config_file=get_config_file('conf/tls_handshake_first.conf')
)
start_natsd(self.natsd)

self.ssl_ctx = ssl.create_default_context(
purpose=ssl.Purpose.SERVER_AUTH
)
# self.ssl_ctx.protocol = ssl.PROTOCOL_TLSv1_2
self.ssl_ctx.load_verify_locations(get_config_file('certs/ca.pem'))
self.ssl_ctx.load_cert_chain(
certfile=get_config_file('certs/client-cert.pem'),
keyfile=get_config_file('certs/client-key.pem')
)

def tearDown(self):
self.natsd.stop()
self.loop.close()


class MultiTLSServerAuthTestCase(unittest.TestCase):

def setUp(self):
Expand Down