Skip to content

Commit e30f166

Browse files
committedNov 20, 2022
subscribe: stay synchronous when possible
This (breaking!) change aligns the return types of `execute` and `subscribe` (as well as `create_source_event_stream`) with respect to returning values or awaitables. Replicates graphql/graphql-js@6d42ced
1 parent 5950470 commit e30f166

File tree

4 files changed

+213
-72
lines changed

4 files changed

+213
-72
lines changed
 

‎src/graphql/execution/subscribe.py

+79-24
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
from inspect import isawaitable
2-
from typing import Any, AsyncIterable, AsyncIterator, Dict, Optional, Type, Union
2+
from typing import (
3+
Any,
4+
AsyncIterable,
5+
AsyncIterator,
6+
Awaitable,
7+
Dict,
8+
Optional,
9+
Type,
10+
Union,
11+
cast,
12+
)
313

414
from ..error import GraphQLError, located_error
515
from ..execution.collect_fields import collect_fields
@@ -11,15 +21,15 @@
1121
)
1222
from ..execution.values import get_argument_values
1323
from ..language import DocumentNode
14-
from ..pyutils import Path, inspect
24+
from ..pyutils import AwaitableOrValue, Path, inspect
1525
from ..type import GraphQLFieldResolver, GraphQLSchema
1626
from .map_async_iterator import MapAsyncIterator
1727

1828

1929
__all__ = ["subscribe", "create_source_event_stream"]
2030

2131

