From ab5acdbe1782f56831a36ee894a7154f21dd9ee2 Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Mon, 26 Aug 2024 12:02:42 +0200 Subject: [PATCH 01/23] subscription over distinct SSE connection implementation and tests --- ariadne/asgi/handlers/http.py | 417 ++++++++++++++++++++++- tests/asgi/test_sse.py | 169 +++++++++ tests/conftest.py | 9 + tests_integrations/fastapi/test_sse.py | 76 +++++ tests_integrations/starlette/test_sse.py | 71 ++++ 5 files changed, 736 insertions(+), 6 deletions(-) create mode 100644 tests/asgi/test_sse.py create mode 100644 tests_integrations/fastapi/test_sse.py create mode 100644 tests_integrations/starlette/test_sse.py diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index 3fe1a5b4..e0716719 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -1,22 +1,60 @@ +import asyncio import json +import logging +from asyncio import Lock +from functools import partial from http import HTTPStatus from inspect import isawaitable -from typing import Any, Optional, Type, Union, cast - -from graphql import DocumentNode, MiddlewareManager +from io import StringIO +from typing import ( + Any, + Optional, + cast, + Dict, + AsyncGenerator, + Callable, + Awaitable, + Literal, + get_args, + List, + Union, +) +from typing import Type + +from anyio import ( + get_cancelled_exc_class, + CancelScope, + sleep, + move_on_after, + create_task_group, +) +from graphql import DocumentNode +from graphql import MiddlewareManager from starlette.datastructures import UploadFile from starlette.requests import Request -from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse, Response +from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse +from starlette.responses import Response from starlette.types import Receive, Scope, Send +from .base import GraphQLHttpHandlerBase +from ... import format_error from ...constants import ( DATA_TYPE_JSON, DATA_TYPE_MULTIPART, ) -from ...exceptions import HttpBadRequestError, HttpError +from ...exceptions import HttpBadRequestError +from ...exceptions import HttpError from ...explorer import Explorer from ...file_uploads import combine_multipart_data +from ...graphql import ( + ExecutionResult, + GraphQLError, + parse_query, + subscribe, + validate_data, +) from ...graphql import graphql +from ...logger import log_error from ...types import ( ContextValue, ExtensionList, @@ -25,7 +63,240 @@ MiddlewareList, Middlewares, ) -from .base import GraphQLHttpHandlerBase + +EVENT_TYPES = Literal["next", "complete"] + + +class GraphQLServerSentEvent: + """GraphQLServerSentEvent is a class that represents a single Server-Sent Event + as defined in the GraphQL SSE Protocol specification + (https://github.com/enisdenjo/graphql-sse/blob/master/PROTOCOL.md) + """ + + DEFAULT_SEPARATOR = "\r\n" + + def __init__( + self, + event: EVENT_TYPES, + result: Optional[ExecutionResult] = None, + ): + """Initializes the Server-Sent Event + # Required arguments + `event`: the type of the event. Either "next" or "complete" + + # Optional arguments + `result`: an `ExecutionResult` or a `dict` that represents the result of the operation + """ + assert event in get_args(EVENT_TYPES), f"Invalid event type: {event}" + self.event = event + self.result = result + self.logger = logging.Logger("GraphQLServerSentEvent") + + def _write_to_buffer( + self, buffer: StringIO, name: str, value: Optional[str] + ) -> StringIO: + """Writes a SSE field to the buffered SSE event representation + + Returns the `StringIO` buffer with the field written to it + + # Required arguments + `buffer`: the `StringIO` buffer to write to + `name`: the name of the field + `value`: the value of the field + """ + if value is not None: + buffer.write(f"{name}: {value}{self.DEFAULT_SEPARATOR}") + return buffer + + def encode_execution_result(self) -> str: + """Encodes the execution result into a single line JSON string + + Returns the JSON string representation of the execution result + """ + payload: Dict[str, Any] = {} + if self.result.data: + payload["data"] = self.result.data + if self.result.errors: + errors = [] + for error in self.result.errors: + log_error(error, self.logger) + errors.append(format_error(error)) + payload["errors"] = errors + + return json.dumps(payload) + + def __str__(self) -> str: + """Returns the string representation of the Server-Sent Event""" + buffer = StringIO() + buffer = self._write_to_buffer(buffer, "event", self.event) + buffer = self._write_to_buffer( + buffer, + "data", + ( + self.encode_execution_result() + if self.event == "next" and self.result + else "" + ), + ) + buffer.write(self.DEFAULT_SEPARATOR) + + return buffer.getvalue() + + +class ServerSentEventResponse(Response): + """Sends GraphQL SSE events using EvenSource protocol using Starlette's Response class + based on the implementation https://github.com/sysid/sse-starlette/ + """ + + # Sends a ping event to the client every 15 seconds to overcome proxy timeout issues + DEFAULT_PING_INTERVAL = 15 + + def __init__( + self, + generator: AsyncGenerator[GraphQLServerSentEvent, Any], + send_timeout: Optional[int] = None, + ping_interval: Optional[int] = None, + headers: Optional[Dict[str, str]] = None, + encoding: Optional[str] = None, + *args, + **kwargs, + ): + """Initializes the a SSE Response that send events generated by an async generator + + # Required arguments + `generator`: an async generator that yields `GraphQLServerSentEvent` objects + + # Optional arguments + `send_timeout`: the timeout in seconds to send an event to the client + `ping_interval`: the interval in seconds to send a ping event to the client, overrides + the DEFAULT_PING_INTERVAL of 15 seconds + `headers`: a dictionary of headers to be sent with the response + `encoding`: the encoding to use for the response + """ + super().__init__(*args, **kwargs) + self.generator = generator + self.status_code = HTTPStatus.OK + self.send_timeout = send_timeout + self.ping_interval = ping_interval or self.DEFAULT_PING_INTERVAL + self.encoding = encoding or "utf-8" + self.content = None + + _headers: dict[str, str] = {} + if headers is not None: + _headers.update(headers) + # mandatory for servers-sent events headers + # allow cache control header to be set by user to support fan out proxies + # https://www.fastly.com/blog/server-sent-events-fastly + _headers.setdefault("Cache-Control", "no-cache") + _headers.setdefault("Connection", "keep-alive") + _headers.setdefault("X-Accel-Buffering", "no") + _headers.setdefault("Transfer-Encoding", "chunked") + self.media_type = "text/event-stream" + self.init_headers(_headers) + + self._send_lock = Lock() + + @staticmethod + async def listen_for_disconnect(receive: Receive) -> None: + """Listens for the client disconnect event and stops the streaming by exiting the infinite loop + this triggers the anyio CancelScope to cancel the TaskGroup + + # Required arguments + `receive`: the starlette Receive object + """ + while True: + message = await receive() + if message["type"] == "http.disconnect": + logging.debug(f"Got event: http.disconnect. Stop streaming...") + break + + def encode_event(self, event: GraphQLServerSentEvent) -> bytes: + """Encodes the GraphQLServerSentEvent into a bytes object + + # Required arguments + `event`: the GraphQLServerSentEvent object + """ + return str(event).encode(self.encoding) + + async def _ping(self, send: Send) -> None: + """Sends a ping event to the client every `ping_interval` seconds gets cancelled if the client disconnects + through the anyio CancelScope of the TaskGroup + + # Required arguments + `send`: the starlette Send object + """ + while True: + await sleep(self.ping_interval) + async with self._send_lock: + await send( + { + "type": "http.response.body", + "body": self.encode_event(GraphQLServerSentEvent(event="next")), + "more_body": True, + } + ) + + async def send_events(self, send: Send) -> None: + """Sends the events generated by the async generator to the client + + # Required arguments + `send`: the starlette Send object + + """ + async with self._send_lock: + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + + try: + async for event in self.generator: + async with self._send_lock: + with move_on_after(self.send_timeout) as timeout: + await send( + { + "type": "http.response.body", + "body": self.encode_event(event), + "more_body": True, + } + ) + + if timeout.cancel_called: + raise asyncio.TimeoutError() + + except (get_cancelled_exc_class(),) as e: + logging.warning(e) + finally: + with CancelScope(shield=True): + async with self._send_lock: + await send( + {"type": "http.response.body", "body": b"", "more_body": False} + ) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """The main entrypoint for the ServerSentEventResponse which is called by starlette + + # Required arguments + `scope`: the starlette Scope object + + `receive`: the starlette Receive object + + `send`: the starlette Send object + + """ + async with create_task_group() as task_group: + + async def wrap_cancelling(func: Callable[[], Awaitable[None]]) -> None: + await func() + task_group.cancel_scope.cancel() + + task_group.start_soon(wrap_cancelling, partial(self._ping, send)) + task_group.start_soon(wrap_cancelling, partial(self.send_events, send)) + # this will cancel the task group when the client disconnects + await wrap_cancelling(partial(self.listen_for_disconnect, receive)) class GraphQLHTTPHandler(GraphQLHttpHandlerBase): @@ -123,6 +394,10 @@ async def handle_request(self, request: Request) -> Response: return await self.render_explorer(request, self.explorer) if request.method == "POST": + accept = request.headers.get("Accept", "") + accept = accept.split(",")[0] + if accept == "text/event-stream": + return await self.handle_sse_request(request) return await self.graphql_http_server(request) return self.handle_not_allowed_method(request) @@ -425,3 +700,133 @@ def handle_not_allowed_method(self, request: Request): return Response(headers=allow_header) return Response(status_code=HTTPStatus.METHOD_NOT_ALLOWED, headers=allow_header) + + async def handle_sse_request(self, request: Request) -> Response: + """Handles the HTTP request with GraphQL Subscription query using Server-Sent Events. + + # Required arguments + + `request`: the starlette `Request` instance + """ + + try: + data = await self.extract_data_from_request(request) + query = await self.get_query_from_sse_request(request, data) + + if self.schema is None: + raise TypeError( + "schema is not set, call configure method to initialize it" + ) + + validate_data(data) + context_value = await self.get_context_for_request(request, data) + return ServerSentEventResponse( + generator=self.sse_subscribe_to_graphql(query, data, context_value) + ) + except (HttpError, TypeError, GraphQLError) as error: + log_error(error, self.logger) + if not isinstance(error, GraphQLError): + error_message = ( + (error.message or error.status) + if isinstance(error, HttpError) + else str(error) + ) + error = GraphQLError(error_message, original_error=error) + return ServerSentEventResponse( + generator=self.sse_generate_error_response([error]) + ) + + async def sse_generate_error_response( + self, errors: List[GraphQLError] + ) -> AsyncGenerator[GraphQLServerSentEvent, Any]: + """A Server-Sent Event response generator for the errors + To be passed to a ServerSentEventResponse instance + + # Required arguments + + `errors`: a list of `GraphQLError` instances + """ + + yield GraphQLServerSentEvent( + event="next", result=ExecutionResult(errors=errors) + ) + yield GraphQLServerSentEvent(event="complete") + + async def sse_subscribe_to_graphql( + self, query_document: DocumentNode, data: Any, context_value: Any + ): + """Main SSE subscription generator for the GraphQL query. Yields `GraphQLServerSentEvent` instances + and is to be consumed by a `ServerSentEventResponse` instance + + # Required arguments + + `query_document`: an already parsed GraphQL query. + + `data`: a `dict` with query data (`query` string, optionally `operationName` + string and `variables` dictionary). + + `context_value`: a context value to make accessible as 'context' attribute + of second argument (`info`) passed to resolvers and source functions. + """ + + success, results = await subscribe( + self.schema, # type: ignore + data, + context_value=context_value, + root_value=self.root_value, + query_document=query_document, + query_validator=self.query_validator, + validation_rules=self.validation_rules, + debug=self.debug, + introspection=self.introspection, + logger=self.logger, + error_formatter=self.error_formatter, + ) + + if not success: + if not isinstance(results, list): + error_payload = cast(List[dict], [results]) + else: + error_payload = results + + # This needs to be handled better, subscribe returns preformatted errors + yield GraphQLServerSentEvent( + event="next", + result=ExecutionResult( + errors=[ + GraphQLError(message=error.get("message")) + for error in error_payload + ] + ), + ) + else: + results = cast(AsyncGenerator, results) + try: + async for result in results: + yield GraphQLServerSentEvent(event="next", result=result) + except (Exception, GraphQLError) as error: + if not isinstance(error, GraphQLError): + error = GraphQLError(str(error), original_error=error) + log_error(error, self.logger) + yield GraphQLServerSentEvent( + event="next", result=ExecutionResult(errors=[error]) + ) + + yield GraphQLServerSentEvent(event="complete") + + async def get_query_from_sse_request( + self, request: Request, data: Any + ) -> DocumentNode: + """Extracts GraphQL query from SSE request. + + Returns a `DocumentNode` with parsed query. + + # Required arguments + + `request`: the starlette `Request` instance + + `data`: an additional data parameter to potentially extract the query from + """ + + context_value = await self.get_context_for_request(request, data) + return parse_query(context_value, self.query_parser, data) diff --git a/tests/asgi/test_sse.py b/tests/asgi/test_sse.py new file mode 100644 index 00000000..988a1ad2 --- /dev/null +++ b/tests/asgi/test_sse.py @@ -0,0 +1,169 @@ +import json +from http import HTTPStatus +from typing import List, Dict, Any +from unittest.mock import Mock +from graphql import parse, GraphQLError + +from starlette.testclient import TestClient +from httpx import Response +import pytest +from ariadne.asgi import GraphQL + +SSE_HEADER = {"Accept": "text/event-stream"} + + +def get_sse_events(response: Response) -> List[Dict[str, Any]]: + events = [] + for event in response.text.split("\r\n\r\n"): + if len(event.strip()) == 0: + continue + event, data = event.split("\r\n", 1) + event = event.replace("event: ", "") + data = data.replace("data: ", "") + data = json.loads(data) if len(data) > 0 else None + events.append({"event": event, "data": data}) + return events + + +@pytest.fixture +def sse_client(schema): + app = GraphQL(schema, introspection=False) + return TestClient(app, headers=SSE_HEADER) + + +def test_sse_headers(sse_client): + response = sse_client.post("/", json={"query": "subscription { ping }"}) + assert response.status_code == HTTPStatus.OK + assert response.headers["Cache-Control"] == "no-cache" + assert response.headers["Connection"] == "keep-alive" + assert response.headers["Transfer-Encoding"] == "chunked" + assert response.headers["X-Accel-Buffering"] == "no" + + +def test_field_can_be_subscribed_to_using_sse(sse_client): + response = sse_client.post("/", json={"query": "subscription { ping }"}) + events = get_sse_events(response) + assert len(events) == 2 + assert events[0]["data"]["data"] == {"ping": "pong"} + assert events[1]["event"] == "complete" + + +def test_non_subscription_query_cannot_be_executed_using_sse( + sse_client, +): + response = sse_client.post( + "/", + json={ + "query": "query Hello($name: String){ hello(name: $name) }", + "variables": {"name": "John"}, + }, + ) + events = get_sse_events(response) + assert len(events) == 2 + assert events[0]["data"].get("errors") is not None + + +def test_invalid_query_is_handled_using_sse(sse_client): + response = sse_client.post("/", json={"query": "query Invalid { error other }"}) + events = get_sse_events(response) + assert len(events) == 2 + assert events[0]["data"].get("errors") is not None + + +def test_custom_query_parser_is_used_for_subscription_over_sse(schema): + mock_parser = Mock(return_value=parse("subscription { testContext }")) + app = GraphQL( + schema, + query_parser=mock_parser, + context_value={"test": "I'm context"}, + root_value={"test": "I'm root"}, + ) + + client = TestClient(app, headers=SSE_HEADER) + response = client.post("/", json={"query": "subscription { testRoot }"}) + + events = get_sse_events(response) + assert len(events) == 2 + assert events[0]["data"]["data"] == {"testContext": "I'm context"} + assert events[1]["event"] == "complete" + + +@pytest.mark.parametrize( + ("errors"), + [ + ([]), + ([GraphQLError("Nope")]), + ], +) +def test_custom_query_validator_is_used_for_subscription_over_sse(schema, errors): + mock_validator = Mock(return_value=errors) + app = GraphQL( + schema, + query_validator=mock_validator, + context_value={"test": "I'm context"}, + root_value={"test": "I'm root"}, + ) + + client = TestClient(app, headers=SSE_HEADER) + response = client.post( + "/", + json={ + "operationName": None, + "query": "subscription { testContext }", + "variables": None, + }, + ) + + events = get_sse_events(response) + if not errors: + assert len(events) == 2 + assert events[0] == { + "event": "next", + "data": {"data": {"testContext": "I'm context"}}, + } + assert events[1] == {"event": "complete", "data": None} + else: + assert len(events) == 2 + assert events[0]["data"]["errors"][0]["message"] == "Nope" + + +def test_schema_not_set_graphql_sse(): + app = GraphQL(None) + + client = TestClient(app, headers=SSE_HEADER) + response = client.post( + "/", + json={ + "operationName": None, + "query": "subscription { testContext }", + "variables": None, + }, + ) + + events = get_sse_events(response) + assert len(events) == 2 + assert ( + events[0]["data"]["errors"][0]["message"] + == "schema is not set, call configure method to initialize it" + ) + + +def test_ping_is_send_sse(sse_client): + response = sse_client.post("/", json={"query": "subscription { testSlow }"}) + events = get_sse_events(response) + assert len(events) == 4 + assert events[0]["event"] == "next" + assert events[0]["data"]["data"] == {"testSlow": "slow"} + assert events[1]["event"] == "next" + assert events[1]["data"] is None + assert events[2]["event"] == "next" + assert events[2]["data"]["data"] == {"testSlow": "slow"} + assert events[3]["event"] == "complete" + + +def test_resolver_error_is_handled_sse(sse_client): + response = sse_client.post("/", json={"query": "subscription { resolverError }"}) + events = get_sse_events(response) + assert len(events) == 2 + assert events[0]["data"]["errors"][0]["message"] == "Test exception" + assert events[1]["event"] == "complete" diff --git a/tests/conftest.py b/tests/conftest.py index 9333a782..837aa8dc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import asyncio from collections.abc import Mapping import pytest @@ -37,6 +38,7 @@ def type_defs(): sourceError: Boolean testContext: String testRoot: String + testSlow: String } """ @@ -172,6 +174,12 @@ async def test_root_generator(root, *_): yield {"testRoot": root.get("test")} +async def test_slow_generator(_, info): + yield {"testSlow": "slow"} + await asyncio.sleep(20) + yield {"testSlow": "slow"} + + @pytest.fixture def subscriptions(): subscription = SubscriptionType() @@ -181,6 +189,7 @@ def subscriptions(): subscription.set_source("sourceError", error_generator) subscription.set_source("testContext", test_context_generator) subscription.set_source("testRoot", test_root_generator) + subscription.set_source("testSlow", test_slow_generator) return subscription diff --git a/tests_integrations/fastapi/test_sse.py b/tests_integrations/fastapi/test_sse.py new file mode 100644 index 00000000..eae596c4 --- /dev/null +++ b/tests_integrations/fastapi/test_sse.py @@ -0,0 +1,76 @@ +from fastapi import FastAPI, Request +from starlette.testclient import TestClient + +from ariadne import SubscriptionType, make_executable_schema +from ariadne.asgi import GraphQL +from ariadne.asgi.handlers import GraphQLTransportWSHandler + +subscription_type = SubscriptionType() + + +@subscription_type.source("counter") +async def counter_source(*_): + yield 1 + + +@subscription_type.field("counter") +async def counter_resolve(obj, *_): + return obj + + +schema = make_executable_schema( + """ + type Query { + _unused: String + } + + type Subscription { + counter: Int! + } + """, + subscription_type, +) + +app = FastAPI() +graphql = GraphQL( + schema, + websocket_handler=GraphQLTransportWSHandler(), +) + + +@app.post("/graphql") +async def graphql_route(request: Request): + return await graphql.handle_request(request) + + +app.mount("/mounted", graphql) + +client = TestClient(app, headers={"Accept": "text/event-stream"}) + + +def test_run_graphql_subscription_through_route(): + response = client.post( + "/graphql", + json={ + "operationName": None, + "query": "subscription { counter }", + "variables": None, + }, + ) + + assert response.status_code == 200 + assert '{"data": {"counter": 1}}' in response.text + + +def test_run_graphql_subscription_through_mount(): + response = client.post( + "/mounted", + json={ + "operationName": None, + "query": "subscription { counter }", + "variables": None, + }, + ) + + assert response.status_code == 200 + assert '{"data": {"counter": 1}}' in response.text diff --git a/tests_integrations/starlette/test_sse.py b/tests_integrations/starlette/test_sse.py new file mode 100644 index 00000000..58caa934 --- /dev/null +++ b/tests_integrations/starlette/test_sse.py @@ -0,0 +1,71 @@ +from starlette.applications import Starlette +from starlette.routing import Mount, Route +from starlette.testclient import TestClient + +from ariadne import SubscriptionType, make_executable_schema +from ariadne.asgi import GraphQL + +subscription_type = SubscriptionType() + + +@subscription_type.source("counter") +async def counter_source(*_): + yield 1 + + +@subscription_type.field("counter") +async def counter_resolve(obj, *_): + return obj + + +schema = make_executable_schema( + """ + type Query { + _unused: String + } + + type Subscription { + counter: Int! + } + """, + subscription_type, +) + +graphql = GraphQL(schema) + +app = Starlette( + routes=[ + Route("/graphql", methods=["POST"], endpoint=graphql.handle_request), + Mount("/mounted", graphql), + ], +) + +client = TestClient(app, headers={"Accept": "text/event-stream"}) + + +def test_run_graphql_subscription_through_route(): + response = client.post( + "/graphql", + json={ + "operationName": None, + "query": "subscription { counter }", + "variables": None, + }, + ) + + assert response.status_code == 200 + assert '{"data": {"counter": 1}}' in response.text + + +def test_run_graphql_subscription_through_mount(): + response = client.post( + "/mounted", + json={ + "operationName": None, + "query": "subscription { counter }", + "variables": None, + }, + ) + + assert response.status_code == 200 + assert '{"data": {"counter": 1}}' in response.text From e0410befe7790eb506b1dcee0fc1f31603298955 Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Mon, 26 Aug 2024 12:08:17 +0200 Subject: [PATCH 02/23] add changelog entry --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 21981d9d..4b3cfddd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ - Fixed tests websockets after starlette update. - Added `share_enabled` param to `ExplorerPlayground` to enable share playground feature. - Added support for nested attribute resolution in alias resolvers. +- Added support for subscriptions over a distinct Server-Sent-Events connection as per (https://github.com/enisdenjo/graphql-sse/blob/master/PROTOCOL.md). ## 0.23 (2024-03-18) From 870db53c13328e9b9eaa6a98e9f0c1df667e85dd Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Mon, 26 Aug 2024 13:55:30 +0200 Subject: [PATCH 03/23] fix linting --- ariadne/asgi/handlers/http.py | 17 +++++++++-------- tests/conftest.py | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index e0716719..37933e85 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -153,12 +153,12 @@ class ServerSentEventResponse(Response): def __init__( self, + *args, generator: AsyncGenerator[GraphQLServerSentEvent, Any], send_timeout: Optional[int] = None, ping_interval: Optional[int] = None, headers: Optional[Dict[str, str]] = None, encoding: Optional[str] = None, - *args, **kwargs, ): """Initializes the a SSE Response that send events generated by an async generator @@ -198,8 +198,8 @@ def __init__( @staticmethod async def listen_for_disconnect(receive: Receive) -> None: - """Listens for the client disconnect event and stops the streaming by exiting the infinite loop - this triggers the anyio CancelScope to cancel the TaskGroup + """Listens for the client disconnect event and stops the streaming by exiting the infinite + loop this triggers the anyio CancelScope to cancel the TaskGroup # Required arguments `receive`: the starlette Receive object @@ -207,7 +207,7 @@ async def listen_for_disconnect(receive: Receive) -> None: while True: message = await receive() if message["type"] == "http.disconnect": - logging.debug(f"Got event: http.disconnect. Stop streaming...") + logging.debug("Got event: http.disconnect. Stop streaming...") break def encode_event(self, event: GraphQLServerSentEvent) -> bytes: @@ -219,8 +219,8 @@ def encode_event(self, event: GraphQLServerSentEvent) -> bytes: return str(event).encode(self.encoding) async def _ping(self, send: Send) -> None: - """Sends a ping event to the client every `ping_interval` seconds gets cancelled if the client disconnects - through the anyio CancelScope of the TaskGroup + """Sends a ping event to the client every `ping_interval` seconds gets + cancelled if the client disconnects through the anyio CancelScope of the TaskGroup # Required arguments `send`: the starlette Send object @@ -755,8 +755,9 @@ async def sse_generate_error_response( async def sse_subscribe_to_graphql( self, query_document: DocumentNode, data: Any, context_value: Any ): - """Main SSE subscription generator for the GraphQL query. Yields `GraphQLServerSentEvent` instances - and is to be consumed by a `ServerSentEventResponse` instance + """Main SSE subscription generator for the GraphQL query. + Yields `GraphQLServerSentEvent` instances and is to be consumed by a + `ServerSentEventResponse` instance # Required arguments diff --git a/tests/conftest.py b/tests/conftest.py index 837aa8dc..5dd06e35 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -174,7 +174,7 @@ async def test_root_generator(root, *_): yield {"testRoot": root.get("test")} -async def test_slow_generator(_, info): +async def test_slow_generator(*_): yield {"testSlow": "slow"} await asyncio.sleep(20) yield {"testSlow": "slow"} From dba57008eaade81417a1d4cf54e1dcf5bfeb6b38 Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Mon, 26 Aug 2024 14:03:46 +0200 Subject: [PATCH 04/23] fix mypy errors --- ariadne/asgi/handlers/http.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index 37933e85..24ad711e 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -114,9 +114,9 @@ def encode_execution_result(self) -> str: Returns the JSON string representation of the execution result """ payload: Dict[str, Any] = {} - if self.result.data: + if self.result is not None and self.result.data is not None: payload["data"] = self.result.data - if self.result.errors: + if self.result is not None and self.result.errors is not None: errors = [] for error in self.result.errors: log_error(error, self.logger) @@ -710,7 +710,7 @@ async def handle_sse_request(self, request: Request) -> Response: """ try: - data = await self.extract_data_from_request(request) + data: Any = await self.extract_data_from_request(request) query = await self.get_query_from_sse_request(request, data) if self.schema is None: @@ -795,7 +795,7 @@ async def sse_subscribe_to_graphql( event="next", result=ExecutionResult( errors=[ - GraphQLError(message=error.get("message")) + GraphQLError(message=cast(str, error.get("message", ""))) for error in error_payload ] ), From d6d9f9f44c064b0309fc4f92a09f4099b54ea4e3 Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Mon, 26 Aug 2024 14:08:15 +0200 Subject: [PATCH 05/23] fix pylint error, only visible in ci pipeline --- ariadne/asgi/handlers/http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index 24ad711e..184282b7 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -181,7 +181,7 @@ def __init__( self.encoding = encoding or "utf-8" self.content = None - _headers: dict[str, str] = {} + _headers: Dict[str, str] = {} if headers is not None: _headers.update(headers) # mandatory for servers-sent events headers From cd3680a22fea348b21886e81649e2f59f750db02 Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Thu, 29 Aug 2024 16:18:20 +0200 Subject: [PATCH 06/23] fix content-length header being set --- ariadne/asgi/handlers/http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index 184282b7..b06c5bef 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -179,7 +179,7 @@ def __init__( self.send_timeout = send_timeout self.ping_interval = ping_interval or self.DEFAULT_PING_INTERVAL self.encoding = encoding or "utf-8" - self.content = None + self.body = None _headers: Dict[str, str] = {} if headers is not None: From 7ebd431f9bdd2988328b06f741325bb87d7c6ab3 Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Thu, 29 Aug 2024 16:58:54 +0200 Subject: [PATCH 07/23] align ping message with graphql-sse implementation (https://github.com/enisdenjo/graphql-sse/blob/e8bef032422a7d38a670dc6d18204c4f5dfab6c8/src/handler.ts#L516) --- ariadne/asgi/handlers/http.py | 4 ++-- tests/asgi/test_sse.py | 16 ++++++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index b06c5bef..5833597a 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -179,7 +179,7 @@ def __init__( self.send_timeout = send_timeout self.ping_interval = ping_interval or self.DEFAULT_PING_INTERVAL self.encoding = encoding or "utf-8" - self.body = None + self.body = None # type: ignore _headers: Dict[str, str] = {} if headers is not None: @@ -231,7 +231,7 @@ async def _ping(self, send: Send) -> None: await send( { "type": "http.response.body", - "body": self.encode_event(GraphQLServerSentEvent(event="next")), + "body": ":\r\n\r\n".encode(self.encoding), "more_body": True, } ) diff --git a/tests/asgi/test_sse.py b/tests/asgi/test_sse.py index 988a1ad2..27c8cf44 100644 --- a/tests/asgi/test_sse.py +++ b/tests/asgi/test_sse.py @@ -17,11 +17,15 @@ def get_sse_events(response: Response) -> List[Dict[str, Any]]: for event in response.text.split("\r\n\r\n"): if len(event.strip()) == 0: continue - event, data = event.split("\r\n", 1) - event = event.replace("event: ", "") - data = data.replace("data: ", "") - data = json.loads(data) if len(data) > 0 else None - events.append({"event": event, "data": data}) + if "\r\n" not in event: + # ping message + events.append({"event": "", "data": None}) + else: + event, data = event.split("\r\n", 1) + event = event.replace("event: ", "") + data = data.replace("data: ", "") + data = json.loads(data) if len(data) > 0 else None + events.append({"event": event, "data": data}) return events @@ -154,7 +158,7 @@ def test_ping_is_send_sse(sse_client): assert len(events) == 4 assert events[0]["event"] == "next" assert events[0]["data"]["data"] == {"testSlow": "slow"} - assert events[1]["event"] == "next" + assert events[1]["event"] == "" assert events[1]["data"] is None assert events[2]["event"] == "next" assert events[2]["data"]["data"] == {"testSlow": "slow"} From 08b3db03a9289ec8691992b68f17c0fce7a86cfb Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Thu, 12 Dec 2024 14:29:39 +0100 Subject: [PATCH 08/23] move sse to separate handler, add tests for configuration options, cleanup --- ariadne/asgi/handlers/http.py | 432 +-------------------- ariadne/contrib/sse.py | 456 +++++++++++++++++++++++ tests/asgi/test_sse.py | 49 ++- tests_integrations/fastapi/test_sse.py | 2 + tests_integrations/starlette/test_sse.py | 3 +- 5 files changed, 523 insertions(+), 419 deletions(-) create mode 100644 ariadne/contrib/sse.py diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index 5833597a..00d183ca 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -1,60 +1,22 @@ -import asyncio import json -import logging -from asyncio import Lock -from functools import partial from http import HTTPStatus from inspect import isawaitable -from io import StringIO -from typing import ( - Any, - Optional, - cast, - Dict, - AsyncGenerator, - Callable, - Awaitable, - Literal, - get_args, - List, - Union, -) -from typing import Type - -from anyio import ( - get_cancelled_exc_class, - CancelScope, - sleep, - move_on_after, - create_task_group, -) -from graphql import DocumentNode -from graphql import MiddlewareManager +from typing import Any, Optional, Type, Union, cast + +from graphql import DocumentNode, MiddlewareManager from starlette.datastructures import UploadFile from starlette.requests import Request -from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse -from starlette.responses import Response +from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse, Response from starlette.types import Receive, Scope, Send -from .base import GraphQLHttpHandlerBase -from ... import format_error from ...constants import ( DATA_TYPE_JSON, DATA_TYPE_MULTIPART, ) -from ...exceptions import HttpBadRequestError -from ...exceptions import HttpError +from ...exceptions import HttpBadRequestError, HttpError from ...explorer import Explorer from ...file_uploads import combine_multipart_data -from ...graphql import ( - ExecutionResult, - GraphQLError, - parse_query, - subscribe, - validate_data, -) from ...graphql import graphql -from ...logger import log_error from ...types import ( ContextValue, ExtensionList, @@ -63,240 +25,7 @@ MiddlewareList, Middlewares, ) - -EVENT_TYPES = Literal["next", "complete"] - - -class GraphQLServerSentEvent: - """GraphQLServerSentEvent is a class that represents a single Server-Sent Event - as defined in the GraphQL SSE Protocol specification - (https://github.com/enisdenjo/graphql-sse/blob/master/PROTOCOL.md) - """ - - DEFAULT_SEPARATOR = "\r\n" - - def __init__( - self, - event: EVENT_TYPES, - result: Optional[ExecutionResult] = None, - ): - """Initializes the Server-Sent Event - # Required arguments - `event`: the type of the event. Either "next" or "complete" - - # Optional arguments - `result`: an `ExecutionResult` or a `dict` that represents the result of the operation - """ - assert event in get_args(EVENT_TYPES), f"Invalid event type: {event}" - self.event = event - self.result = result - self.logger = logging.Logger("GraphQLServerSentEvent") - - def _write_to_buffer( - self, buffer: StringIO, name: str, value: Optional[str] - ) -> StringIO: - """Writes a SSE field to the buffered SSE event representation - - Returns the `StringIO` buffer with the field written to it - - # Required arguments - `buffer`: the `StringIO` buffer to write to - `name`: the name of the field - `value`: the value of the field - """ - if value is not None: - buffer.write(f"{name}: {value}{self.DEFAULT_SEPARATOR}") - return buffer - - def encode_execution_result(self) -> str: - """Encodes the execution result into a single line JSON string - - Returns the JSON string representation of the execution result - """ - payload: Dict[str, Any] = {} - if self.result is not None and self.result.data is not None: - payload["data"] = self.result.data - if self.result is not None and self.result.errors is not None: - errors = [] - for error in self.result.errors: - log_error(error, self.logger) - errors.append(format_error(error)) - payload["errors"] = errors - - return json.dumps(payload) - - def __str__(self) -> str: - """Returns the string representation of the Server-Sent Event""" - buffer = StringIO() - buffer = self._write_to_buffer(buffer, "event", self.event) - buffer = self._write_to_buffer( - buffer, - "data", - ( - self.encode_execution_result() - if self.event == "next" and self.result - else "" - ), - ) - buffer.write(self.DEFAULT_SEPARATOR) - - return buffer.getvalue() - - -class ServerSentEventResponse(Response): - """Sends GraphQL SSE events using EvenSource protocol using Starlette's Response class - based on the implementation https://github.com/sysid/sse-starlette/ - """ - - # Sends a ping event to the client every 15 seconds to overcome proxy timeout issues - DEFAULT_PING_INTERVAL = 15 - - def __init__( - self, - *args, - generator: AsyncGenerator[GraphQLServerSentEvent, Any], - send_timeout: Optional[int] = None, - ping_interval: Optional[int] = None, - headers: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - **kwargs, - ): - """Initializes the a SSE Response that send events generated by an async generator - - # Required arguments - `generator`: an async generator that yields `GraphQLServerSentEvent` objects - - # Optional arguments - `send_timeout`: the timeout in seconds to send an event to the client - `ping_interval`: the interval in seconds to send a ping event to the client, overrides - the DEFAULT_PING_INTERVAL of 15 seconds - `headers`: a dictionary of headers to be sent with the response - `encoding`: the encoding to use for the response - """ - super().__init__(*args, **kwargs) - self.generator = generator - self.status_code = HTTPStatus.OK - self.send_timeout = send_timeout - self.ping_interval = ping_interval or self.DEFAULT_PING_INTERVAL - self.encoding = encoding or "utf-8" - self.body = None # type: ignore - - _headers: Dict[str, str] = {} - if headers is not None: - _headers.update(headers) - # mandatory for servers-sent events headers - # allow cache control header to be set by user to support fan out proxies - # https://www.fastly.com/blog/server-sent-events-fastly - _headers.setdefault("Cache-Control", "no-cache") - _headers.setdefault("Connection", "keep-alive") - _headers.setdefault("X-Accel-Buffering", "no") - _headers.setdefault("Transfer-Encoding", "chunked") - self.media_type = "text/event-stream" - self.init_headers(_headers) - - self._send_lock = Lock() - - @staticmethod - async def listen_for_disconnect(receive: Receive) -> None: - """Listens for the client disconnect event and stops the streaming by exiting the infinite - loop this triggers the anyio CancelScope to cancel the TaskGroup - - # Required arguments - `receive`: the starlette Receive object - """ - while True: - message = await receive() - if message["type"] == "http.disconnect": - logging.debug("Got event: http.disconnect. Stop streaming...") - break - - def encode_event(self, event: GraphQLServerSentEvent) -> bytes: - """Encodes the GraphQLServerSentEvent into a bytes object - - # Required arguments - `event`: the GraphQLServerSentEvent object - """ - return str(event).encode(self.encoding) - - async def _ping(self, send: Send) -> None: - """Sends a ping event to the client every `ping_interval` seconds gets - cancelled if the client disconnects through the anyio CancelScope of the TaskGroup - - # Required arguments - `send`: the starlette Send object - """ - while True: - await sleep(self.ping_interval) - async with self._send_lock: - await send( - { - "type": "http.response.body", - "body": ":\r\n\r\n".encode(self.encoding), - "more_body": True, - } - ) - - async def send_events(self, send: Send) -> None: - """Sends the events generated by the async generator to the client - - # Required arguments - `send`: the starlette Send object - - """ - async with self._send_lock: - await send( - { - "type": "http.response.start", - "status": self.status_code, - "headers": self.raw_headers, - } - ) - - try: - async for event in self.generator: - async with self._send_lock: - with move_on_after(self.send_timeout) as timeout: - await send( - { - "type": "http.response.body", - "body": self.encode_event(event), - "more_body": True, - } - ) - - if timeout.cancel_called: - raise asyncio.TimeoutError() - - except (get_cancelled_exc_class(),) as e: - logging.warning(e) - finally: - with CancelScope(shield=True): - async with self._send_lock: - await send( - {"type": "http.response.body", "body": b"", "more_body": False} - ) - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """The main entrypoint for the ServerSentEventResponse which is called by starlette - - # Required arguments - `scope`: the starlette Scope object - - `receive`: the starlette Receive object - - `send`: the starlette Send object - - """ - async with create_task_group() as task_group: - - async def wrap_cancelling(func: Callable[[], Awaitable[None]]) -> None: - await func() - task_group.cancel_scope.cancel() - - task_group.start_soon(wrap_cancelling, partial(self._ping, send)) - task_group.start_soon(wrap_cancelling, partial(self.send_events, send)) - # this will cancel the task group when the client disconnects - await wrap_cancelling(partial(self.listen_for_disconnect, receive)) +from .base import GraphQLHttpHandlerBase class GraphQLHTTPHandler(GraphQLHttpHandlerBase): @@ -364,6 +93,16 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: response = await self.handle_request(request) await response(scope, receive, send) + async def handle_request_override(self, request: Request) -> Response | None: + """Override the default request handling logic in subclasses. + Is called in the `handle_request` method before the default logic. + If None is returned, the default logic is executed. + + # Required arguments: + `request`: the `Request` instance from Starlette or FastAPI. + """ + return None + async def handle_request(self, request: Request) -> Response: """Handle GraphQL request and return response for the client. @@ -386,6 +125,10 @@ async def handle_request(self, request: Request) -> Response: `request`: the `Request` instance from Starlette or FastAPI. """ + response = await self.handle_request_override(request) + if response is not None: + return response + if request.method == "GET": if self.execute_get_queries and request.query_params.get("query"): return await self.graphql_http_server(request) @@ -394,10 +137,6 @@ async def handle_request(self, request: Request) -> Response: return await self.render_explorer(request, self.explorer) if request.method == "POST": - accept = request.headers.get("Accept", "") - accept = accept.split(",")[0] - if accept == "text/event-stream": - return await self.handle_sse_request(request) return await self.graphql_http_server(request) return self.handle_not_allowed_method(request) @@ -700,134 +439,3 @@ def handle_not_allowed_method(self, request: Request): return Response(headers=allow_header) return Response(status_code=HTTPStatus.METHOD_NOT_ALLOWED, headers=allow_header) - - async def handle_sse_request(self, request: Request) -> Response: - """Handles the HTTP request with GraphQL Subscription query using Server-Sent Events. - - # Required arguments - - `request`: the starlette `Request` instance - """ - - try: - data: Any = await self.extract_data_from_request(request) - query = await self.get_query_from_sse_request(request, data) - - if self.schema is None: - raise TypeError( - "schema is not set, call configure method to initialize it" - ) - - validate_data(data) - context_value = await self.get_context_for_request(request, data) - return ServerSentEventResponse( - generator=self.sse_subscribe_to_graphql(query, data, context_value) - ) - except (HttpError, TypeError, GraphQLError) as error: - log_error(error, self.logger) - if not isinstance(error, GraphQLError): - error_message = ( - (error.message or error.status) - if isinstance(error, HttpError) - else str(error) - ) - error = GraphQLError(error_message, original_error=error) - return ServerSentEventResponse( - generator=self.sse_generate_error_response([error]) - ) - - async def sse_generate_error_response( - self, errors: List[GraphQLError] - ) -> AsyncGenerator[GraphQLServerSentEvent, Any]: - """A Server-Sent Event response generator for the errors - To be passed to a ServerSentEventResponse instance - - # Required arguments - - `errors`: a list of `GraphQLError` instances - """ - - yield GraphQLServerSentEvent( - event="next", result=ExecutionResult(errors=errors) - ) - yield GraphQLServerSentEvent(event="complete") - - async def sse_subscribe_to_graphql( - self, query_document: DocumentNode, data: Any, context_value: Any - ): - """Main SSE subscription generator for the GraphQL query. - Yields `GraphQLServerSentEvent` instances and is to be consumed by a - `ServerSentEventResponse` instance - - # Required arguments - - `query_document`: an already parsed GraphQL query. - - `data`: a `dict` with query data (`query` string, optionally `operationName` - string and `variables` dictionary). - - `context_value`: a context value to make accessible as 'context' attribute - of second argument (`info`) passed to resolvers and source functions. - """ - - success, results = await subscribe( - self.schema, # type: ignore - data, - context_value=context_value, - root_value=self.root_value, - query_document=query_document, - query_validator=self.query_validator, - validation_rules=self.validation_rules, - debug=self.debug, - introspection=self.introspection, - logger=self.logger, - error_formatter=self.error_formatter, - ) - - if not success: - if not isinstance(results, list): - error_payload = cast(List[dict], [results]) - else: - error_payload = results - - # This needs to be handled better, subscribe returns preformatted errors - yield GraphQLServerSentEvent( - event="next", - result=ExecutionResult( - errors=[ - GraphQLError(message=cast(str, error.get("message", ""))) - for error in error_payload - ] - ), - ) - else: - results = cast(AsyncGenerator, results) - try: - async for result in results: - yield GraphQLServerSentEvent(event="next", result=result) - except (Exception, GraphQLError) as error: - if not isinstance(error, GraphQLError): - error = GraphQLError(str(error), original_error=error) - log_error(error, self.logger) - yield GraphQLServerSentEvent( - event="next", result=ExecutionResult(errors=[error]) - ) - - yield GraphQLServerSentEvent(event="complete") - - async def get_query_from_sse_request( - self, request: Request, data: Any - ) -> DocumentNode: - """Extracts GraphQL query from SSE request. - - Returns a `DocumentNode` with parsed query. - - # Required arguments - - `request`: the starlette `Request` instance - - `data`: an additional data parameter to potentially extract the query from - """ - - context_value = await self.get_context_for_request(request, data) - return parse_query(context_value, self.query_parser, data) diff --git a/ariadne/contrib/sse.py b/ariadne/contrib/sse.py new file mode 100644 index 00000000..2ce52191 --- /dev/null +++ b/ariadne/contrib/sse.py @@ -0,0 +1,456 @@ +import asyncio +import json +import logging +from asyncio import Lock +from functools import partial +from http import HTTPStatus +from io import StringIO +from typing import ( + Any, + Optional, + cast, + AsyncGenerator, + List, + Literal, + get_args, + Dict, + Callable, + Awaitable, + Type, +) + +from anyio import ( + get_cancelled_exc_class, + CancelScope, + sleep, + move_on_after, + create_task_group, +) +from graphql import DocumentNode +from graphql import MiddlewareManager +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +from .. import format_error +from ..asgi.handlers import GraphQLHTTPHandler +from ..exceptions import HttpError +from ..graphql import ( + ExecutionResult, + GraphQLError, + parse_query, + subscribe, + validate_data, +) +from ..logger import log_error +from ..types import Extensions, Middlewares + +EVENT_TYPES = Literal["next", "complete"] + + +class GraphQLServerSentEvent: + """GraphQLServerSentEvent is a class that represents a single Server-Sent Event + as defined in the GraphQL SSE Protocol specification + (https://github.com/enisdenjo/graphql-sse/blob/master/PROTOCOL.md) + """ + + DEFAULT_SEPARATOR = "\r\n" + + def __init__( + self, + event: EVENT_TYPES, + result: Optional[ExecutionResult] = None, + ): + """Initializes the Server-Sent Event + # Required arguments + `event`: the type of the event. Either "next" or "complete" + + # Optional arguments + `result`: an `ExecutionResult` or a `dict` that represents the result of the operation + """ + assert event in get_args(EVENT_TYPES), f"Invalid event type: {event}" + self.event = event + self.result = result + self.logger = logging.Logger("GraphQLServerSentEvent") + + def __str__(self) -> str: + """Returns the string representation of the Server-Sent Event""" + buffer = StringIO() + buffer = self._write_to_buffer(buffer, "event", self.event) + buffer = self._write_to_buffer( + buffer, + "data", + ( + self.encode_execution_result() + if self.event == "next" and self.result + else "" + ), + ) + buffer.write(self.DEFAULT_SEPARATOR) + + return buffer.getvalue() + + def _write_to_buffer( + self, buffer: StringIO, name: str, value: Optional[str] + ) -> StringIO: + """Writes an SSE field to the buffered SSE event representation + + Returns the `StringIO` buffer with the field written to it + + # Required arguments + `buffer`: the `StringIO` buffer to write to + `name`: the name of the field + `value`: the value of the field + """ + if value is not None: + buffer.write(f"{name}: {value}{self.DEFAULT_SEPARATOR}") + return buffer + + def encode_execution_result(self) -> str: + """Encodes the execution result into a single line JSON string + + Returns the JSON string representation of the execution result + """ + payload: Dict[str, Any] = {} + if self.result is not None and self.result.data is not None: + payload["data"] = self.result.data + if self.result is not None and self.result.errors is not None: + errors = [] + for error in self.result.errors: + errors.append(format_error(error)) + payload["errors"] = errors + + return json.dumps(payload) + + +class ServerSentEventResponse(Response): + """Sends GraphQL SSE events using the EventSource protocol using Starlette's Response class + based on the implementation https://github.com/sysid/sse-starlette/ + """ + + # Sends a ping event to the client every 15 seconds to overcome proxy timeout issues + DEFAULT_PING_INTERVAL = 15 + + def __init__( + self, + *args, + generator: AsyncGenerator[GraphQLServerSentEvent, Any], + send_timeout: Optional[int] = None, + ping_interval: Optional[int] = None, + headers: Optional[Dict[str, str]] = None, + **kwargs, + ): + """Initializes an SSE Response that sends events generated by an async generator + + # Required arguments + `generator`: an async generator that yields `GraphQLServerSentEvent` objects + + # Optional arguments + `send_timeout`: the timeout in seconds to send an event to the client + `ping_interval`: the interval in seconds to send a ping event to the client, overrides + the DEFAULT_PING_INTERVAL of 15 seconds + `headers`: a dictionary of headers to be sent with the response + `encoding`: the encoding to use for the response + """ + super().__init__(*args, **kwargs) + self.generator = generator + self.status_code = HTTPStatus.OK + self.send_timeout = send_timeout + self.ping_interval = ping_interval or self.DEFAULT_PING_INTERVAL + self.body = None # type: ignore + + _headers: Dict[str, str] = {} + if headers is not None: + _headers.update(headers) + # mandatory for servers-sent events headers + # allow cache control header to be set by user to support fan out proxies + # https://www.fastly.com/blog/server-sent-events-fastly + _headers.setdefault("Cache-Control", "no-cache") + _headers.setdefault("Connection", "keep-alive") + _headers.setdefault("X-Accel-Buffering", "no") + _headers.setdefault("Transfer-Encoding", "chunked") + self.media_type = "text/event-stream" + self.init_headers(_headers) + + self._send_lock = Lock() + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """The main entrypoint for the ServerSentEventResponse which is called by starlette + + # Required arguments + `scope`: the starlette Scope object + + `receive`: the starlette Receive object + + `send`: the starlette Send object + + """ + async with create_task_group() as task_group: + + async def wrap_cancelling(func: Callable[[], Awaitable[None]]) -> None: + await func() + task_group.cancel_scope.cancel() + + task_group.start_soon(wrap_cancelling, partial(self._ping, send)) + task_group.start_soon(wrap_cancelling, partial(self.send_events, send)) + # this will cancel the task group when the client disconnects + await wrap_cancelling(partial(self.listen_for_disconnect, receive)) + + async def _ping(self, send: Send) -> None: + """Sends a ping event to the client every `ping_interval` seconds gets + cancelled if the client disconnects through the anyio CancelScope of the TaskGroup + + # Required arguments + `send`: the starlette Send object + """ + while True: + await sleep(self.ping_interval) + async with self._send_lock: + await send( + { + "type": "http.response.body", + # always encode as utf-8 as per https://html.spec.whatwg.org/multipage/server-sent-events.html#sse-processing-model + "body": ":\r\n\r\n".encode("utf-8"), + "more_body": True, + } + ) + + async def send_events(self, send: Send) -> None: + """Sends the events generated by the async generator to the client + + # Required arguments + `send`: the starlette Send object + + """ + async with self._send_lock: + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + + try: + async for event in self.generator: + async with self._send_lock: + with move_on_after(self.send_timeout) as timeout: + await send( + { + "type": "http.response.body", + "body": self.encode_event(event), + "more_body": True, + } + ) + + if timeout.cancel_called: + raise asyncio.TimeoutError() + + except (get_cancelled_exc_class(),) as e: + logging.warning(e) + finally: + with CancelScope(shield=True): + async with self._send_lock: + await send( + {"type": "http.response.body", "body": b"", "more_body": False} + ) + + @staticmethod + async def listen_for_disconnect(receive: Receive) -> None: + """Listens for the client disconnect event and stops the streaming by exiting the infinite + loop. This triggers the anyio CancelScope to cancel the TaskGroup + + # Required arguments + `receive`: the starlette Receive object + """ + while True: + message = await receive() + if message["type"] == "http.disconnect": + logging.debug("Got event: http.disconnect. Stop streaming...") + break + + @staticmethod + def encode_event(event: GraphQLServerSentEvent) -> bytes: + """Encodes the GraphQLServerSentEvent into a bytes object + + # Required arguments + `event`: the GraphQLServerSentEvent object + """ + # always encode as utf-8 as per https://html.spec.whatwg.org/multipage/server-sent-events.html#sse-processing-model + return str(event).encode("utf-8") + + +class GraphQLHTTPSSEHandler(GraphQLHTTPHandler): + """Extension to the default GraphQLHTTPHandler to also handle Server-Sent Events as per + the GraphQL SSE Protocol specification. This handler only supports the defined `Distinct connections mode` + due to its statelessness. This implementation is based on the specification as of commit + 80cf75b5952d1a065c95bdbd6a74304c90dbe2c5. For more information see the specification + (https://github.com/enisdenjo/graphql-sse/blob/master/PROTOCOL.md) + """ + + def __init__( + self, + extensions: Optional[Extensions] = None, + middleware: Optional[Middlewares] = None, + middleware_manager_class: Optional[Type[MiddlewareManager]] = None, + send_timeout: Optional[int] = None, + ping_interval: Optional[int] = None, + default_response_headers: Optional[Dict[str, str]] = None, + ): + super().__init__(extensions, middleware, middleware_manager_class) + self.send_timeout = send_timeout + self.ping_interval = ping_interval + self.default_response_headers = default_response_headers + + async def handle_request_override(self, request: Request) -> Response | None: + """Overrides the handle_request_override method to handle Server-Sent Events + + # Required arguments + `request`: the starlette or FastAPI `Request` instance + + """ + + if request.method == "POST": + accept = request.headers.get("Accept", "").split(",") + accept = [a.strip() for a in accept] + if "text/event-stream" in accept: + return await self.handle_sse_request(request) + return None + + async def handle_sse_request(self, request: Request) -> Response: + """Handles the HTTP request with GraphQL Subscription query using Server-Sent Events. + + # Required arguments + + `request`: the starlette `Request` instance + """ + + try: + data: Any = await self.extract_data_from_request(request) + query = await self.get_query_from_sse_request(request, data) + + if self.schema is None: + raise TypeError( + "schema is not set, call configure method to initialize it" + ) + + validate_data(data) + context_value = await self.get_context_for_request(request, data) + return ServerSentEventResponse( + generator=self.sse_subscribe_to_graphql(query, data, context_value), + ping_interval=self.ping_interval, + send_timeout=self.send_timeout, + headers=self.default_response_headers, + ) + except (HttpError, TypeError, GraphQLError) as error: + log_error(error, self.logger) + if not isinstance(error, GraphQLError): + error_message = ( + (error.message or error.status) + if isinstance(error, HttpError) + else str(error) + ) + error = GraphQLError(error_message, original_error=error) + return ServerSentEventResponse( + generator=self.sse_generate_error_response([error]), + ping_interval=self.ping_interval, + send_timeout=self.send_timeout, + headers=self.default_response_headers, + ) + + async def get_query_from_sse_request( + self, request: Request, data: Any + ) -> DocumentNode: + """Extracts GraphQL query from SSE request. + + Returns a `DocumentNode` with parsed query. + + # Required arguments + + `request`: the starlette `Request` instance + + `data`: an additional data parameter to potentially extract the query from + """ + + context_value = await self.get_context_for_request(request, data) + return parse_query(context_value, self.query_parser, data) + + async def sse_subscribe_to_graphql( + self, query_document: DocumentNode, data: Any, context_value: Any + ): + """Main SSE subscription generator for the GraphQL query. + Yields `GraphQLServerSentEvent` instances and is to be consumed by a + `ServerSentEventResponse` instance + + # Required arguments + + `query_document`: an already parsed GraphQL query. + + `data`: a `dict` with query data (`query` string, optionally `operationName` + string and `variables` dictionary). + + `context_value`: a context value to make accessible as 'context' attribute + of second argument (`info`) passed to resolvers and source functions. + """ + + success, results = await subscribe( + self.schema, # type: ignore + data, + context_value=context_value, + root_value=self.root_value, + query_document=query_document, + query_validator=self.query_validator, + validation_rules=self.validation_rules, + debug=self.debug, + introspection=self.introspection, + logger=self.logger, + error_formatter=self.error_formatter, + ) + + if not success: + if not isinstance(results, list): + error_payload = cast(List[dict], [results]) + else: + error_payload = results + + # This needs to be handled better, subscribe returns preformatted errors + yield GraphQLServerSentEvent( + event="next", + result=ExecutionResult( + errors=[ + GraphQLError(message=cast(str, error.get("message", ""))) + for error in error_payload + ] + ), + ) + else: + results = cast(AsyncGenerator, results) + try: + async for result in results: + yield GraphQLServerSentEvent(event="next", result=result) + except (Exception, GraphQLError) as error: + if not isinstance(error, GraphQLError): + error = GraphQLError(str(error), original_error=error) + log_error(error, self.logger) + yield GraphQLServerSentEvent( + event="next", result=ExecutionResult(errors=[error]) + ) + + yield GraphQLServerSentEvent(event="complete") + + @staticmethod + async def sse_generate_error_response( + errors: List[GraphQLError], + ) -> AsyncGenerator[GraphQLServerSentEvent, Any]: + """A Server-Sent Event response generator for the errors + To be passed to a ServerSentEventResponse instance + + # Required arguments + + `errors`: a list of `GraphQLError` instances + """ + + yield GraphQLServerSentEvent( + event="next", result=ExecutionResult(errors=errors) + ) + yield GraphQLServerSentEvent(event="complete") diff --git a/tests/asgi/test_sse.py b/tests/asgi/test_sse.py index 27c8cf44..ad7ed216 100644 --- a/tests/asgi/test_sse.py +++ b/tests/asgi/test_sse.py @@ -2,12 +2,14 @@ from http import HTTPStatus from typing import List, Dict, Any from unittest.mock import Mock -from graphql import parse, GraphQLError -from starlette.testclient import TestClient -from httpx import Response import pytest +from graphql import parse, GraphQLError +from httpx import Response +from starlette.testclient import TestClient + from ariadne.asgi import GraphQL +from ariadne.contrib.sse import GraphQLHTTPSSEHandler SSE_HEADER = {"Accept": "text/event-stream"} @@ -31,7 +33,13 @@ def get_sse_events(response: Response) -> List[Dict[str, Any]]: @pytest.fixture def sse_client(schema): - app = GraphQL(schema, introspection=False) + app = GraphQL( + schema, + http_handler=GraphQLHTTPSSEHandler( + default_response_headers={"Test_Header": "test"} + ), + introspection=False, + ) return TestClient(app, headers=SSE_HEADER) @@ -78,6 +86,7 @@ def test_custom_query_parser_is_used_for_subscription_over_sse(schema): mock_parser = Mock(return_value=parse("subscription { testContext }")) app = GraphQL( schema, + http_handler=GraphQLHTTPSSEHandler(), query_parser=mock_parser, context_value={"test": "I'm context"}, root_value={"test": "I'm root"}, @@ -87,6 +96,7 @@ def test_custom_query_parser_is_used_for_subscription_over_sse(schema): response = client.post("/", json={"query": "subscription { testRoot }"}) events = get_sse_events(response) + print(response) assert len(events) == 2 assert events[0]["data"]["data"] == {"testContext": "I'm context"} assert events[1]["event"] == "complete" @@ -103,6 +113,7 @@ def test_custom_query_validator_is_used_for_subscription_over_sse(schema, errors mock_validator = Mock(return_value=errors) app = GraphQL( schema, + http_handler=GraphQLHTTPSSEHandler(), query_validator=mock_validator, context_value={"test": "I'm context"}, root_value={"test": "I'm root"}, @@ -132,7 +143,7 @@ def test_custom_query_validator_is_used_for_subscription_over_sse(schema, errors def test_schema_not_set_graphql_sse(): - app = GraphQL(None) + app = GraphQL(None, http_handler=GraphQLHTTPSSEHandler()) client = TestClient(app, headers=SSE_HEADER) response = client.post( @@ -158,16 +169,42 @@ def test_ping_is_send_sse(sse_client): assert len(events) == 4 assert events[0]["event"] == "next" assert events[0]["data"]["data"] == {"testSlow": "slow"} - assert events[1]["event"] == "" + assert events[1]["event"] == "" # ping assert events[1]["data"] is None assert events[2]["event"] == "next" assert events[2]["data"]["data"] == {"testSlow": "slow"} assert events[3]["event"] == "complete" +def test_custom_ping_interval(schema): + app = GraphQL( + schema, + http_handler=GraphQLHTTPSSEHandler(ping_interval=10), + introspection=False, + ) + sse_client = TestClient(app, headers=SSE_HEADER) + response = sse_client.post("/", json={"query": "subscription { testSlow }"}) + events = get_sse_events(response) + assert len(events) == 5 + assert events[0]["event"] == "next" + assert events[0]["data"]["data"] == {"testSlow": "slow"} + assert events[1]["event"] == "" # ping + assert events[1]["data"] is None + assert events[2]["event"] == "" # second ping + assert events[2]["data"] is None + assert events[3]["event"] == "next" + assert events[3]["data"]["data"] == {"testSlow": "slow"} + assert events[4]["event"] == "complete" + + def test_resolver_error_is_handled_sse(sse_client): response = sse_client.post("/", json={"query": "subscription { resolverError }"}) events = get_sse_events(response) assert len(events) == 2 assert events[0]["data"]["errors"][0]["message"] == "Test exception" assert events[1]["event"] == "complete" + + +def test_default_headers_are_applied(sse_client): + response = sse_client.post("/", json={"query": "subscription { ping }"}) + assert response.headers["Test_Header"] == "test" diff --git a/tests_integrations/fastapi/test_sse.py b/tests_integrations/fastapi/test_sse.py index eae596c4..400d969b 100644 --- a/tests_integrations/fastapi/test_sse.py +++ b/tests_integrations/fastapi/test_sse.py @@ -4,6 +4,7 @@ from ariadne import SubscriptionType, make_executable_schema from ariadne.asgi import GraphQL from ariadne.asgi.handlers import GraphQLTransportWSHandler +from ariadne.contrib.sse import GraphQLHTTPSSEHandler subscription_type = SubscriptionType() @@ -34,6 +35,7 @@ async def counter_resolve(obj, *_): app = FastAPI() graphql = GraphQL( schema, + http_handler=GraphQLHTTPSSEHandler(), websocket_handler=GraphQLTransportWSHandler(), ) diff --git a/tests_integrations/starlette/test_sse.py b/tests_integrations/starlette/test_sse.py index 58caa934..a65430b0 100644 --- a/tests_integrations/starlette/test_sse.py +++ b/tests_integrations/starlette/test_sse.py @@ -4,6 +4,7 @@ from ariadne import SubscriptionType, make_executable_schema from ariadne.asgi import GraphQL +from ariadne.contrib.sse import GraphQLHTTPSSEHandler subscription_type = SubscriptionType() @@ -31,7 +32,7 @@ async def counter_resolve(obj, *_): subscription_type, ) -graphql = GraphQL(schema) +graphql = GraphQL(schema, http_handler=GraphQLHTTPSSEHandler()) app = Starlette( routes=[ From fc493010ba81c1157b73055698027616cccb49a1 Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Thu, 12 Dec 2024 14:34:10 +0100 Subject: [PATCH 09/23] remove pipe symbol for compatibility with older python versions --- ariadne/asgi/handlers/http.py | 2 +- ariadne/contrib/sse.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index 00d183ca..8d261f60 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -93,7 +93,7 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: response = await self.handle_request(request) await response(scope, receive, send) - async def handle_request_override(self, request: Request) -> Response | None: + async def handle_request_override(self, request: Request) -> Optional[Response]: """Override the default request handling logic in subclasses. Is called in the `handle_request` method before the default logic. If None is returned, the default logic is executed. diff --git a/ariadne/contrib/sse.py b/ariadne/contrib/sse.py index 2ce52191..e6b4be0c 100644 --- a/ariadne/contrib/sse.py +++ b/ariadne/contrib/sse.py @@ -302,7 +302,7 @@ def __init__( self.ping_interval = ping_interval self.default_response_headers = default_response_headers - async def handle_request_override(self, request: Request) -> Response | None: + async def handle_request_override(self, request: Request) -> Optional[Response]: """Overrides the handle_request_override method to handle Server-Sent Events # Required arguments From 0243e1e54542d6a724ac5ed48e000c209edb9dc1 Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Thu, 12 Dec 2024 15:22:11 +0100 Subject: [PATCH 10/23] format and linting, update requirements.txt in fastapi integration tests to fix Starlette TestClient error, manually set anyio dependency lower for python 3.8 integration tests as current version is not available for 3.8 --- ariadne/asgi/handlers/http.py | 4 ++-- ariadne/contrib/sse.py | 13 +++++++----- tests/asgi/test_sse.py | 2 +- tests_integrations/fastapi/requirements.txt | 23 ++++++++++++--------- 4 files changed, 24 insertions(+), 18 deletions(-) diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index 8d261f60..9432fecc 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -93,13 +93,13 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: response = await self.handle_request(request) await response(scope, receive, send) - async def handle_request_override(self, request: Request) -> Optional[Response]: + async def handle_request_override(self, _: Request) -> Optional[Response]: """Override the default request handling logic in subclasses. Is called in the `handle_request` method before the default logic. If None is returned, the default logic is executed. # Required arguments: - `request`: the `Request` instance from Starlette or FastAPI. + `_`: the `Request` instance from Starlette or FastAPI. """ return None diff --git a/ariadne/contrib/sse.py b/ariadne/contrib/sse.py index e6b4be0c..cbc5e506 100644 --- a/ariadne/contrib/sse.py +++ b/ariadne/contrib/sse.py @@ -209,7 +209,8 @@ async def _ping(self, send: Send) -> None: await send( { "type": "http.response.body", - # always encode as utf-8 as per https://html.spec.whatwg.org/multipage/server-sent-events.html#sse-processing-model + # always encode as utf-8 as per + # https://html.spec.whatwg.org/multipage/server-sent-events.html#sse-processing-model "body": ":\r\n\r\n".encode("utf-8"), "more_body": True, } @@ -276,15 +277,17 @@ def encode_event(event: GraphQLServerSentEvent) -> bytes: # Required arguments `event`: the GraphQLServerSentEvent object """ - # always encode as utf-8 as per https://html.spec.whatwg.org/multipage/server-sent-events.html#sse-processing-model + # always encode as utf-8 as per + # https://html.spec.whatwg.org/multipage/server-sent-events.html#sse-processing-model return str(event).encode("utf-8") class GraphQLHTTPSSEHandler(GraphQLHTTPHandler): """Extension to the default GraphQLHTTPHandler to also handle Server-Sent Events as per - the GraphQL SSE Protocol specification. This handler only supports the defined `Distinct connections mode` - due to its statelessness. This implementation is based on the specification as of commit - 80cf75b5952d1a065c95bdbd6a74304c90dbe2c5. For more information see the specification + the GraphQL SSE Protocol specification. This handler only supports the defined + `Distinct connections mode` due to its statelessness. This implementation is based on + the specification as of commit 80cf75b5952d1a065c95bdbd6a74304c90dbe2c5. + For more information see the specification (https://github.com/enisdenjo/graphql-sse/blob/master/PROTOCOL.md) """ diff --git a/tests/asgi/test_sse.py b/tests/asgi/test_sse.py index ad7ed216..05a67257 100644 --- a/tests/asgi/test_sse.py +++ b/tests/asgi/test_sse.py @@ -179,7 +179,7 @@ def test_ping_is_send_sse(sse_client): def test_custom_ping_interval(schema): app = GraphQL( schema, - http_handler=GraphQLHTTPSSEHandler(ping_interval=10), + http_handler=GraphQLHTTPSSEHandler(ping_interval=8), introspection=False, ) sse_client = TestClient(app, headers=SSE_HEADER) diff --git a/tests_integrations/fastapi/requirements.txt b/tests_integrations/fastapi/requirements.txt index b039b774..3524c884 100644 --- a/tests_integrations/fastapi/requirements.txt +++ b/tests_integrations/fastapi/requirements.txt @@ -1,27 +1,30 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile --output-file=requirements.txt requirements.in # -annotated-types==0.6.0 +annotated-types==0.7.0 # via pydantic -anyio==3.7.1 +anyio==4.5.2 # via starlette -fastapi==0.109.1 +exceptiongroup==1.2.2 + # via anyio +fastapi==0.115.6 # via -r requirements.in -idna==3.7 +idna==3.10 # via anyio -pydantic==2.4.2 +pydantic==2.10.3 # via fastapi -pydantic-core==2.10.1 +pydantic-core==2.27.1 # via pydantic -sniffio==1.3.0 +sniffio==1.3.1 # via anyio -starlette==0.35.1 +starlette==0.41.3 # via fastapi -typing-extensions==4.8.0 +typing-extensions==4.12.2 # via + # anyio # fastapi # pydantic # pydantic-core From 1a1b4042d6351c28d2fa29479f794849b627ae08 Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Mon, 26 Aug 2024 12:02:42 +0200 Subject: [PATCH 11/23] subscription over distinct SSE connection implementation and tests --- ariadne/asgi/handlers/http.py | 417 ++++++++++++++++++++++- tests/asgi/test_sse.py | 169 +++++++++ tests/conftest.py | 9 + tests_integrations/fastapi/test_sse.py | 76 +++++ tests_integrations/starlette/test_sse.py | 71 ++++ 5 files changed, 736 insertions(+), 6 deletions(-) create mode 100644 tests/asgi/test_sse.py create mode 100644 tests_integrations/fastapi/test_sse.py create mode 100644 tests_integrations/starlette/test_sse.py diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index 3fe1a5b4..e0716719 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -1,22 +1,60 @@ +import asyncio import json +import logging +from asyncio import Lock +from functools import partial from http import HTTPStatus from inspect import isawaitable -from typing import Any, Optional, Type, Union, cast - -from graphql import DocumentNode, MiddlewareManager +from io import StringIO +from typing import ( + Any, + Optional, + cast, + Dict, + AsyncGenerator, + Callable, + Awaitable, + Literal, + get_args, + List, + Union, +) +from typing import Type + +from anyio import ( + get_cancelled_exc_class, + CancelScope, + sleep, + move_on_after, + create_task_group, +) +from graphql import DocumentNode +from graphql import MiddlewareManager from starlette.datastructures import UploadFile from starlette.requests import Request -from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse, Response +from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse +from starlette.responses import Response from starlette.types import Receive, Scope, Send +from .base import GraphQLHttpHandlerBase +from ... import format_error from ...constants import ( DATA_TYPE_JSON, DATA_TYPE_MULTIPART, ) -from ...exceptions import HttpBadRequestError, HttpError +from ...exceptions import HttpBadRequestError +from ...exceptions import HttpError from ...explorer import Explorer from ...file_uploads import combine_multipart_data +from ...graphql import ( + ExecutionResult, + GraphQLError, + parse_query, + subscribe, + validate_data, +) from ...graphql import graphql +from ...logger import log_error from ...types import ( ContextValue, ExtensionList, @@ -25,7 +63,240 @@ MiddlewareList, Middlewares, ) -from .base import GraphQLHttpHandlerBase + +EVENT_TYPES = Literal["next", "complete"] + + +class GraphQLServerSentEvent: + """GraphQLServerSentEvent is a class that represents a single Server-Sent Event + as defined in the GraphQL SSE Protocol specification + (https://github.com/enisdenjo/graphql-sse/blob/master/PROTOCOL.md) + """ + + DEFAULT_SEPARATOR = "\r\n" + + def __init__( + self, + event: EVENT_TYPES, + result: Optional[ExecutionResult] = None, + ): + """Initializes the Server-Sent Event + # Required arguments + `event`: the type of the event. Either "next" or "complete" + + # Optional arguments + `result`: an `ExecutionResult` or a `dict` that represents the result of the operation + """ + assert event in get_args(EVENT_TYPES), f"Invalid event type: {event}" + self.event = event + self.result = result + self.logger = logging.Logger("GraphQLServerSentEvent") + + def _write_to_buffer( + self, buffer: StringIO, name: str, value: Optional[str] + ) -> StringIO: + """Writes a SSE field to the buffered SSE event representation + + Returns the `StringIO` buffer with the field written to it + + # Required arguments + `buffer`: the `StringIO` buffer to write to + `name`: the name of the field + `value`: the value of the field + """ + if value is not None: + buffer.write(f"{name}: {value}{self.DEFAULT_SEPARATOR}") + return buffer + + def encode_execution_result(self) -> str: + """Encodes the execution result into a single line JSON string + + Returns the JSON string representation of the execution result + """ + payload: Dict[str, Any] = {} + if self.result.data: + payload["data"] = self.result.data + if self.result.errors: + errors = [] + for error in self.result.errors: + log_error(error, self.logger) + errors.append(format_error(error)) + payload["errors"] = errors + + return json.dumps(payload) + + def __str__(self) -> str: + """Returns the string representation of the Server-Sent Event""" + buffer = StringIO() + buffer = self._write_to_buffer(buffer, "event", self.event) + buffer = self._write_to_buffer( + buffer, + "data", + ( + self.encode_execution_result() + if self.event == "next" and self.result + else "" + ), + ) + buffer.write(self.DEFAULT_SEPARATOR) + + return buffer.getvalue() + + +class ServerSentEventResponse(Response): + """Sends GraphQL SSE events using EvenSource protocol using Starlette's Response class + based on the implementation https://github.com/sysid/sse-starlette/ + """ + + # Sends a ping event to the client every 15 seconds to overcome proxy timeout issues + DEFAULT_PING_INTERVAL = 15 + + def __init__( + self, + generator: AsyncGenerator[GraphQLServerSentEvent, Any], + send_timeout: Optional[int] = None, + ping_interval: Optional[int] = None, + headers: Optional[Dict[str, str]] = None, + encoding: Optional[str] = None, + *args, + **kwargs, + ): + """Initializes the a SSE Response that send events generated by an async generator + + # Required arguments + `generator`: an async generator that yields `GraphQLServerSentEvent` objects + + # Optional arguments + `send_timeout`: the timeout in seconds to send an event to the client + `ping_interval`: the interval in seconds to send a ping event to the client, overrides + the DEFAULT_PING_INTERVAL of 15 seconds + `headers`: a dictionary of headers to be sent with the response + `encoding`: the encoding to use for the response + """ + super().__init__(*args, **kwargs) + self.generator = generator + self.status_code = HTTPStatus.OK + self.send_timeout = send_timeout + self.ping_interval = ping_interval or self.DEFAULT_PING_INTERVAL + self.encoding = encoding or "utf-8" + self.content = None + + _headers: dict[str, str] = {} + if headers is not None: + _headers.update(headers) + # mandatory for servers-sent events headers + # allow cache control header to be set by user to support fan out proxies + # https://www.fastly.com/blog/server-sent-events-fastly + _headers.setdefault("Cache-Control", "no-cache") + _headers.setdefault("Connection", "keep-alive") + _headers.setdefault("X-Accel-Buffering", "no") + _headers.setdefault("Transfer-Encoding", "chunked") + self.media_type = "text/event-stream" + self.init_headers(_headers) + + self._send_lock = Lock() + + @staticmethod + async def listen_for_disconnect(receive: Receive) -> None: + """Listens for the client disconnect event and stops the streaming by exiting the infinite loop + this triggers the anyio CancelScope to cancel the TaskGroup + + # Required arguments + `receive`: the starlette Receive object + """ + while True: + message = await receive() + if message["type"] == "http.disconnect": + logging.debug(f"Got event: http.disconnect. Stop streaming...") + break + + def encode_event(self, event: GraphQLServerSentEvent) -> bytes: + """Encodes the GraphQLServerSentEvent into a bytes object + + # Required arguments + `event`: the GraphQLServerSentEvent object + """ + return str(event).encode(self.encoding) + + async def _ping(self, send: Send) -> None: + """Sends a ping event to the client every `ping_interval` seconds gets cancelled if the client disconnects + through the anyio CancelScope of the TaskGroup + + # Required arguments + `send`: the starlette Send object + """ + while True: + await sleep(self.ping_interval) + async with self._send_lock: + await send( + { + "type": "http.response.body", + "body": self.encode_event(GraphQLServerSentEvent(event="next")), + "more_body": True, + } + ) + + async def send_events(self, send: Send) -> None: + """Sends the events generated by the async generator to the client + + # Required arguments + `send`: the starlette Send object + + """ + async with self._send_lock: + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + + try: + async for event in self.generator: + async with self._send_lock: + with move_on_after(self.send_timeout) as timeout: + await send( + { + "type": "http.response.body", + "body": self.encode_event(event), + "more_body": True, + } + ) + + if timeout.cancel_called: + raise asyncio.TimeoutError() + + except (get_cancelled_exc_class(),) as e: + logging.warning(e) + finally: + with CancelScope(shield=True): + async with self._send_lock: + await send( + {"type": "http.response.body", "body": b"", "more_body": False} + ) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """The main entrypoint for the ServerSentEventResponse which is called by starlette + + # Required arguments + `scope`: the starlette Scope object + + `receive`: the starlette Receive object + + `send`: the starlette Send object + + """ + async with create_task_group() as task_group: + + async def wrap_cancelling(func: Callable[[], Awaitable[None]]) -> None: + await func() + task_group.cancel_scope.cancel() + + task_group.start_soon(wrap_cancelling, partial(self._ping, send)) + task_group.start_soon(wrap_cancelling, partial(self.send_events, send)) + # this will cancel the task group when the client disconnects + await wrap_cancelling(partial(self.listen_for_disconnect, receive)) class GraphQLHTTPHandler(GraphQLHttpHandlerBase): @@ -123,6 +394,10 @@ async def handle_request(self, request: Request) -> Response: return await self.render_explorer(request, self.explorer) if request.method == "POST": + accept = request.headers.get("Accept", "") + accept = accept.split(",")[0] + if accept == "text/event-stream": + return await self.handle_sse_request(request) return await self.graphql_http_server(request) return self.handle_not_allowed_method(request) @@ -425,3 +700,133 @@ def handle_not_allowed_method(self, request: Request): return Response(headers=allow_header) return Response(status_code=HTTPStatus.METHOD_NOT_ALLOWED, headers=allow_header) + + async def handle_sse_request(self, request: Request) -> Response: + """Handles the HTTP request with GraphQL Subscription query using Server-Sent Events. + + # Required arguments + + `request`: the starlette `Request` instance + """ + + try: + data = await self.extract_data_from_request(request) + query = await self.get_query_from_sse_request(request, data) + + if self.schema is None: + raise TypeError( + "schema is not set, call configure method to initialize it" + ) + + validate_data(data) + context_value = await self.get_context_for_request(request, data) + return ServerSentEventResponse( + generator=self.sse_subscribe_to_graphql(query, data, context_value) + ) + except (HttpError, TypeError, GraphQLError) as error: + log_error(error, self.logger) + if not isinstance(error, GraphQLError): + error_message = ( + (error.message or error.status) + if isinstance(error, HttpError) + else str(error) + ) + error = GraphQLError(error_message, original_error=error) + return ServerSentEventResponse( + generator=self.sse_generate_error_response([error]) + ) + + async def sse_generate_error_response( + self, errors: List[GraphQLError] + ) -> AsyncGenerator[GraphQLServerSentEvent, Any]: + """A Server-Sent Event response generator for the errors + To be passed to a ServerSentEventResponse instance + + # Required arguments + + `errors`: a list of `GraphQLError` instances + """ + + yield GraphQLServerSentEvent( + event="next", result=ExecutionResult(errors=errors) + ) + yield GraphQLServerSentEvent(event="complete") + + async def sse_subscribe_to_graphql( + self, query_document: DocumentNode, data: Any, context_value: Any + ): + """Main SSE subscription generator for the GraphQL query. Yields `GraphQLServerSentEvent` instances + and is to be consumed by a `ServerSentEventResponse` instance + + # Required arguments + + `query_document`: an already parsed GraphQL query. + + `data`: a `dict` with query data (`query` string, optionally `operationName` + string and `variables` dictionary). + + `context_value`: a context value to make accessible as 'context' attribute + of second argument (`info`) passed to resolvers and source functions. + """ + + success, results = await subscribe( + self.schema, # type: ignore + data, + context_value=context_value, + root_value=self.root_value, + query_document=query_document, + query_validator=self.query_validator, + validation_rules=self.validation_rules, + debug=self.debug, + introspection=self.introspection, + logger=self.logger, + error_formatter=self.error_formatter, + ) + + if not success: + if not isinstance(results, list): + error_payload = cast(List[dict], [results]) + else: + error_payload = results + + # This needs to be handled better, subscribe returns preformatted errors + yield GraphQLServerSentEvent( + event="next", + result=ExecutionResult( + errors=[ + GraphQLError(message=error.get("message")) + for error in error_payload + ] + ), + ) + else: + results = cast(AsyncGenerator, results) + try: + async for result in results: + yield GraphQLServerSentEvent(event="next", result=result) + except (Exception, GraphQLError) as error: + if not isinstance(error, GraphQLError): + error = GraphQLError(str(error), original_error=error) + log_error(error, self.logger) + yield GraphQLServerSentEvent( + event="next", result=ExecutionResult(errors=[error]) + ) + + yield GraphQLServerSentEvent(event="complete") + + async def get_query_from_sse_request( + self, request: Request, data: Any + ) -> DocumentNode: + """Extracts GraphQL query from SSE request. + + Returns a `DocumentNode` with parsed query. + + # Required arguments + + `request`: the starlette `Request` instance + + `data`: an additional data parameter to potentially extract the query from + """ + + context_value = await self.get_context_for_request(request, data) + return parse_query(context_value, self.query_parser, data) diff --git a/tests/asgi/test_sse.py b/tests/asgi/test_sse.py new file mode 100644 index 00000000..988a1ad2 --- /dev/null +++ b/tests/asgi/test_sse.py @@ -0,0 +1,169 @@ +import json +from http import HTTPStatus +from typing import List, Dict, Any +from unittest.mock import Mock +from graphql import parse, GraphQLError + +from starlette.testclient import TestClient +from httpx import Response +import pytest +from ariadne.asgi import GraphQL + +SSE_HEADER = {"Accept": "text/event-stream"} + + +def get_sse_events(response: Response) -> List[Dict[str, Any]]: + events = [] + for event in response.text.split("\r\n\r\n"): + if len(event.strip()) == 0: + continue + event, data = event.split("\r\n", 1) + event = event.replace("event: ", "") + data = data.replace("data: ", "") + data = json.loads(data) if len(data) > 0 else None + events.append({"event": event, "data": data}) + return events + + +@pytest.fixture +def sse_client(schema): + app = GraphQL(schema, introspection=False) + return TestClient(app, headers=SSE_HEADER) + + +def test_sse_headers(sse_client): + response = sse_client.post("/", json={"query": "subscription { ping }"}) + assert response.status_code == HTTPStatus.OK + assert response.headers["Cache-Control"] == "no-cache" + assert response.headers["Connection"] == "keep-alive" + assert response.headers["Transfer-Encoding"] == "chunked" + assert response.headers["X-Accel-Buffering"] == "no" + + +def test_field_can_be_subscribed_to_using_sse(sse_client): + response = sse_client.post("/", json={"query": "subscription { ping }"}) + events = get_sse_events(response) + assert len(events) == 2 + assert events[0]["data"]["data"] == {"ping": "pong"} + assert events[1]["event"] == "complete" + + +def test_non_subscription_query_cannot_be_executed_using_sse( + sse_client, +): + response = sse_client.post( + "/", + json={ + "query": "query Hello($name: String){ hello(name: $name) }", + "variables": {"name": "John"}, + }, + ) + events = get_sse_events(response) + assert len(events) == 2 + assert events[0]["data"].get("errors") is not None + + +def test_invalid_query_is_handled_using_sse(sse_client): + response = sse_client.post("/", json={"query": "query Invalid { error other }"}) + events = get_sse_events(response) + assert len(events) == 2 + assert events[0]["data"].get("errors") is not None + + +def test_custom_query_parser_is_used_for_subscription_over_sse(schema): + mock_parser = Mock(return_value=parse("subscription { testContext }")) + app = GraphQL( + schema, + query_parser=mock_parser, + context_value={"test": "I'm context"}, + root_value={"test": "I'm root"}, + ) + + client = TestClient(app, headers=SSE_HEADER) + response = client.post("/", json={"query": "subscription { testRoot }"}) + + events = get_sse_events(response) + assert len(events) == 2 + assert events[0]["data"]["data"] == {"testContext": "I'm context"} + assert events[1]["event"] == "complete" + + +@pytest.mark.parametrize( + ("errors"), + [ + ([]), + ([GraphQLError("Nope")]), + ], +) +def test_custom_query_validator_is_used_for_subscription_over_sse(schema, errors): + mock_validator = Mock(return_value=errors) + app = GraphQL( + schema, + query_validator=mock_validator, + context_value={"test": "I'm context"}, + root_value={"test": "I'm root"}, + ) + + client = TestClient(app, headers=SSE_HEADER) + response = client.post( + "/", + json={ + "operationName": None, + "query": "subscription { testContext }", + "variables": None, + }, + ) + + events = get_sse_events(response) + if not errors: + assert len(events) == 2 + assert events[0] == { + "event": "next", + "data": {"data": {"testContext": "I'm context"}}, + } + assert events[1] == {"event": "complete", "data": None} + else: + assert len(events) == 2 + assert events[0]["data"]["errors"][0]["message"] == "Nope" + + +def test_schema_not_set_graphql_sse(): + app = GraphQL(None) + + client = TestClient(app, headers=SSE_HEADER) + response = client.post( + "/", + json={ + "operationName": None, + "query": "subscription { testContext }", + "variables": None, + }, + ) + + events = get_sse_events(response) + assert len(events) == 2 + assert ( + events[0]["data"]["errors"][0]["message"] + == "schema is not set, call configure method to initialize it" + ) + + +def test_ping_is_send_sse(sse_client): + response = sse_client.post("/", json={"query": "subscription { testSlow }"}) + events = get_sse_events(response) + assert len(events) == 4 + assert events[0]["event"] == "next" + assert events[0]["data"]["data"] == {"testSlow": "slow"} + assert events[1]["event"] == "next" + assert events[1]["data"] is None + assert events[2]["event"] == "next" + assert events[2]["data"]["data"] == {"testSlow": "slow"} + assert events[3]["event"] == "complete" + + +def test_resolver_error_is_handled_sse(sse_client): + response = sse_client.post("/", json={"query": "subscription { resolverError }"}) + events = get_sse_events(response) + assert len(events) == 2 + assert events[0]["data"]["errors"][0]["message"] == "Test exception" + assert events[1]["event"] == "complete" diff --git a/tests/conftest.py b/tests/conftest.py index 9333a782..837aa8dc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import asyncio from collections.abc import Mapping import pytest @@ -37,6 +38,7 @@ def type_defs(): sourceError: Boolean testContext: String testRoot: String + testSlow: String } """ @@ -172,6 +174,12 @@ async def test_root_generator(root, *_): yield {"testRoot": root.get("test")} +async def test_slow_generator(_, info): + yield {"testSlow": "slow"} + await asyncio.sleep(20) + yield {"testSlow": "slow"} + + @pytest.fixture def subscriptions(): subscription = SubscriptionType() @@ -181,6 +189,7 @@ def subscriptions(): subscription.set_source("sourceError", error_generator) subscription.set_source("testContext", test_context_generator) subscription.set_source("testRoot", test_root_generator) + subscription.set_source("testSlow", test_slow_generator) return subscription diff --git a/tests_integrations/fastapi/test_sse.py b/tests_integrations/fastapi/test_sse.py new file mode 100644 index 00000000..eae596c4 --- /dev/null +++ b/tests_integrations/fastapi/test_sse.py @@ -0,0 +1,76 @@ +from fastapi import FastAPI, Request +from starlette.testclient import TestClient + +from ariadne import SubscriptionType, make_executable_schema +from ariadne.asgi import GraphQL +from ariadne.asgi.handlers import GraphQLTransportWSHandler + +subscription_type = SubscriptionType() + + +@subscription_type.source("counter") +async def counter_source(*_): + yield 1 + + +@subscription_type.field("counter") +async def counter_resolve(obj, *_): + return obj + + +schema = make_executable_schema( + """ + type Query { + _unused: String + } + + type Subscription { + counter: Int! + } + """, + subscription_type, +) + +app = FastAPI() +graphql = GraphQL( + schema, + websocket_handler=GraphQLTransportWSHandler(), +) + + +@app.post("/graphql") +async def graphql_route(request: Request): + return await graphql.handle_request(request) + + +app.mount("/mounted", graphql) + +client = TestClient(app, headers={"Accept": "text/event-stream"}) + + +def test_run_graphql_subscription_through_route(): + response = client.post( + "/graphql", + json={ + "operationName": None, + "query": "subscription { counter }", + "variables": None, + }, + ) + + assert response.status_code == 200 + assert '{"data": {"counter": 1}}' in response.text + + +def test_run_graphql_subscription_through_mount(): + response = client.post( + "/mounted", + json={ + "operationName": None, + "query": "subscription { counter }", + "variables": None, + }, + ) + + assert response.status_code == 200 + assert '{"data": {"counter": 1}}' in response.text diff --git a/tests_integrations/starlette/test_sse.py b/tests_integrations/starlette/test_sse.py new file mode 100644 index 00000000..58caa934 --- /dev/null +++ b/tests_integrations/starlette/test_sse.py @@ -0,0 +1,71 @@ +from starlette.applications import Starlette +from starlette.routing import Mount, Route +from starlette.testclient import TestClient + +from ariadne import SubscriptionType, make_executable_schema +from ariadne.asgi import GraphQL + +subscription_type = SubscriptionType() + + +@subscription_type.source("counter") +async def counter_source(*_): + yield 1 + + +@subscription_type.field("counter") +async def counter_resolve(obj, *_): + return obj + + +schema = make_executable_schema( + """ + type Query { + _unused: String + } + + type Subscription { + counter: Int! + } + """, + subscription_type, +) + +graphql = GraphQL(schema) + +app = Starlette( + routes=[ + Route("/graphql", methods=["POST"], endpoint=graphql.handle_request), + Mount("/mounted", graphql), + ], +) + +client = TestClient(app, headers={"Accept": "text/event-stream"}) + + +def test_run_graphql_subscription_through_route(): + response = client.post( + "/graphql", + json={ + "operationName": None, + "query": "subscription { counter }", + "variables": None, + }, + ) + + assert response.status_code == 200 + assert '{"data": {"counter": 1}}' in response.text + + +def test_run_graphql_subscription_through_mount(): + response = client.post( + "/mounted", + json={ + "operationName": None, + "query": "subscription { counter }", + "variables": None, + }, + ) + + assert response.status_code == 200 + assert '{"data": {"counter": 1}}' in response.text From 7f39e37e91b98b6d2b00887f896e2833b2fd59fe Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Mon, 26 Aug 2024 12:08:17 +0200 Subject: [PATCH 12/23] Fix changelog entry conflicts --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 42e33025..9ecd3548 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ - Added `share_enabled` param to `ExplorerPlayground` to enable share playground feature. - Added support for nested attribute resolution in alias resolvers. - Replaced regexes in the Apollo Federation implementation with cleaner approach using GraphQL AST. +- Added support for subscriptions over a distinct Server-Sent-Events connection as per (https://github.com/enisdenjo/graphql-sse/blob/master/PROTOCOL.md). ## 0.23 (2024-03-18) From 224586e943a254ad77add757113a517e3ac757a7 Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Mon, 26 Aug 2024 13:55:30 +0200 Subject: [PATCH 13/23] fix linting --- ariadne/asgi/handlers/http.py | 17 +++++++++-------- tests/conftest.py | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index e0716719..37933e85 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -153,12 +153,12 @@ class ServerSentEventResponse(Response): def __init__( self, + *args, generator: AsyncGenerator[GraphQLServerSentEvent, Any], send_timeout: Optional[int] = None, ping_interval: Optional[int] = None, headers: Optional[Dict[str, str]] = None, encoding: Optional[str] = None, - *args, **kwargs, ): """Initializes the a SSE Response that send events generated by an async generator @@ -198,8 +198,8 @@ def __init__( @staticmethod async def listen_for_disconnect(receive: Receive) -> None: - """Listens for the client disconnect event and stops the streaming by exiting the infinite loop - this triggers the anyio CancelScope to cancel the TaskGroup + """Listens for the client disconnect event and stops the streaming by exiting the infinite + loop this triggers the anyio CancelScope to cancel the TaskGroup # Required arguments `receive`: the starlette Receive object @@ -207,7 +207,7 @@ async def listen_for_disconnect(receive: Receive) -> None: while True: message = await receive() if message["type"] == "http.disconnect": - logging.debug(f"Got event: http.disconnect. Stop streaming...") + logging.debug("Got event: http.disconnect. Stop streaming...") break def encode_event(self, event: GraphQLServerSentEvent) -> bytes: @@ -219,8 +219,8 @@ def encode_event(self, event: GraphQLServerSentEvent) -> bytes: return str(event).encode(self.encoding) async def _ping(self, send: Send) -> None: - """Sends a ping event to the client every `ping_interval` seconds gets cancelled if the client disconnects - through the anyio CancelScope of the TaskGroup + """Sends a ping event to the client every `ping_interval` seconds gets + cancelled if the client disconnects through the anyio CancelScope of the TaskGroup # Required arguments `send`: the starlette Send object @@ -755,8 +755,9 @@ async def sse_generate_error_response( async def sse_subscribe_to_graphql( self, query_document: DocumentNode, data: Any, context_value: Any ): - """Main SSE subscription generator for the GraphQL query. Yields `GraphQLServerSentEvent` instances - and is to be consumed by a `ServerSentEventResponse` instance + """Main SSE subscription generator for the GraphQL query. + Yields `GraphQLServerSentEvent` instances and is to be consumed by a + `ServerSentEventResponse` instance # Required arguments diff --git a/tests/conftest.py b/tests/conftest.py index 837aa8dc..5dd06e35 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -174,7 +174,7 @@ async def test_root_generator(root, *_): yield {"testRoot": root.get("test")} -async def test_slow_generator(_, info): +async def test_slow_generator(*_): yield {"testSlow": "slow"} await asyncio.sleep(20) yield {"testSlow": "slow"} From 56cceab0e767f487bfe97b7db68fadd33ee7a31e Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Mon, 26 Aug 2024 14:03:46 +0200 Subject: [PATCH 14/23] fix mypy errors --- ariadne/asgi/handlers/http.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index 37933e85..24ad711e 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -114,9 +114,9 @@ def encode_execution_result(self) -> str: Returns the JSON string representation of the execution result """ payload: Dict[str, Any] = {} - if self.result.data: + if self.result is not None and self.result.data is not None: payload["data"] = self.result.data - if self.result.errors: + if self.result is not None and self.result.errors is not None: errors = [] for error in self.result.errors: log_error(error, self.logger) @@ -710,7 +710,7 @@ async def handle_sse_request(self, request: Request) -> Response: """ try: - data = await self.extract_data_from_request(request) + data: Any = await self.extract_data_from_request(request) query = await self.get_query_from_sse_request(request, data) if self.schema is None: @@ -795,7 +795,7 @@ async def sse_subscribe_to_graphql( event="next", result=ExecutionResult( errors=[ - GraphQLError(message=error.get("message")) + GraphQLError(message=cast(str, error.get("message", ""))) for error in error_payload ] ), From 47c3cf3276e1d32419235cdd467b8a885b4829b1 Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Mon, 26 Aug 2024 14:08:15 +0200 Subject: [PATCH 15/23] fix pylint error, only visible in ci pipeline --- ariadne/asgi/handlers/http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index 24ad711e..184282b7 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -181,7 +181,7 @@ def __init__( self.encoding = encoding or "utf-8" self.content = None - _headers: dict[str, str] = {} + _headers: Dict[str, str] = {} if headers is not None: _headers.update(headers) # mandatory for servers-sent events headers From a6a798521dd26777885ced1d9d3a2ff3d6ee0276 Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Thu, 29 Aug 2024 16:18:20 +0200 Subject: [PATCH 16/23] fix content-length header being set --- ariadne/asgi/handlers/http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index 184282b7..b06c5bef 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -179,7 +179,7 @@ def __init__( self.send_timeout = send_timeout self.ping_interval = ping_interval or self.DEFAULT_PING_INTERVAL self.encoding = encoding or "utf-8" - self.content = None + self.body = None _headers: Dict[str, str] = {} if headers is not None: From 3ec7f6c05fda8a6feda53f37ec2405866c96b670 Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Thu, 29 Aug 2024 16:58:54 +0200 Subject: [PATCH 17/23] align ping message with graphql-sse implementation (https://github.com/enisdenjo/graphql-sse/blob/e8bef032422a7d38a670dc6d18204c4f5dfab6c8/src/handler.ts#L516) --- ariadne/asgi/handlers/http.py | 4 ++-- tests/asgi/test_sse.py | 16 ++++++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index b06c5bef..5833597a 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -179,7 +179,7 @@ def __init__( self.send_timeout = send_timeout self.ping_interval = ping_interval or self.DEFAULT_PING_INTERVAL self.encoding = encoding or "utf-8" - self.body = None + self.body = None # type: ignore _headers: Dict[str, str] = {} if headers is not None: @@ -231,7 +231,7 @@ async def _ping(self, send: Send) -> None: await send( { "type": "http.response.body", - "body": self.encode_event(GraphQLServerSentEvent(event="next")), + "body": ":\r\n\r\n".encode(self.encoding), "more_body": True, } ) diff --git a/tests/asgi/test_sse.py b/tests/asgi/test_sse.py index 988a1ad2..27c8cf44 100644 --- a/tests/asgi/test_sse.py +++ b/tests/asgi/test_sse.py @@ -17,11 +17,15 @@ def get_sse_events(response: Response) -> List[Dict[str, Any]]: for event in response.text.split("\r\n\r\n"): if len(event.strip()) == 0: continue - event, data = event.split("\r\n", 1) - event = event.replace("event: ", "") - data = data.replace("data: ", "") - data = json.loads(data) if len(data) > 0 else None - events.append({"event": event, "data": data}) + if "\r\n" not in event: + # ping message + events.append({"event": "", "data": None}) + else: + event, data = event.split("\r\n", 1) + event = event.replace("event: ", "") + data = data.replace("data: ", "") + data = json.loads(data) if len(data) > 0 else None + events.append({"event": event, "data": data}) return events @@ -154,7 +158,7 @@ def test_ping_is_send_sse(sse_client): assert len(events) == 4 assert events[0]["event"] == "next" assert events[0]["data"]["data"] == {"testSlow": "slow"} - assert events[1]["event"] == "next" + assert events[1]["event"] == "" assert events[1]["data"] is None assert events[2]["event"] == "next" assert events[2]["data"]["data"] == {"testSlow": "slow"} From c7724f7c9f8bb2fa144ddb6638fe870c8fc95416 Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Thu, 12 Dec 2024 14:29:39 +0100 Subject: [PATCH 18/23] move sse to separate handler, add tests for configuration options, cleanup --- ariadne/asgi/handlers/http.py | 432 +-------------------- ariadne/contrib/sse.py | 456 +++++++++++++++++++++++ tests/asgi/test_sse.py | 49 ++- tests_integrations/fastapi/test_sse.py | 2 + tests_integrations/starlette/test_sse.py | 3 +- 5 files changed, 523 insertions(+), 419 deletions(-) create mode 100644 ariadne/contrib/sse.py diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index 5833597a..00d183ca 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -1,60 +1,22 @@ -import asyncio import json -import logging -from asyncio import Lock -from functools import partial from http import HTTPStatus from inspect import isawaitable -from io import StringIO -from typing import ( - Any, - Optional, - cast, - Dict, - AsyncGenerator, - Callable, - Awaitable, - Literal, - get_args, - List, - Union, -) -from typing import Type - -from anyio import ( - get_cancelled_exc_class, - CancelScope, - sleep, - move_on_after, - create_task_group, -) -from graphql import DocumentNode -from graphql import MiddlewareManager +from typing import Any, Optional, Type, Union, cast + +from graphql import DocumentNode, MiddlewareManager from starlette.datastructures import UploadFile from starlette.requests import Request -from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse -from starlette.responses import Response +from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse, Response from starlette.types import Receive, Scope, Send -from .base import GraphQLHttpHandlerBase -from ... import format_error from ...constants import ( DATA_TYPE_JSON, DATA_TYPE_MULTIPART, ) -from ...exceptions import HttpBadRequestError -from ...exceptions import HttpError +from ...exceptions import HttpBadRequestError, HttpError from ...explorer import Explorer from ...file_uploads import combine_multipart_data -from ...graphql import ( - ExecutionResult, - GraphQLError, - parse_query, - subscribe, - validate_data, -) from ...graphql import graphql -from ...logger import log_error from ...types import ( ContextValue, ExtensionList, @@ -63,240 +25,7 @@ MiddlewareList, Middlewares, ) - -EVENT_TYPES = Literal["next", "complete"] - - -class GraphQLServerSentEvent: - """GraphQLServerSentEvent is a class that represents a single Server-Sent Event - as defined in the GraphQL SSE Protocol specification - (https://github.com/enisdenjo/graphql-sse/blob/master/PROTOCOL.md) - """ - - DEFAULT_SEPARATOR = "\r\n" - - def __init__( - self, - event: EVENT_TYPES, - result: Optional[ExecutionResult] = None, - ): - """Initializes the Server-Sent Event - # Required arguments - `event`: the type of the event. Either "next" or "complete" - - # Optional arguments - `result`: an `ExecutionResult` or a `dict` that represents the result of the operation - """ - assert event in get_args(EVENT_TYPES), f"Invalid event type: {event}" - self.event = event - self.result = result - self.logger = logging.Logger("GraphQLServerSentEvent") - - def _write_to_buffer( - self, buffer: StringIO, name: str, value: Optional[str] - ) -> StringIO: - """Writes a SSE field to the buffered SSE event representation - - Returns the `StringIO` buffer with the field written to it - - # Required arguments - `buffer`: the `StringIO` buffer to write to - `name`: the name of the field - `value`: the value of the field - """ - if value is not None: - buffer.write(f"{name}: {value}{self.DEFAULT_SEPARATOR}") - return buffer - - def encode_execution_result(self) -> str: - """Encodes the execution result into a single line JSON string - - Returns the JSON string representation of the execution result - """ - payload: Dict[str, Any] = {} - if self.result is not None and self.result.data is not None: - payload["data"] = self.result.data - if self.result is not None and self.result.errors is not None: - errors = [] - for error in self.result.errors: - log_error(error, self.logger) - errors.append(format_error(error)) - payload["errors"] = errors - - return json.dumps(payload) - - def __str__(self) -> str: - """Returns the string representation of the Server-Sent Event""" - buffer = StringIO() - buffer = self._write_to_buffer(buffer, "event", self.event) - buffer = self._write_to_buffer( - buffer, - "data", - ( - self.encode_execution_result() - if self.event == "next" and self.result - else "" - ), - ) - buffer.write(self.DEFAULT_SEPARATOR) - - return buffer.getvalue() - - -class ServerSentEventResponse(Response): - """Sends GraphQL SSE events using EvenSource protocol using Starlette's Response class - based on the implementation https://github.com/sysid/sse-starlette/ - """ - - # Sends a ping event to the client every 15 seconds to overcome proxy timeout issues - DEFAULT_PING_INTERVAL = 15 - - def __init__( - self, - *args, - generator: AsyncGenerator[GraphQLServerSentEvent, Any], - send_timeout: Optional[int] = None, - ping_interval: Optional[int] = None, - headers: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - **kwargs, - ): - """Initializes the a SSE Response that send events generated by an async generator - - # Required arguments - `generator`: an async generator that yields `GraphQLServerSentEvent` objects - - # Optional arguments - `send_timeout`: the timeout in seconds to send an event to the client - `ping_interval`: the interval in seconds to send a ping event to the client, overrides - the DEFAULT_PING_INTERVAL of 15 seconds - `headers`: a dictionary of headers to be sent with the response - `encoding`: the encoding to use for the response - """ - super().__init__(*args, **kwargs) - self.generator = generator - self.status_code = HTTPStatus.OK - self.send_timeout = send_timeout - self.ping_interval = ping_interval or self.DEFAULT_PING_INTERVAL - self.encoding = encoding or "utf-8" - self.body = None # type: ignore - - _headers: Dict[str, str] = {} - if headers is not None: - _headers.update(headers) - # mandatory for servers-sent events headers - # allow cache control header to be set by user to support fan out proxies - # https://www.fastly.com/blog/server-sent-events-fastly - _headers.setdefault("Cache-Control", "no-cache") - _headers.setdefault("Connection", "keep-alive") - _headers.setdefault("X-Accel-Buffering", "no") - _headers.setdefault("Transfer-Encoding", "chunked") - self.media_type = "text/event-stream" - self.init_headers(_headers) - - self._send_lock = Lock() - - @staticmethod - async def listen_for_disconnect(receive: Receive) -> None: - """Listens for the client disconnect event and stops the streaming by exiting the infinite - loop this triggers the anyio CancelScope to cancel the TaskGroup - - # Required arguments - `receive`: the starlette Receive object - """ - while True: - message = await receive() - if message["type"] == "http.disconnect": - logging.debug("Got event: http.disconnect. Stop streaming...") - break - - def encode_event(self, event: GraphQLServerSentEvent) -> bytes: - """Encodes the GraphQLServerSentEvent into a bytes object - - # Required arguments - `event`: the GraphQLServerSentEvent object - """ - return str(event).encode(self.encoding) - - async def _ping(self, send: Send) -> None: - """Sends a ping event to the client every `ping_interval` seconds gets - cancelled if the client disconnects through the anyio CancelScope of the TaskGroup - - # Required arguments - `send`: the starlette Send object - """ - while True: - await sleep(self.ping_interval) - async with self._send_lock: - await send( - { - "type": "http.response.body", - "body": ":\r\n\r\n".encode(self.encoding), - "more_body": True, - } - ) - - async def send_events(self, send: Send) -> None: - """Sends the events generated by the async generator to the client - - # Required arguments - `send`: the starlette Send object - - """ - async with self._send_lock: - await send( - { - "type": "http.response.start", - "status": self.status_code, - "headers": self.raw_headers, - } - ) - - try: - async for event in self.generator: - async with self._send_lock: - with move_on_after(self.send_timeout) as timeout: - await send( - { - "type": "http.response.body", - "body": self.encode_event(event), - "more_body": True, - } - ) - - if timeout.cancel_called: - raise asyncio.TimeoutError() - - except (get_cancelled_exc_class(),) as e: - logging.warning(e) - finally: - with CancelScope(shield=True): - async with self._send_lock: - await send( - {"type": "http.response.body", "body": b"", "more_body": False} - ) - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """The main entrypoint for the ServerSentEventResponse which is called by starlette - - # Required arguments - `scope`: the starlette Scope object - - `receive`: the starlette Receive object - - `send`: the starlette Send object - - """ - async with create_task_group() as task_group: - - async def wrap_cancelling(func: Callable[[], Awaitable[None]]) -> None: - await func() - task_group.cancel_scope.cancel() - - task_group.start_soon(wrap_cancelling, partial(self._ping, send)) - task_group.start_soon(wrap_cancelling, partial(self.send_events, send)) - # this will cancel the task group when the client disconnects - await wrap_cancelling(partial(self.listen_for_disconnect, receive)) +from .base import GraphQLHttpHandlerBase class GraphQLHTTPHandler(GraphQLHttpHandlerBase): @@ -364,6 +93,16 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: response = await self.handle_request(request) await response(scope, receive, send) + async def handle_request_override(self, request: Request) -> Response | None: + """Override the default request handling logic in subclasses. + Is called in the `handle_request` method before the default logic. + If None is returned, the default logic is executed. + + # Required arguments: + `request`: the `Request` instance from Starlette or FastAPI. + """ + return None + async def handle_request(self, request: Request) -> Response: """Handle GraphQL request and return response for the client. @@ -386,6 +125,10 @@ async def handle_request(self, request: Request) -> Response: `request`: the `Request` instance from Starlette or FastAPI. """ + response = await self.handle_request_override(request) + if response is not None: + return response + if request.method == "GET": if self.execute_get_queries and request.query_params.get("query"): return await self.graphql_http_server(request) @@ -394,10 +137,6 @@ async def handle_request(self, request: Request) -> Response: return await self.render_explorer(request, self.explorer) if request.method == "POST": - accept = request.headers.get("Accept", "") - accept = accept.split(",")[0] - if accept == "text/event-stream": - return await self.handle_sse_request(request) return await self.graphql_http_server(request) return self.handle_not_allowed_method(request) @@ -700,134 +439,3 @@ def handle_not_allowed_method(self, request: Request): return Response(headers=allow_header) return Response(status_code=HTTPStatus.METHOD_NOT_ALLOWED, headers=allow_header) - - async def handle_sse_request(self, request: Request) -> Response: - """Handles the HTTP request with GraphQL Subscription query using Server-Sent Events. - - # Required arguments - - `request`: the starlette `Request` instance - """ - - try: - data: Any = await self.extract_data_from_request(request) - query = await self.get_query_from_sse_request(request, data) - - if self.schema is None: - raise TypeError( - "schema is not set, call configure method to initialize it" - ) - - validate_data(data) - context_value = await self.get_context_for_request(request, data) - return ServerSentEventResponse( - generator=self.sse_subscribe_to_graphql(query, data, context_value) - ) - except (HttpError, TypeError, GraphQLError) as error: - log_error(error, self.logger) - if not isinstance(error, GraphQLError): - error_message = ( - (error.message or error.status) - if isinstance(error, HttpError) - else str(error) - ) - error = GraphQLError(error_message, original_error=error) - return ServerSentEventResponse( - generator=self.sse_generate_error_response([error]) - ) - - async def sse_generate_error_response( - self, errors: List[GraphQLError] - ) -> AsyncGenerator[GraphQLServerSentEvent, Any]: - """A Server-Sent Event response generator for the errors - To be passed to a ServerSentEventResponse instance - - # Required arguments - - `errors`: a list of `GraphQLError` instances - """ - - yield GraphQLServerSentEvent( - event="next", result=ExecutionResult(errors=errors) - ) - yield GraphQLServerSentEvent(event="complete") - - async def sse_subscribe_to_graphql( - self, query_document: DocumentNode, data: Any, context_value: Any - ): - """Main SSE subscription generator for the GraphQL query. - Yields `GraphQLServerSentEvent` instances and is to be consumed by a - `ServerSentEventResponse` instance - - # Required arguments - - `query_document`: an already parsed GraphQL query. - - `data`: a `dict` with query data (`query` string, optionally `operationName` - string and `variables` dictionary). - - `context_value`: a context value to make accessible as 'context' attribute - of second argument (`info`) passed to resolvers and source functions. - """ - - success, results = await subscribe( - self.schema, # type: ignore - data, - context_value=context_value, - root_value=self.root_value, - query_document=query_document, - query_validator=self.query_validator, - validation_rules=self.validation_rules, - debug=self.debug, - introspection=self.introspection, - logger=self.logger, - error_formatter=self.error_formatter, - ) - - if not success: - if not isinstance(results, list): - error_payload = cast(List[dict], [results]) - else: - error_payload = results - - # This needs to be handled better, subscribe returns preformatted errors - yield GraphQLServerSentEvent( - event="next", - result=ExecutionResult( - errors=[ - GraphQLError(message=cast(str, error.get("message", ""))) - for error in error_payload - ] - ), - ) - else: - results = cast(AsyncGenerator, results) - try: - async for result in results: - yield GraphQLServerSentEvent(event="next", result=result) - except (Exception, GraphQLError) as error: - if not isinstance(error, GraphQLError): - error = GraphQLError(str(error), original_error=error) - log_error(error, self.logger) - yield GraphQLServerSentEvent( - event="next", result=ExecutionResult(errors=[error]) - ) - - yield GraphQLServerSentEvent(event="complete") - - async def get_query_from_sse_request( - self, request: Request, data: Any - ) -> DocumentNode: - """Extracts GraphQL query from SSE request. - - Returns a `DocumentNode` with parsed query. - - # Required arguments - - `request`: the starlette `Request` instance - - `data`: an additional data parameter to potentially extract the query from - """ - - context_value = await self.get_context_for_request(request, data) - return parse_query(context_value, self.query_parser, data) diff --git a/ariadne/contrib/sse.py b/ariadne/contrib/sse.py new file mode 100644 index 00000000..2ce52191 --- /dev/null +++ b/ariadne/contrib/sse.py @@ -0,0 +1,456 @@ +import asyncio +import json +import logging +from asyncio import Lock +from functools import partial +from http import HTTPStatus +from io import StringIO +from typing import ( + Any, + Optional, + cast, + AsyncGenerator, + List, + Literal, + get_args, + Dict, + Callable, + Awaitable, + Type, +) + +from anyio import ( + get_cancelled_exc_class, + CancelScope, + sleep, + move_on_after, + create_task_group, +) +from graphql import DocumentNode +from graphql import MiddlewareManager +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +from .. import format_error +from ..asgi.handlers import GraphQLHTTPHandler +from ..exceptions import HttpError +from ..graphql import ( + ExecutionResult, + GraphQLError, + parse_query, + subscribe, + validate_data, +) +from ..logger import log_error +from ..types import Extensions, Middlewares + +EVENT_TYPES = Literal["next", "complete"] + + +class GraphQLServerSentEvent: + """GraphQLServerSentEvent is a class that represents a single Server-Sent Event + as defined in the GraphQL SSE Protocol specification + (https://github.com/enisdenjo/graphql-sse/blob/master/PROTOCOL.md) + """ + + DEFAULT_SEPARATOR = "\r\n" + + def __init__( + self, + event: EVENT_TYPES, + result: Optional[ExecutionResult] = None, + ): + """Initializes the Server-Sent Event + # Required arguments + `event`: the type of the event. Either "next" or "complete" + + # Optional arguments + `result`: an `ExecutionResult` or a `dict` that represents the result of the operation + """ + assert event in get_args(EVENT_TYPES), f"Invalid event type: {event}" + self.event = event + self.result = result + self.logger = logging.Logger("GraphQLServerSentEvent") + + def __str__(self) -> str: + """Returns the string representation of the Server-Sent Event""" + buffer = StringIO() + buffer = self._write_to_buffer(buffer, "event", self.event) + buffer = self._write_to_buffer( + buffer, + "data", + ( + self.encode_execution_result() + if self.event == "next" and self.result + else "" + ), + ) + buffer.write(self.DEFAULT_SEPARATOR) + + return buffer.getvalue() + + def _write_to_buffer( + self, buffer: StringIO, name: str, value: Optional[str] + ) -> StringIO: + """Writes an SSE field to the buffered SSE event representation + + Returns the `StringIO` buffer with the field written to it + + # Required arguments + `buffer`: the `StringIO` buffer to write to + `name`: the name of the field + `value`: the value of the field + """ + if value is not None: + buffer.write(f"{name}: {value}{self.DEFAULT_SEPARATOR}") + return buffer + + def encode_execution_result(self) -> str: + """Encodes the execution result into a single line JSON string + + Returns the JSON string representation of the execution result + """ + payload: Dict[str, Any] = {} + if self.result is not None and self.result.data is not None: + payload["data"] = self.result.data + if self.result is not None and self.result.errors is not None: + errors = [] + for error in self.result.errors: + errors.append(format_error(error)) + payload["errors"] = errors + + return json.dumps(payload) + + +class ServerSentEventResponse(Response): + """Sends GraphQL SSE events using the EventSource protocol using Starlette's Response class + based on the implementation https://github.com/sysid/sse-starlette/ + """ + + # Sends a ping event to the client every 15 seconds to overcome proxy timeout issues + DEFAULT_PING_INTERVAL = 15 + + def __init__( + self, + *args, + generator: AsyncGenerator[GraphQLServerSentEvent, Any], + send_timeout: Optional[int] = None, + ping_interval: Optional[int] = None, + headers: Optional[Dict[str, str]] = None, + **kwargs, + ): + """Initializes an SSE Response that sends events generated by an async generator + + # Required arguments + `generator`: an async generator that yields `GraphQLServerSentEvent` objects + + # Optional arguments + `send_timeout`: the timeout in seconds to send an event to the client + `ping_interval`: the interval in seconds to send a ping event to the client, overrides + the DEFAULT_PING_INTERVAL of 15 seconds + `headers`: a dictionary of headers to be sent with the response + `encoding`: the encoding to use for the response + """ + super().__init__(*args, **kwargs) + self.generator = generator + self.status_code = HTTPStatus.OK + self.send_timeout = send_timeout + self.ping_interval = ping_interval or self.DEFAULT_PING_INTERVAL + self.body = None # type: ignore + + _headers: Dict[str, str] = {} + if headers is not None: + _headers.update(headers) + # mandatory for servers-sent events headers + # allow cache control header to be set by user to support fan out proxies + # https://www.fastly.com/blog/server-sent-events-fastly + _headers.setdefault("Cache-Control", "no-cache") + _headers.setdefault("Connection", "keep-alive") + _headers.setdefault("X-Accel-Buffering", "no") + _headers.setdefault("Transfer-Encoding", "chunked") + self.media_type = "text/event-stream" + self.init_headers(_headers) + + self._send_lock = Lock() + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """The main entrypoint for the ServerSentEventResponse which is called by starlette + + # Required arguments + `scope`: the starlette Scope object + + `receive`: the starlette Receive object + + `send`: the starlette Send object + + """ + async with create_task_group() as task_group: + + async def wrap_cancelling(func: Callable[[], Awaitable[None]]) -> None: + await func() + task_group.cancel_scope.cancel() + + task_group.start_soon(wrap_cancelling, partial(self._ping, send)) + task_group.start_soon(wrap_cancelling, partial(self.send_events, send)) + # this will cancel the task group when the client disconnects + await wrap_cancelling(partial(self.listen_for_disconnect, receive)) + + async def _ping(self, send: Send) -> None: + """Sends a ping event to the client every `ping_interval` seconds gets + cancelled if the client disconnects through the anyio CancelScope of the TaskGroup + + # Required arguments + `send`: the starlette Send object + """ + while True: + await sleep(self.ping_interval) + async with self._send_lock: + await send( + { + "type": "http.response.body", + # always encode as utf-8 as per https://html.spec.whatwg.org/multipage/server-sent-events.html#sse-processing-model + "body": ":\r\n\r\n".encode("utf-8"), + "more_body": True, + } + ) + + async def send_events(self, send: Send) -> None: + """Sends the events generated by the async generator to the client + + # Required arguments + `send`: the starlette Send object + + """ + async with self._send_lock: + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + + try: + async for event in self.generator: + async with self._send_lock: + with move_on_after(self.send_timeout) as timeout: + await send( + { + "type": "http.response.body", + "body": self.encode_event(event), + "more_body": True, + } + ) + + if timeout.cancel_called: + raise asyncio.TimeoutError() + + except (get_cancelled_exc_class(),) as e: + logging.warning(e) + finally: + with CancelScope(shield=True): + async with self._send_lock: + await send( + {"type": "http.response.body", "body": b"", "more_body": False} + ) + + @staticmethod + async def listen_for_disconnect(receive: Receive) -> None: + """Listens for the client disconnect event and stops the streaming by exiting the infinite + loop. This triggers the anyio CancelScope to cancel the TaskGroup + + # Required arguments + `receive`: the starlette Receive object + """ + while True: + message = await receive() + if message["type"] == "http.disconnect": + logging.debug("Got event: http.disconnect. Stop streaming...") + break + + @staticmethod + def encode_event(event: GraphQLServerSentEvent) -> bytes: + """Encodes the GraphQLServerSentEvent into a bytes object + + # Required arguments + `event`: the GraphQLServerSentEvent object + """ + # always encode as utf-8 as per https://html.spec.whatwg.org/multipage/server-sent-events.html#sse-processing-model + return str(event).encode("utf-8") + + +class GraphQLHTTPSSEHandler(GraphQLHTTPHandler): + """Extension to the default GraphQLHTTPHandler to also handle Server-Sent Events as per + the GraphQL SSE Protocol specification. This handler only supports the defined `Distinct connections mode` + due to its statelessness. This implementation is based on the specification as of commit + 80cf75b5952d1a065c95bdbd6a74304c90dbe2c5. For more information see the specification + (https://github.com/enisdenjo/graphql-sse/blob/master/PROTOCOL.md) + """ + + def __init__( + self, + extensions: Optional[Extensions] = None, + middleware: Optional[Middlewares] = None, + middleware_manager_class: Optional[Type[MiddlewareManager]] = None, + send_timeout: Optional[int] = None, + ping_interval: Optional[int] = None, + default_response_headers: Optional[Dict[str, str]] = None, + ): + super().__init__(extensions, middleware, middleware_manager_class) + self.send_timeout = send_timeout + self.ping_interval = ping_interval + self.default_response_headers = default_response_headers + + async def handle_request_override(self, request: Request) -> Response | None: + """Overrides the handle_request_override method to handle Server-Sent Events + + # Required arguments + `request`: the starlette or FastAPI `Request` instance + + """ + + if request.method == "POST": + accept = request.headers.get("Accept", "").split(",") + accept = [a.strip() for a in accept] + if "text/event-stream" in accept: + return await self.handle_sse_request(request) + return None + + async def handle_sse_request(self, request: Request) -> Response: + """Handles the HTTP request with GraphQL Subscription query using Server-Sent Events. + + # Required arguments + + `request`: the starlette `Request` instance + """ + + try: + data: Any = await self.extract_data_from_request(request) + query = await self.get_query_from_sse_request(request, data) + + if self.schema is None: + raise TypeError( + "schema is not set, call configure method to initialize it" + ) + + validate_data(data) + context_value = await self.get_context_for_request(request, data) + return ServerSentEventResponse( + generator=self.sse_subscribe_to_graphql(query, data, context_value), + ping_interval=self.ping_interval, + send_timeout=self.send_timeout, + headers=self.default_response_headers, + ) + except (HttpError, TypeError, GraphQLError) as error: + log_error(error, self.logger) + if not isinstance(error, GraphQLError): + error_message = ( + (error.message or error.status) + if isinstance(error, HttpError) + else str(error) + ) + error = GraphQLError(error_message, original_error=error) + return ServerSentEventResponse( + generator=self.sse_generate_error_response([error]), + ping_interval=self.ping_interval, + send_timeout=self.send_timeout, + headers=self.default_response_headers, + ) + + async def get_query_from_sse_request( + self, request: Request, data: Any + ) -> DocumentNode: + """Extracts GraphQL query from SSE request. + + Returns a `DocumentNode` with parsed query. + + # Required arguments + + `request`: the starlette `Request` instance + + `data`: an additional data parameter to potentially extract the query from + """ + + context_value = await self.get_context_for_request(request, data) + return parse_query(context_value, self.query_parser, data) + + async def sse_subscribe_to_graphql( + self, query_document: DocumentNode, data: Any, context_value: Any + ): + """Main SSE subscription generator for the GraphQL query. + Yields `GraphQLServerSentEvent` instances and is to be consumed by a + `ServerSentEventResponse` instance + + # Required arguments + + `query_document`: an already parsed GraphQL query. + + `data`: a `dict` with query data (`query` string, optionally `operationName` + string and `variables` dictionary). + + `context_value`: a context value to make accessible as 'context' attribute + of second argument (`info`) passed to resolvers and source functions. + """ + + success, results = await subscribe( + self.schema, # type: ignore + data, + context_value=context_value, + root_value=self.root_value, + query_document=query_document, + query_validator=self.query_validator, + validation_rules=self.validation_rules, + debug=self.debug, + introspection=self.introspection, + logger=self.logger, + error_formatter=self.error_formatter, + ) + + if not success: + if not isinstance(results, list): + error_payload = cast(List[dict], [results]) + else: + error_payload = results + + # This needs to be handled better, subscribe returns preformatted errors + yield GraphQLServerSentEvent( + event="next", + result=ExecutionResult( + errors=[ + GraphQLError(message=cast(str, error.get("message", ""))) + for error in error_payload + ] + ), + ) + else: + results = cast(AsyncGenerator, results) + try: + async for result in results: + yield GraphQLServerSentEvent(event="next", result=result) + except (Exception, GraphQLError) as error: + if not isinstance(error, GraphQLError): + error = GraphQLError(str(error), original_error=error) + log_error(error, self.logger) + yield GraphQLServerSentEvent( + event="next", result=ExecutionResult(errors=[error]) + ) + + yield GraphQLServerSentEvent(event="complete") + + @staticmethod + async def sse_generate_error_response( + errors: List[GraphQLError], + ) -> AsyncGenerator[GraphQLServerSentEvent, Any]: + """A Server-Sent Event response generator for the errors + To be passed to a ServerSentEventResponse instance + + # Required arguments + + `errors`: a list of `GraphQLError` instances + """ + + yield GraphQLServerSentEvent( + event="next", result=ExecutionResult(errors=errors) + ) + yield GraphQLServerSentEvent(event="complete") diff --git a/tests/asgi/test_sse.py b/tests/asgi/test_sse.py index 27c8cf44..ad7ed216 100644 --- a/tests/asgi/test_sse.py +++ b/tests/asgi/test_sse.py @@ -2,12 +2,14 @@ from http import HTTPStatus from typing import List, Dict, Any from unittest.mock import Mock -from graphql import parse, GraphQLError -from starlette.testclient import TestClient -from httpx import Response import pytest +from graphql import parse, GraphQLError +from httpx import Response +from starlette.testclient import TestClient + from ariadne.asgi import GraphQL +from ariadne.contrib.sse import GraphQLHTTPSSEHandler SSE_HEADER = {"Accept": "text/event-stream"} @@ -31,7 +33,13 @@ def get_sse_events(response: Response) -> List[Dict[str, Any]]: @pytest.fixture def sse_client(schema): - app = GraphQL(schema, introspection=False) + app = GraphQL( + schema, + http_handler=GraphQLHTTPSSEHandler( + default_response_headers={"Test_Header": "test"} + ), + introspection=False, + ) return TestClient(app, headers=SSE_HEADER) @@ -78,6 +86,7 @@ def test_custom_query_parser_is_used_for_subscription_over_sse(schema): mock_parser = Mock(return_value=parse("subscription { testContext }")) app = GraphQL( schema, + http_handler=GraphQLHTTPSSEHandler(), query_parser=mock_parser, context_value={"test": "I'm context"}, root_value={"test": "I'm root"}, @@ -87,6 +96,7 @@ def test_custom_query_parser_is_used_for_subscription_over_sse(schema): response = client.post("/", json={"query": "subscription { testRoot }"}) events = get_sse_events(response) + print(response) assert len(events) == 2 assert events[0]["data"]["data"] == {"testContext": "I'm context"} assert events[1]["event"] == "complete" @@ -103,6 +113,7 @@ def test_custom_query_validator_is_used_for_subscription_over_sse(schema, errors mock_validator = Mock(return_value=errors) app = GraphQL( schema, + http_handler=GraphQLHTTPSSEHandler(), query_validator=mock_validator, context_value={"test": "I'm context"}, root_value={"test": "I'm root"}, @@ -132,7 +143,7 @@ def test_custom_query_validator_is_used_for_subscription_over_sse(schema, errors def test_schema_not_set_graphql_sse(): - app = GraphQL(None) + app = GraphQL(None, http_handler=GraphQLHTTPSSEHandler()) client = TestClient(app, headers=SSE_HEADER) response = client.post( @@ -158,16 +169,42 @@ def test_ping_is_send_sse(sse_client): assert len(events) == 4 assert events[0]["event"] == "next" assert events[0]["data"]["data"] == {"testSlow": "slow"} - assert events[1]["event"] == "" + assert events[1]["event"] == "" # ping assert events[1]["data"] is None assert events[2]["event"] == "next" assert events[2]["data"]["data"] == {"testSlow": "slow"} assert events[3]["event"] == "complete" +def test_custom_ping_interval(schema): + app = GraphQL( + schema, + http_handler=GraphQLHTTPSSEHandler(ping_interval=10), + introspection=False, + ) + sse_client = TestClient(app, headers=SSE_HEADER) + response = sse_client.post("/", json={"query": "subscription { testSlow }"}) + events = get_sse_events(response) + assert len(events) == 5 + assert events[0]["event"] == "next" + assert events[0]["data"]["data"] == {"testSlow": "slow"} + assert events[1]["event"] == "" # ping + assert events[1]["data"] is None + assert events[2]["event"] == "" # second ping + assert events[2]["data"] is None + assert events[3]["event"] == "next" + assert events[3]["data"]["data"] == {"testSlow": "slow"} + assert events[4]["event"] == "complete" + + def test_resolver_error_is_handled_sse(sse_client): response = sse_client.post("/", json={"query": "subscription { resolverError }"}) events = get_sse_events(response) assert len(events) == 2 assert events[0]["data"]["errors"][0]["message"] == "Test exception" assert events[1]["event"] == "complete" + + +def test_default_headers_are_applied(sse_client): + response = sse_client.post("/", json={"query": "subscription { ping }"}) + assert response.headers["Test_Header"] == "test" diff --git a/tests_integrations/fastapi/test_sse.py b/tests_integrations/fastapi/test_sse.py index eae596c4..400d969b 100644 --- a/tests_integrations/fastapi/test_sse.py +++ b/tests_integrations/fastapi/test_sse.py @@ -4,6 +4,7 @@ from ariadne import SubscriptionType, make_executable_schema from ariadne.asgi import GraphQL from ariadne.asgi.handlers import GraphQLTransportWSHandler +from ariadne.contrib.sse import GraphQLHTTPSSEHandler subscription_type = SubscriptionType() @@ -34,6 +35,7 @@ async def counter_resolve(obj, *_): app = FastAPI() graphql = GraphQL( schema, + http_handler=GraphQLHTTPSSEHandler(), websocket_handler=GraphQLTransportWSHandler(), ) diff --git a/tests_integrations/starlette/test_sse.py b/tests_integrations/starlette/test_sse.py index 58caa934..a65430b0 100644 --- a/tests_integrations/starlette/test_sse.py +++ b/tests_integrations/starlette/test_sse.py @@ -4,6 +4,7 @@ from ariadne import SubscriptionType, make_executable_schema from ariadne.asgi import GraphQL +from ariadne.contrib.sse import GraphQLHTTPSSEHandler subscription_type = SubscriptionType() @@ -31,7 +32,7 @@ async def counter_resolve(obj, *_): subscription_type, ) -graphql = GraphQL(schema) +graphql = GraphQL(schema, http_handler=GraphQLHTTPSSEHandler()) app = Starlette( routes=[ From f2ada34469c3d791371ffc485367e7a3d4941744 Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Thu, 12 Dec 2024 14:34:10 +0100 Subject: [PATCH 19/23] remove pipe symbol for compatibility with older python versions --- ariadne/asgi/handlers/http.py | 2 +- ariadne/contrib/sse.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index 00d183ca..8d261f60 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -93,7 +93,7 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: response = await self.handle_request(request) await response(scope, receive, send) - async def handle_request_override(self, request: Request) -> Response | None: + async def handle_request_override(self, request: Request) -> Optional[Response]: """Override the default request handling logic in subclasses. Is called in the `handle_request` method before the default logic. If None is returned, the default logic is executed. diff --git a/ariadne/contrib/sse.py b/ariadne/contrib/sse.py index 2ce52191..e6b4be0c 100644 --- a/ariadne/contrib/sse.py +++ b/ariadne/contrib/sse.py @@ -302,7 +302,7 @@ def __init__( self.ping_interval = ping_interval self.default_response_headers = default_response_headers - async def handle_request_override(self, request: Request) -> Response | None: + async def handle_request_override(self, request: Request) -> Optional[Response]: """Overrides the handle_request_override method to handle Server-Sent Events # Required arguments From 5a18fc393449e7aeaf5154a627d357f3b4ed6b5a Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Thu, 12 Dec 2024 15:22:11 +0100 Subject: [PATCH 20/23] format and linting, update requirements.txt in fastapi integration tests to fix Starlette TestClient error, manually set anyio dependency lower for python 3.8 integration tests as current version is not available for 3.8 --- ariadne/asgi/handlers/http.py | 4 ++-- ariadne/contrib/sse.py | 13 ++++++++----- tests/asgi/test_sse.py | 2 +- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index 8d261f60..9432fecc 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -93,13 +93,13 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: response = await self.handle_request(request) await response(scope, receive, send) - async def handle_request_override(self, request: Request) -> Optional[Response]: + async def handle_request_override(self, _: Request) -> Optional[Response]: """Override the default request handling logic in subclasses. Is called in the `handle_request` method before the default logic. If None is returned, the default logic is executed. # Required arguments: - `request`: the `Request` instance from Starlette or FastAPI. + `_`: the `Request` instance from Starlette or FastAPI. """ return None diff --git a/ariadne/contrib/sse.py b/ariadne/contrib/sse.py index e6b4be0c..cbc5e506 100644 --- a/ariadne/contrib/sse.py +++ b/ariadne/contrib/sse.py @@ -209,7 +209,8 @@ async def _ping(self, send: Send) -> None: await send( { "type": "http.response.body", - # always encode as utf-8 as per https://html.spec.whatwg.org/multipage/server-sent-events.html#sse-processing-model + # always encode as utf-8 as per + # https://html.spec.whatwg.org/multipage/server-sent-events.html#sse-processing-model "body": ":\r\n\r\n".encode("utf-8"), "more_body": True, } @@ -276,15 +277,17 @@ def encode_event(event: GraphQLServerSentEvent) -> bytes: # Required arguments `event`: the GraphQLServerSentEvent object """ - # always encode as utf-8 as per https://html.spec.whatwg.org/multipage/server-sent-events.html#sse-processing-model + # always encode as utf-8 as per + # https://html.spec.whatwg.org/multipage/server-sent-events.html#sse-processing-model return str(event).encode("utf-8") class GraphQLHTTPSSEHandler(GraphQLHTTPHandler): """Extension to the default GraphQLHTTPHandler to also handle Server-Sent Events as per - the GraphQL SSE Protocol specification. This handler only supports the defined `Distinct connections mode` - due to its statelessness. This implementation is based on the specification as of commit - 80cf75b5952d1a065c95bdbd6a74304c90dbe2c5. For more information see the specification + the GraphQL SSE Protocol specification. This handler only supports the defined + `Distinct connections mode` due to its statelessness. This implementation is based on + the specification as of commit 80cf75b5952d1a065c95bdbd6a74304c90dbe2c5. + For more information see the specification (https://github.com/enisdenjo/graphql-sse/blob/master/PROTOCOL.md) """ diff --git a/tests/asgi/test_sse.py b/tests/asgi/test_sse.py index ad7ed216..05a67257 100644 --- a/tests/asgi/test_sse.py +++ b/tests/asgi/test_sse.py @@ -179,7 +179,7 @@ def test_ping_is_send_sse(sse_client): def test_custom_ping_interval(schema): app = GraphQL( schema, - http_handler=GraphQLHTTPSSEHandler(ping_interval=10), + http_handler=GraphQLHTTPSSEHandler(ping_interval=8), introspection=False, ) sse_client = TestClient(app, headers=SSE_HEADER) From 753cb350109c18ba4dbaf7335733dfeb4cf38f11 Mon Sep 17 00:00:00 2001 From: Dan Plischke Date: Fri, 20 Dec 2024 19:38:40 +0100 Subject: [PATCH 21/23] remove changelog entry as 0.24 has already been published --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9641dde2..a667d160 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,6 @@ - Added `share_enabled` param to `ExplorerPlayground` to enable share playground feature. - Added support for nested attribute resolution in alias resolvers. - Replaced regexes in the Apollo Federation implementation with cleaner approach using GraphQL AST. -- Added support for subscriptions over a distinct Server-Sent-Events connection as per (https://github.com/enisdenjo/graphql-sse/blob/master/PROTOCOL.md). ## 0.23 (2024-03-18) From 8f9e338e00f0f741f29c5d1782a40120f21e86b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Pito=C5=84?= Date: Wed, 29 Jan 2025 20:08:12 +0100 Subject: [PATCH 22/23] update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a667d160..8c622c76 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,10 @@ # CHANGELOG +## 0.25 (UNRELEASED) + +- Added support for GraphQL subscriptions over the Server-Sent Events (SSE). + ## 0.24 (2024-12-19) - Added validation for directive declarations in `make_executable_schema` to prevent schema creation with undeclared directives. From ad491e4bfb9197ae1b8d5e4337fe0a3d79835890 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Pito=C5=84?= Date: Wed, 29 Jan 2025 20:08:28 +0100 Subject: [PATCH 23/23] Markdown formatting --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c622c76..5b722df8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - Added support for GraphQL subscriptions over the Server-Sent Events (SSE). + ## 0.24 (2024-12-19) - Added validation for directive declarations in `make_executable_schema` to prevent schema creation with undeclared directives.