Skip to content

Commit 689a70d

Browse files
committed
refactor: move subscribe code to execute file
Replicates graphql/graphql-js@e24f426
1 parent e30f166 commit 689a70d

File tree

3 files changed

+241
-265
lines changed

3 files changed

+241
-265
lines changed

src/graphql/execution/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@
55
"""
66

77
from .execute import (
8+
create_source_event_stream,
89
execute,
910
execute_sync,
1011
default_field_resolver,
1112
default_type_resolver,
13+
subscribe,
1214
ExecutionContext,
1315
ExecutionResult,
1416
FormattedExecutionResult,
1517
Middleware,
1618
)
1719
from .map_async_iterator import MapAsyncIterator
18-
from .subscribe import subscribe, create_source_event_stream
1920
from .middleware import MiddlewareManager
2021
from .values import get_argument_values, get_directive_values, get_variable_values
2122

src/graphql/execution/execute.py

+239
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import (
77
Any,
88
AsyncIterable,
9+
AsyncIterator,
910
Awaitable,
1011
Callable,
1112
Dict,
@@ -58,16 +59,19 @@
5859
is_object_type,
5960
)
6061
from .collect_fields import collect_fields, collect_subfields
62+
from .map_async_iterator import MapAsyncIterator
6163
from .middleware import MiddlewareManager
6264
from .values import get_argument_values, get_variable_values
6365

6466

6567
__all__ = [
6668
"assert_valid_execution_arguments",
69+
"create_source_event_stream",
6770
"default_field_resolver",
6871
"default_type_resolver",
6972
"execute",
7073
"execute_sync",
74+
"subscribe",
7175
"ExecutionResult",
7276
"ExecutionContext",
7377
"FormattedExecutionResult",
@@ -1222,3 +1226,238 @@ def default_field_resolver(source: Any, info: GraphQLResolveInfo, **args: Any) -
12221226
if callable(value):
12231227
return value(info, **args)
12241228
return value
1229+
1230+
1231+
def subscribe(
1232+
schema: GraphQLSchema,
1233+
document: DocumentNode,
1234+
root_value: Any = None,
1235+
context_value: Any = None,
1236+
variable_values: Optional[Dict[str, Any]] = None,
1237+
operation_name: Optional[str] = None,
1238+
field_resolver: Optional[GraphQLFieldResolver] = None,
1239+
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
1240+
execution_context_class: Optional[Type[ExecutionContext]] = None,
1241+
) -> AwaitableOrValue[Union[AsyncIterator[ExecutionResult], ExecutionResult]]:
1242+
"""Create a GraphQL subscription.
1243+
1244+
Implements the "Subscribe" algorithm described in the GraphQL spec.
1245+
1246+
Returns a coroutine object which yields either an AsyncIterator (if successful) or
1247+
an ExecutionResult (client error). The coroutine will raise an exception if a server
1248+
error occurs.
1249+
1250+
If the client-provided arguments to this function do not result in a compliant
1251+
subscription, a GraphQL Response (ExecutionResult) with descriptive errors and no
1252+
data will be returned.
1253+
1254+
If the source stream could not be created due to faulty subscription resolver logic
1255+
or underlying systems, the coroutine object will yield a single ExecutionResult
1256+
containing ``errors`` and no ``data``.
1257+
1258+
If the operation succeeded, the coroutine will yield an AsyncIterator, which yields
1259+
a stream of ExecutionResults representing the response stream.
1260+
"""
1261+
result_or_stream = create_source_event_stream(
1262+
schema,
1263+
document,
1264+
root_value,
1265+
context_value,
1266+
variable_values,
1267+
operation_name,
1268+
subscribe_field_resolver,
1269+
execution_context_class,
1270+
)
1271+
1272+
async def map_source_to_response(payload: Any) -> ExecutionResult:
1273+
"""Map source to response.
1274+
1275+
For each payload yielded from a subscription, map it over the normal GraphQL
1276+
:func:`~graphql.execute` function, with ``payload`` as the ``root_value``.
1277+
This implements the "MapSourceToResponseEvent" algorithm described in the
1278+
GraphQL specification. The :func:`~graphql.execute` function provides the
1279+
"ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the
1280+
"ExecuteQuery" algorithm, for which :func:`~graphql.execute` is also used.
1281+
"""
1282+
result = execute(
1283+
schema,
1284+
document,
1285+
payload,
1286+
context_value,
1287+
variable_values,
1288+
operation_name,
1289+
field_resolver,
1290+
execution_context_class=execution_context_class,
1291+
)
1292+
return await result if isawaitable(result) else result
1293+
1294+
if (execution_context_class or ExecutionContext).is_awaitable(result_or_stream):
1295+
awaitable_result_or_stream = cast(Awaitable, result_or_stream)
1296+
1297+
# noinspection PyShadowingNames
1298+
async def await_result() -> Any:
1299+
result_or_stream = await awaitable_result_or_stream
1300+
if isinstance(result_or_stream, ExecutionResult):
1301+
return result_or_stream
1302+
return MapAsyncIterator(result_or_stream, map_source_to_response)
1303+
1304+
return await_result()
1305+
1306+
if isinstance(result_or_stream, ExecutionResult):
1307+
return result_or_stream
1308+
1309+
# Map every source value to a ExecutionResult value as described above.
1310+
return MapAsyncIterator(
1311+
cast(AsyncIterable[Any], result_or_stream), map_source_to_response
1312+
)
1313+
1314+
1315+
def create_source_event_stream(
1316+
schema: GraphQLSchema,
1317+
document: DocumentNode,
1318+
root_value: Any = None,
1319+
context_value: Any = None,
1320+
variable_values: Optional[Dict[str, Any]] = None,
1321+
operation_name: Optional[str] = None,
1322+
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
1323+
execution_context_class: Optional[Type[ExecutionContext]] = None,
1324+
) -> AwaitableOrValue[Union[AsyncIterable[Any], ExecutionResult]]:
1325+
"""Create source event stream
1326+
1327+
Implements the "CreateSourceEventStream" algorithm described in the GraphQL
1328+
specification, resolving the subscription source event stream.
1329+
1330+
Returns a coroutine that yields an AsyncIterable.
1331+
1332+
If the client-provided arguments to this function do not result in a compliant
1333+
subscription, a GraphQL Response (ExecutionResult) with descriptive errors and no
1334+
data will be returned.
1335+
1336+
If the source stream could not be created due to faulty subscription resolver logic
1337+
or underlying systems, the coroutine object will yield a single ExecutionResult
1338+
containing ``errors`` and no ``data``.
1339+
1340+
A source event stream represents a sequence of events, each of which triggers a
1341+
GraphQL execution for that event.
1342+
1343+
This may be useful when hosting the stateful subscription service in a different
1344+
process or machine than the stateless GraphQL execution engine, or otherwise
1345+
separating these two steps. For more on this, see the "Supporting Subscriptions
1346+
at Scale" information in the GraphQL spec.
1347+
"""
1348+
# If arguments are missing or incorrectly typed, this is an internal developer
1349+
# mistake which should throw an early error.
1350+
assert_valid_execution_arguments(schema, document, variable_values)
1351+
1352+
if not execution_context_class:
1353+
execution_context_class = ExecutionContext
1354+
1355+
# If a valid context cannot be created due to incorrect arguments,
1356+
# a "Response" with only errors is returned.
1357+
context = execution_context_class.build(
1358+
schema,
1359+
document,
1360+
root_value,
1361+
context_value,
1362+
variable_values,
1363+
operation_name,
1364+
subscribe_field_resolver=subscribe_field_resolver,
1365+
)
1366+
1367+
# Return early errors if execution context failed.
1368+
if isinstance(context, list):
1369+
return ExecutionResult(data=None, errors=context)
1370+
1371+
try:
1372+
event_stream = execute_subscription(context)
1373+
except GraphQLError as error:
1374+
return ExecutionResult(data=None, errors=[error])
1375+
1376+
if context.is_awaitable(event_stream):
1377+
awaitable_event_stream = cast(Awaitable, event_stream)
1378+
1379+
# noinspection PyShadowingNames
1380+
async def await_event_stream() -> Union[AsyncIterable[Any], ExecutionResult]:
1381+
try:
1382+
return await awaitable_event_stream
1383+
except GraphQLError as error:
1384+
return ExecutionResult(data=None, errors=[error])
1385+
1386+
return await_event_stream()
1387+
1388+
return event_stream
1389+
1390+
1391+
def execute_subscription(
1392+
context: ExecutionContext,
1393+
) -> AwaitableOrValue[AsyncIterable[Any]]:
1394+
schema = context.schema
1395+
1396+
root_type = schema.subscription_type
1397+
if root_type is None:
1398+
raise GraphQLError(
1399+
"Schema is not configured to execute subscription operation.",
1400+
context.operation,
1401+
)
1402+
1403+
root_fields = collect_fields(
1404+
schema,
1405+
context.fragments,
1406+
context.variable_values,
1407+
root_type,
1408+
context.operation.selection_set,
1409+
)
1410+
response_name, field_nodes = next(iter(root_fields.items()))
1411+
field_name = field_nodes[0].name.value
1412+
field_def = schema.get_field(root_type, field_name)
1413+
1414+
if not field_def:
1415+
raise GraphQLError(
1416+
f"The subscription field '{field_name}' is not defined.", field_nodes
1417+
)
1418+
1419+
path = Path(None, response_name, root_type.name)
1420+
info = context.build_resolve_info(field_def, field_nodes, root_type, path)
1421+
1422+
# Implements the "ResolveFieldEventStream" algorithm from GraphQL specification.
1423+
# It differs from "ResolveFieldValue" due to providing a different `resolveFn`.
1424+
1425+
try:
1426+
# Build a dictionary of arguments from the field.arguments AST, using the
1427+
# variables scope to fulfill any variable references.
1428+
args = get_argument_values(field_def, field_nodes[0], context.variable_values)
1429+
1430+
# Call the `subscribe()` resolver or the default resolver to produce an
1431+
# AsyncIterable yielding raw payloads.
1432+
resolve_fn = field_def.subscribe or context.subscribe_field_resolver
1433+
1434+
result = resolve_fn(context.root_value, info, **args)
1435+
if context.is_awaitable(result):
1436+
1437+
# noinspection PyShadowingNames
1438+
async def await_result() -> AsyncIterable[Any]:
1439+
try:
1440+
return assert_event_stream(await result)
1441+
except Exception as error:
1442+
raise located_error(error, field_nodes, path.as_list())
1443+
1444+
return await_result()
1445+
1446+
return assert_event_stream(result)
1447+
1448+
except Exception as error:
1449+
raise located_error(error, field_nodes, path.as_list())
1450+
1451+
1452+
def assert_event_stream(result: Any) -> AsyncIterable:
1453+
if isinstance(result, Exception):
1454+
raise result
1455+
1456+
# Assert field returned an event stream, otherwise yield an error.
1457+
if not isinstance(result, AsyncIterable):
1458+
raise GraphQLError(
1459+
"Subscription field must return AsyncIterable."
1460+
f" Received: {inspect(result)}."
1461+
)
1462+
1463+
return result

0 commit comments

Comments
 (0)