22-
async def subscribe(
32+
def subscribe(
2333
schema: GraphQLSchema,
2434
document: DocumentNode,
2535
root_value: Any = None,
@@ -29,7 +39,7 @@ async def subscribe(
2939
field_resolver: Optional[GraphQLFieldResolver] = None,
3040
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
3141
execution_context_class: Optional[Type[ExecutionContext]] = None,
32-
) -> Union[AsyncIterator[ExecutionResult], ExecutionResult]:
42+
) -> AwaitableOrValue[Union[AsyncIterator[ExecutionResult], ExecutionResult]]:
3343
"""Create a GraphQL subscription.
3444
3545
Implements the "Subscribe" algorithm described in the GraphQL spec.
@@ -49,7 +59,7 @@ async def subscribe(
4959
If the operation succeeded, the coroutine will yield an AsyncIterator, which yields
5060
a stream of ExecutionResults representing the response stream.
5161
"""
52-
result_or_stream = await create_source_event_stream(
62+
result_or_stream = create_source_event_stream(
5363
schema,
5464
document,
5565
root_value,
@@ -59,8 +69,6 @@ async def subscribe(
5969
subscribe_field_resolver,
6070
execution_context_class,
6171
)
62-
if isinstance(result_or_stream, ExecutionResult):
63-
return result_or_stream
6472

6573
async def map_source_to_response(payload: Any) -> ExecutionResult:
6674
"""Map source to response.
@@ -84,11 +92,28 @@ async def map_source_to_response(payload: Any) -> ExecutionResult:
8492
)
8593
return await result if isawaitable(result) else result
8694

95+
if (execution_context_class or ExecutionContext).is_awaitable(result_or_stream):
96+
awaitable_result_or_stream = cast(Awaitable, result_or_stream)
97+
98+
# noinspection PyShadowingNames
99+
async def await_result() -> Any:
100+
result_or_stream = await awaitable_result_or_stream
101+
if isinstance(result_or_stream, ExecutionResult):
102+
return result_or_stream
103+
return MapAsyncIterator(result_or_stream, map_source_to_response)
104+
105+
return await_result()
106+
107+
if isinstance(result_or_stream, ExecutionResult):
108+
return result_or_stream
109+
87110
# Map every source value to a ExecutionResult value as described above.
88-
return MapAsyncIterator(result_or_stream, map_source_to_response)
111+
return MapAsyncIterator(
112+
cast(AsyncIterable[Any], result_or_stream), map_source_to_response
113+
)
89114

90115

91-
async def create_source_event_stream(
116+
def create_source_event_stream(
92117
schema: GraphQLSchema,
93118
document: DocumentNode,
94119
root_value: Any = None,
@@ -97,7 +122,7 @@ async def create_source_event_stream(
97122
operation_name: Optional[str] = None,
98123
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
99124
execution_context_class: Optional[Type[ExecutionContext]] = None,
100-
) -> Union[AsyncIterable[Any], ExecutionResult]:
125+
) -> AwaitableOrValue[Union[AsyncIterable[Any], ExecutionResult]]:
101126
"""Create source event stream
102127
103128
Implements the "CreateSourceEventStream" algorithm described in the GraphQL
@@ -145,12 +170,28 @@ async def create_source_event_stream(
145170
return ExecutionResult(data=None, errors=context)
146171

147172
try:
148-
return await execute_subscription(context)
173+
event_stream = execute_subscription(context)
149174
except GraphQLError as error:
150175
return ExecutionResult(data=None, errors=[error])
151176

177+
if context.is_awaitable(event_stream):
178+
awaitable_event_stream = cast(Awaitable, event_stream)
179+
180+
# noinspection PyShadowingNames
181+
async def await_event_stream() -> Union[AsyncIterable[Any], ExecutionResult]:
182+
try:
183+
return await awaitable_event_stream
184+
except GraphQLError as error:
185+
return ExecutionResult(data=None, errors=[error])
152186

153-
async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
187+
return await_event_stream()
188+
189+
return event_stream
190+
191+
192+
def execute_subscription(
193+
context: ExecutionContext,
194+
) -> AwaitableOrValue[AsyncIterable[Any]]:
154195
schema = context.schema
155196

156197
root_type = schema.subscription_type
@@ -191,19 +232,33 @@ async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
191232
# AsyncIterable yielding raw payloads.
192233
resolve_fn = field_def.subscribe or context.subscribe_field_resolver
193234

194-
event_stream = resolve_fn(context.root_value, info, **args)
195-
if context.is_awaitable(event_stream):
196-
event_stream = await event_stream
197-
if isinstance(event_stream, Exception):
198-
raise event_stream
235+
result = resolve_fn(context.root_value, info, **args)
236+
if context.is_awaitable(result):
199237

200-
# Assert field returned an event stream, otherwise yield an error.
201-
if not isinstance(event_stream, AsyncIterable):
202-
raise GraphQLError(
203-
"Subscription field must return AsyncIterable."
204-
f" Received: {inspect(event_stream)}."
205-
)
238+
# noinspection PyShadowingNames
239+
async def await_result() -> AsyncIterable[Any]:
240+
try:
241+
return assert_event_stream(await result)
242+
except Exception as error:
243+
raise located_error(error, field_nodes, path.as_list())
244+
245+
return await_result()
246+
247+
return assert_event_stream(result)
206248

207-
return event_stream
208249
except Exception as error:
209250
raise located_error(error, field_nodes, path.as_list())
251+
252+
253+
def assert_event_stream(result: Any) -> AsyncIterable:
254+
if isinstance(result, Exception):
255+
raise result
256+
257+
# Assert field returned an event stream, otherwise yield an error.
258+
if not isinstance(result, AsyncIterable):
259+
raise GraphQLError(
260+
"Subscription field must return AsyncIterable."
261+
f" Received: {inspect(result)}."
262+
)
263+
264+
return result

‎tests/execution/test_customize.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class Root:
6767
async def custom_foo():
6868
yield {"foo": "FooValue"}
6969

70-
subscription = await subscribe(
70+
subscription = subscribe(
7171
schema,
7272
document=parse("subscription { foo }"),
7373
root_value=Root(),
@@ -111,7 +111,7 @@ def resolve_foo(message, _info):
111111
)
112112

113113
document = parse("subscription { foo }")
114-
subscription = await subscribe(
114+
subscription = subscribe(
115115
schema,
116116
document,
117117
context_value={},

0 commit comments

Comments
 (0)