From 2d719924b03ac81005aa1ce62aee820356b9637f Mon Sep 17 00:00:00 2001 From: Casper Beyer Date: Thu, 18 Jul 2024 12:30:33 +0200 Subject: [PATCH] More tests and fixups --- examples/micro/service.py | 49 +++++++++++ nats/micro/__init__.py | 4 +- nats/micro/api.py | 34 -------- nats/micro/request.py | 6 +- nats/micro/service.py | 98 +++++++++++++++------- tests/test_micro_service.py | 162 ++++++++++-------------------------- 6 files changed, 167 insertions(+), 186 deletions(-) create mode 100644 examples/micro/service.py delete mode 100644 nats/micro/api.py diff --git a/examples/micro/service.py b/examples/micro/service.py new file mode 100644 index 00000000..182457d5 --- /dev/null +++ b/examples/micro/service.py @@ -0,0 +1,49 @@ +import asyncio +import contextlib +import signal + +from nats import Client, micro + +async def echo(req) -> None: + """Echo the request data back to the client.""" + await req.respond(req.data()) + + +async def main(): + # Define an event to signal when to quit + quit_event = asyncio.Event() + # Attach signal handler to the event loop + loop = asyncio.get_event_loop() + for sig in (signal.Signals.SIGINT, signal.Signals.SIGTERM): + loop.add_signal_handler(sig, lambda *_: quit_event.set()) + # Create an async exit stack + async with contextlib.AsyncExitStack() as stack: + # Create a NATS client + nc = Client() + # Connect to NATS + await nc.connect("nats://localhost:4222") + # Push the client.close() method into the stack to be called on exit + stack.push_async_callback(nc.close) + # Push a new micro service into the stack to be stopped on exit + # The service will be stopped and drain its subscriptions before + # closing the connection. + service = await stack.enter_async_context( + await micro.add_service( + nc, + name="demo-service", + version="1.0.0", + description="Demo service", + ) + ) + group = service.add_group("demo") + # Add an endpoint to the service + await group.add_endpoint( + name="echo", + handler=echo, + ) + # Wait for the quit event + await quit_event.wait() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/nats/micro/__init__.py b/nats/micro/__init__.py index f7ec0b73..39450508 100644 --- a/nats/micro/__init__.py +++ b/nats/micro/__init__.py @@ -20,7 +20,9 @@ from .request import Request, Handler -async def add_service(nc: Client, config: Optional[ServiceConfig] = None, **kwargs) -> Service: +async def add_service( + nc: Client, config: Optional[ServiceConfig] = None, **kwargs +) -> Service: """Add a service.""" if config: config = replace(config, **kwargs) diff --git a/nats/micro/api.py b/nats/micro/api.py deleted file mode 100644 index acea0681..00000000 --- a/nats/micro/api.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2021-2022 The NATS Authors -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import re - -DEFAULT_QUEUE_GROUP = "q" -"""Queue Group name used across all services.""" - -DEFAULT_PREFIX = "$SRV" -"""The root of all control subjects.""" - -ERROR_HEADER = "Nats-Service-Error" -ERROR_CODE_HEADER = "Nats-Service-Error-Code" - -INFO_RESPONSE_TYPE = "io.nats.micro.v1.info_response" -PING_RESPONSE_TYPE = "io.nats.micro.v1.ping_response" -STATS_RESPONSE_TYPE = "io.nats.micro.v1.stats_response" - -SEMVER_REGEX = re.compile( - r"^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$" -) -NAME_REGEX = re.compile(r"^[A-Za-z0-9\-_]+$") -SUBJECT_REGEX = re.compile(r"^[^ >]*[>]?$") diff --git a/nats/micro/request.py b/nats/micro/request.py index 2444a0ba..eb1f49e6 100644 --- a/nats/micro/request.py +++ b/nats/micro/request.py @@ -12,14 +12,14 @@ # limitations under the License. # -import abc - from dataclasses import dataclass from enum import Enum from typing import Dict, Awaitable, Callable, TypeAlias, Optional from nats.aio.msg import Msg -from nats.micro.api import ERROR_HEADER, ERROR_CODE_HEADER + +ERROR_HEADER = "Nats-Service-Error" +ERROR_CODE_HEADER = "Nats-Service-Error-Code" class Request: diff --git a/nats/micro/service.py b/nats/micro/service.py index ac491481..ab10df71 100644 --- a/nats/micro/service.py +++ b/nats/micro/service.py @@ -2,12 +2,12 @@ from asyncio import Event from dataclasses import dataclass, replace, field -from datetime import UTC, datetime, timedelta +from datetime import datetime, timedelta from enum import Enum from nats.aio.client import Client from nats.aio.msg import Msg from nats.aio.subscription import Subscription -from nats.micro.api import DEFAULT_PREFIX, DEFAULT_QUEUE_GROUP + from typing import ( Any, AsyncContextManager, @@ -15,22 +15,40 @@ Protocol, Dict, List, - Self, overload, TypeAlias, Callable, ) +import re import json import time from .request import Request, Handler + +DEFAULT_QUEUE_GROUP = "q" +"""Queue Group name used across all services.""" + +DEFAULT_PREFIX = "$SRV" +"""The root of all control subjects.""" + +INFO_RESPONSE_TYPE = "io.nats.micro.v1.info_response" +PING_RESPONSE_TYPE = "io.nats.micro.v1.ping_response" +STATS_RESPONSE_TYPE = "io.nats.micro.v1.stats_response" + +SEMVER_REGEX = re.compile( + r"^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$" +) +NAME_REGEX = re.compile(r"^[A-Za-z0-9\-_]+$") +SUBJECT_REGEX = re.compile(r"^[^ >]*[>]?$") + class ServiceVerb(str, Enum): PING = "PING" STATS = "STATS" INFO = "INFO" + @dataclass class EndpointConfig: name: str @@ -98,7 +116,7 @@ class EndpointStats: The last error the service encountered, if any. """ - data: Dict[str, object] = field(default_factory=dict) + data: Optional[Any] = None """ Additional statistics the endpoint makes available """ @@ -130,6 +148,7 @@ def to_dict(self) -> Dict[str, Any]: "data": self.data, } + @dataclass class EndpointInfo: """The information of an endpoint.""" @@ -171,6 +190,7 @@ def to_dict(self) -> Dict[str, Any]: "metadata": self.metadata, } + class Endpoint: """Endpoint manages a service endpoint.""" @@ -235,7 +255,6 @@ async def _handle_request(self, msg: Msg) -> None: self._average_processing_time = int(self._processing_time / self._num_requests) - @dataclass class GroupConfig: """The configuration of a group.""" @@ -251,6 +270,7 @@ class EndpointManager(Protocol): """ Manages the endpoints of a service. """ + @overload async def add_endpoint(self, config: EndpointConfig) -> None: ... @@ -267,23 +287,22 @@ async def add_endpoint( async def add_endpoint( self, config: Optional[EndpointConfig] = None, **kwargs - ) -> None: - ... + ) -> None: ... + class GroupManager(Protocol): """ Manages the groups of a service. """ + @overload - def add_group( - self, *, name: str, queue_group: Optional[str] = None - ) -> Group: ... + def add_group(self, *, name: str, queue_group: Optional[str] = None) -> Group: ... @overload def add_group(self, config: GroupConfig) -> Group: ... - def add_group(self, config: Optional[GroupConfig] = None, **kwargs) -> Group: - ... + def add_group(self, config: Optional[GroupConfig] = None, **kwargs) -> Group: ... + class Group(GroupManager, EndpointManager): def __init__(self, service: "Service", config: GroupConfig) -> None: @@ -313,16 +332,17 @@ async def add_endpoint( else: config = replace(config, **kwargs) - config = replace(config, - subject = f"{self._prefix.strip('.')}.{config.subject or config.name}".strip('.'), + config = replace( + config, + subject=f"{self._prefix.strip('.')}.{config.subject or config.name}".strip( + "." + ), ) await self._service.add_endpoint(config) @overload - def add_group( - self, *, name: str, queue_group: Optional[str] = None - ) -> Group: ... + def add_group(self, *, name: str, queue_group: Optional[str] = None) -> Group: ... @overload def add_group(self, config: GroupConfig) -> Group: ... @@ -333,7 +353,8 @@ def add_group(self, config: Optional[GroupConfig] = None, **kwargs) -> Group: else: config = replace(config, **kwargs) - config = replace(config, + config = replace( + config, name=f"{self._prefix}.{config.name}", queue_group=config.queue_group or self._queue_group, ) @@ -346,7 +367,6 @@ def add_group(self, config: Optional[GroupConfig] = None, **kwargs) -> Group: A handler function used to configure a custom *STATS* endpoint. """ - @dataclass class ServiceConfig: """The configuration of a service.""" @@ -369,6 +389,9 @@ class ServiceConfig: """The default queue group of the service.""" stats_handler: Optional[StatsHandler] = None + """ + A handler function used to configure a custom *STATS* endpoint. + """ class ServiceIdentity(Protocol): @@ -381,6 +404,7 @@ class ServiceIdentity(Protocol): version: str metadata: Dict[str, str] + @dataclass class ServicePing(ServiceIdentity): """The response to a ping message.""" @@ -389,7 +413,7 @@ class ServicePing(ServiceIdentity): name: str version: str metadata: Dict[str, str] = field(default_factory=dict) - type: str = "io.nats.micro.v1.ping_response" + type: str = PING_RESPONSE_TYPE @classmethod def from_dict(cls, data: Dict[str, Any]) -> ServicePing: @@ -409,6 +433,7 @@ def to_dict(self) -> Dict[str, Any]: "metadata": self.metadata, } + @dataclass class ServiceStats(ServiceIdentity): """The statistics of a service.""" @@ -441,7 +466,7 @@ class ServiceStats(ServiceIdentity): metadata: Dict[str, str] = field(default_factory=dict) """Service metadata.""" - type: str = "io.nats.micro.v1.stats_response" + type: str = STATS_RESPONSE_TYPE """ The schema type of the message """ @@ -507,7 +532,7 @@ class ServiceInfo: The service metadata """ - type: str = "io.nats.micro.v1.info_response" + type: str = INFO_RESPONSE_TYPE """ The type of the message """ @@ -526,7 +551,9 @@ def from_dict(cls, data: Dict[str, Any]) -> ServiceInfo: name=data["name"], version=data["version"], description=data.get("description"), - endpoints=[EndpointInfo.from_dict(endpoint) for endpoint in data["endpoints"]], + endpoints=[ + EndpointInfo.from_dict(endpoint) for endpoint in data["endpoints"] + ], metadata=data["metadata"], type=data.get("type", "io.nats.micro.v1.info_response"), ) @@ -546,6 +573,7 @@ def to_dict(self) -> Dict[str, Any]: "metadata": self.metadata, } + class Service(AsyncContextManager): def __init__(self, client: Client, config: ServiceConfig) -> None: self._id = client._nuid.next().decode() @@ -559,7 +587,7 @@ def __init__(self, client: Client, config: ServiceConfig) -> None: self._client = client self._subscriptions = {} self._endpoints = [] - self._started = datetime.now(UTC) + self._started = datetime.utcnow() self._stopped = Event() self._prefix = DEFAULT_PREFIX @@ -582,18 +610,29 @@ async def start(self) -> None: for verb, verb_handler in verb_request_handlers.items(): verb_subjects = [ - (f"{verb}-all", control_subject(verb, name=None, id=None, prefix=self._prefix)), - (f"{verb}-kind", control_subject(verb, name=self._name, id=None, prefix=self._prefix)), - (verb, control_subject(verb, name=self._name, id=self._id, prefix=self._prefix)), + ( + f"{verb}-all", + control_subject(verb, name=None, id=None, prefix=self._prefix), + ), + ( + f"{verb}-kind", + control_subject( + verb, name=self._name, id=None, prefix=self._prefix + ), + ), + ( + verb, + control_subject( + verb, name=self._name, id=self._id, prefix=self._prefix + ), + ), ] for key, subject in verb_subjects: - print(f"Subscribing to {subject} for {verb}") self._subscriptions[key] = await self._client.subscribe( subject, cb=verb_handler ) - print("Subscriptions all created", self._subscriptions) self._started = datetime.now() await self._client.flush() @@ -742,6 +781,7 @@ async def _handle_info_request(self, msg: Msg) -> None: await msg.respond(data=json.dumps(info).encode()) + async def _handle_stats_request(self, msg: Msg) -> None: """Handle a stats message.""" stats = self.stats().to_dict() diff --git a/tests/test_micro_service.py b/tests/test_micro_service.py index 9692ba76..a01bad31 100644 --- a/tests/test_micro_service.py +++ b/tests/test_micro_service.py @@ -6,12 +6,7 @@ import nats.micro -import pytest - -from nats import connect - from nats.micro import * -from nats.micro.api import * from nats.micro.service import * from nats.micro.request import * @@ -26,7 +21,7 @@ def test_endpoint_config(self): @async_test async def test_service_basics(self): - nc = await connect() + nc = await nats.connect() svcs = [] async def add_handler(request: Request): @@ -348,14 +343,10 @@ async def noop_handler(request: Request): metadata={"basic": "schema"}, ) - async def error_cb(err): - print(err) - - nc = await nats.connect(error_cb=error_cb) + nc = await nats.connect() await nc.flush() svc = await add_service(nc, service_config) - # await svc.start() await svc.add_endpoint(endpoint_config) sub_tests = { @@ -450,7 +441,6 @@ async def error_cb(err): for name, data in sub_tests.items(): with self.subTest(name=name): - print(f"Sending request to {data['subject']}") response = await nc.request(data["subject"], timeout=1) response_data = json.loads(response.data) expected_response = data["expected_response"] @@ -538,120 +528,54 @@ async def handler(request: Request): assert stats.endpoints[0].processing_time > 0 assert stats.endpoints[0].average_processing_time > 0 - assert stats.endpoints[0].data == data.get("expected_stats", {}) + assert stats.endpoints[0].data == data.get("expected_stats") await svc.stop() - # async def test_request_respond(self): - # sub_tests = { - # "byte_response": { - # "name": "byte response", - # "respond_data": b"OK", - # "expected_response": b"OK", - # }, - # "byte_response_with_headers": { - # "name": "byte response, with headers", - # "respond_headers": {"key": ["value"]}, - # "respond_data": b"OK", - # "expected_response": b"OK", - # }, - # # "byte_response_connection_closed": { - # # "respond_data": b"OK", - # # "with_respond_error": Error("io.nats.micro.v1.respond_error"), - # # }, - # # { - # # "name": "struct response", - # # "respond_data": X(a="abc", b=5), - # # "expected_response": b'{"a":"abc","b":5}', - # # }, - # # { - # # "name": "invalid response data", - # # "respond_data": lambda: None, - # # "with_respond_error": Error("io.nats.micro.v1.marshal_response_error"), - # # }, - # # { - # # "name": "generic error", - # # "err_description": "oops", - # # "err_code": "500", - # # "err_data": b"error!", - # # "expected_message": "oops", - # # "expected_code": "500", - # # }, - # # { - # # "name": "generic error, with headers", - # # "respond_headers": Headers({"key": ["value"]}), - # # "err_description": "oops", - # # "err_code": "500", - # # "err_data": b"error!", - # # "expected_message": "oops", - # # "expected_code": "500", - # # }, - # # { - # # "name": "error without response payload", - # # "err_description": "oops", - # # "err_code": "500", - # # "expected_message": "oops", - # # "expected_code": "500", - # # }, - # # { - # # "name": "missing error code", - # # "err_description": "oops", - # # "with_respond_error": Error("io.nats.micro.v1.arg_required_error"), - # # }, - # # { - # # "name": "missing error description", - # # "err_code": "500", - # # "with_respond_error": Error("io.nats.micro.v1.arg_required_error"), - # # }, - # } - - # for name, data in sub_tests.items(): - # with self.subTest(name=name): - # async def handler(request: Request): - # await request.respond( - # data["respond_data"], - # headers=data.get("respond_headers"), - # ) - - # svc = await add_service( - # self.nc, - # ServiceConfig( - # name="CoolService", - # version="0.1.0", - # description="test service", - # endpoint=EndpointConfig(_subject="test.func", handler=handler), - # ), - # ) - # await svc.start() # Explicitly start the service - - # request = nc.new_request("test.func", b"req") - # request.headers = Headers({"key": ["value"]}) - - # response = await nc.request(data["request"], headers="", timeout=0.5) - - # if "with_respond_error" in test: - # assert isinstance(response.error, test["with_respond_error"]) - # await svc.stop() - # continue + @async_test + async def test_request_respond(self): + sub_tests = { + "empty_response": { + "respond_data": b"", + "expected_response": b"", + }, + "byte_response": { + "respond_data": b"OK", + "expected_response": b"OK", + }, + "byte_response_with_headers": { + "respond_headers": {"key": "value"}, + "respond_data": b"OK", + "expected_response": b"OK", + "expected_headers": {"key": "value"}, + }, + } - # if "err_code" in test: - # assert response.headers["Status"] == test["expected_code"] - # assert response.headers["Message"] == test["expected_message"] - # assert response.headers == { - # "Status": [test["expected_code"]], - # "Message": [test["expected_message"]], - # **test.get("respond_headers", {}), - # } - # await svc.stop() - # continue + nc = await nats.connect() + for name, data in sub_tests.items(): + with self.subTest(name=name): + async def handler(request: Request): + await request.respond( + data["respond_data"], + headers=data.get("respond_headers"), + ) + + svc = await add_service( + nc, + ServiceConfig( + name="CoolService", + version="0.1.0", + description="test service", + ), + ) + await svc.add_endpoint(EndpointConfig(name="test.func", handler=handler)) - # assert response.data == test["expected_response"] - # assert response.headers == test.get("respond_headers", Headers()) + response = await nc.request("test.func", data["respond_data"], headers=data.get("respond_headers"), timeout=0.5) - # await svc.stop() + assert response.data == data["expected_response"] + assert response.headers == data.get("expected_headers") - async def test_request_error(self): - pass + await svc.stop() def test_control_subject(self): sub_tests = {