|
6 | 6 | from typing import (
|
7 | 7 | Any,
|
8 | 8 | AsyncIterable,
|
| 9 | + AsyncIterator, |
9 | 10 | Awaitable,
|
10 | 11 | Callable,
|
11 | 12 | Dict,
|
|
58 | 59 | is_object_type,
|
59 | 60 | )
|
60 | 61 | from .collect_fields import collect_fields, collect_subfields
|
| 62 | +from .map_async_iterator import MapAsyncIterator |
61 | 63 | from .middleware import MiddlewareManager
|
62 | 64 | from .values import get_argument_values, get_variable_values
|
63 | 65 |
|
64 | 66 |
|
65 | 67 | __all__ = [
|
66 | 68 | "assert_valid_execution_arguments",
|
| 69 | + "create_source_event_stream", |
67 | 70 | "default_field_resolver",
|
68 | 71 | "default_type_resolver",
|
69 | 72 | "execute",
|
70 | 73 | "execute_sync",
|
| 74 | + "subscribe", |
71 | 75 | "ExecutionResult",
|
72 | 76 | "ExecutionContext",
|
73 | 77 | "FormattedExecutionResult",
|
@@ -1222,3 +1226,238 @@ def default_field_resolver(source: Any, info: GraphQLResolveInfo, **args: Any) -
|
1222 | 1226 | if callable(value):
|
1223 | 1227 | return value(info, **args)
|
1224 | 1228 | return value
|
| 1229 | + |
| 1230 | + |
| 1231 | +def subscribe( |
| 1232 | + schema: GraphQLSchema, |
| 1233 | + document: DocumentNode, |
| 1234 | + root_value: Any = None, |
| 1235 | + context_value: Any = None, |
| 1236 | + variable_values: Optional[Dict[str, Any]] = None, |
| 1237 | + operation_name: Optional[str] = None, |
| 1238 | + field_resolver: Optional[GraphQLFieldResolver] = None, |
| 1239 | + subscribe_field_resolver: Optional[GraphQLFieldResolver] = None, |
| 1240 | + execution_context_class: Optional[Type[ExecutionContext]] = None, |
| 1241 | +) -> AwaitableOrValue[Union[AsyncIterator[ExecutionResult], ExecutionResult]]: |
| 1242 | + """Create a GraphQL subscription. |
| 1243 | +
|
| 1244 | + Implements the "Subscribe" algorithm described in the GraphQL spec. |
| 1245 | +
|
| 1246 | + Returns a coroutine object which yields either an AsyncIterator (if successful) or |
| 1247 | + an ExecutionResult (client error). The coroutine will raise an exception if a server |
| 1248 | + error occurs. |
| 1249 | +
|
| 1250 | + If the client-provided arguments to this function do not result in a compliant |
| 1251 | + subscription, a GraphQL Response (ExecutionResult) with descriptive errors and no |
| 1252 | + data will be returned. |
| 1253 | +
|
| 1254 | + If the source stream could not be created due to faulty subscription resolver logic |
| 1255 | + or underlying systems, the coroutine object will yield a single ExecutionResult |
| 1256 | + containing ``errors`` and no ``data``. |
| 1257 | +
|
| 1258 | + If the operation succeeded, the coroutine will yield an AsyncIterator, which yields |
| 1259 | + a stream of ExecutionResults representing the response stream. |
| 1260 | + """ |
| 1261 | + result_or_stream = create_source_event_stream( |
| 1262 | + schema, |
| 1263 | + document, |
| 1264 | + root_value, |
| 1265 | + context_value, |
| 1266 | + variable_values, |
| 1267 | + operation_name, |
| 1268 | + subscribe_field_resolver, |
| 1269 | + execution_context_class, |
| 1270 | + ) |
| 1271 | + |
| 1272 | + async def map_source_to_response(payload: Any) -> ExecutionResult: |
| 1273 | + """Map source to response. |
| 1274 | +
|
| 1275 | + For each payload yielded from a subscription, map it over the normal GraphQL |
| 1276 | + :func:`~graphql.execute` function, with ``payload`` as the ``root_value``. |
| 1277 | + This implements the "MapSourceToResponseEvent" algorithm described in the |
| 1278 | + GraphQL specification. The :func:`~graphql.execute` function provides the |
| 1279 | + "ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the |
| 1280 | + "ExecuteQuery" algorithm, for which :func:`~graphql.execute` is also used. |
| 1281 | + """ |
| 1282 | + result = execute( |
| 1283 | + schema, |
| 1284 | + document, |
| 1285 | + payload, |
| 1286 | + context_value, |
| 1287 | + variable_values, |
| 1288 | + operation_name, |
| 1289 | + field_resolver, |
| 1290 | + execution_context_class=execution_context_class, |
| 1291 | + ) |
| 1292 | + return await result if isawaitable(result) else result |
| 1293 | + |
| 1294 | + if (execution_context_class or ExecutionContext).is_awaitable(result_or_stream): |
| 1295 | + awaitable_result_or_stream = cast(Awaitable, result_or_stream) |
| 1296 | + |
| 1297 | + # noinspection PyShadowingNames |
| 1298 | + async def await_result() -> Any: |
| 1299 | + result_or_stream = await awaitable_result_or_stream |
| 1300 | + if isinstance(result_or_stream, ExecutionResult): |
| 1301 | + return result_or_stream |
| 1302 | + return MapAsyncIterator(result_or_stream, map_source_to_response) |
| 1303 | + |
| 1304 | + return await_result() |
| 1305 | + |
| 1306 | + if isinstance(result_or_stream, ExecutionResult): |
| 1307 | + return result_or_stream |
| 1308 | + |
| 1309 | + # Map every source value to a ExecutionResult value as described above. |
| 1310 | + return MapAsyncIterator( |
| 1311 | + cast(AsyncIterable[Any], result_or_stream), map_source_to_response |
| 1312 | + ) |
| 1313 | + |
| 1314 | + |
| 1315 | +def create_source_event_stream( |
| 1316 | + schema: GraphQLSchema, |
| 1317 | + document: DocumentNode, |
| 1318 | + root_value: Any = None, |
| 1319 | + context_value: Any = None, |
| 1320 | + variable_values: Optional[Dict[str, Any]] = None, |
| 1321 | + operation_name: Optional[str] = None, |
| 1322 | + subscribe_field_resolver: Optional[GraphQLFieldResolver] = None, |
| 1323 | + execution_context_class: Optional[Type[ExecutionContext]] = None, |
| 1324 | +) -> AwaitableOrValue[Union[AsyncIterable[Any], ExecutionResult]]: |
| 1325 | + """Create source event stream |
| 1326 | +
|
| 1327 | + Implements the "CreateSourceEventStream" algorithm described in the GraphQL |
| 1328 | + specification, resolving the subscription source event stream. |
| 1329 | +
|
| 1330 | + Returns a coroutine that yields an AsyncIterable. |
| 1331 | +
|
| 1332 | + If the client-provided arguments to this function do not result in a compliant |
| 1333 | + subscription, a GraphQL Response (ExecutionResult) with descriptive errors and no |
| 1334 | + data will be returned. |
| 1335 | +
|
| 1336 | + If the source stream could not be created due to faulty subscription resolver logic |
| 1337 | + or underlying systems, the coroutine object will yield a single ExecutionResult |
| 1338 | + containing ``errors`` and no ``data``. |
| 1339 | +
|
| 1340 | + A source event stream represents a sequence of events, each of which triggers a |
| 1341 | + GraphQL execution for that event. |
| 1342 | +
|
| 1343 | + This may be useful when hosting the stateful subscription service in a different |
| 1344 | + process or machine than the stateless GraphQL execution engine, or otherwise |
| 1345 | + separating these two steps. For more on this, see the "Supporting Subscriptions |
| 1346 | + at Scale" information in the GraphQL spec. |
| 1347 | + """ |
| 1348 | + # If arguments are missing or incorrectly typed, this is an internal developer |
| 1349 | + # mistake which should throw an early error. |
| 1350 | + assert_valid_execution_arguments(schema, document, variable_values) |
| 1351 | + |
| 1352 | + if not execution_context_class: |
| 1353 | + execution_context_class = ExecutionContext |
| 1354 | + |
| 1355 | + # If a valid context cannot be created due to incorrect arguments, |
| 1356 | + # a "Response" with only errors is returned. |
| 1357 | + context = execution_context_class.build( |
| 1358 | + schema, |
| 1359 | + document, |
| 1360 | + root_value, |
| 1361 | + context_value, |
| 1362 | + variable_values, |
| 1363 | + operation_name, |
| 1364 | + subscribe_field_resolver=subscribe_field_resolver, |
| 1365 | + ) |
| 1366 | + |
| 1367 | + # Return early errors if execution context failed. |
| 1368 | + if isinstance(context, list): |
| 1369 | + return ExecutionResult(data=None, errors=context) |
| 1370 | + |
| 1371 | + try: |
| 1372 | + event_stream = execute_subscription(context) |
| 1373 | + except GraphQLError as error: |
| 1374 | + return ExecutionResult(data=None, errors=[error]) |
| 1375 | + |
| 1376 | + if context.is_awaitable(event_stream): |
| 1377 | + awaitable_event_stream = cast(Awaitable, event_stream) |
| 1378 | + |
| 1379 | + # noinspection PyShadowingNames |
| 1380 | + async def await_event_stream() -> Union[AsyncIterable[Any], ExecutionResult]: |
| 1381 | + try: |
| 1382 | + return await awaitable_event_stream |
| 1383 | + except GraphQLError as error: |
| 1384 | + return ExecutionResult(data=None, errors=[error]) |
| 1385 | + |
| 1386 | + return await_event_stream() |
| 1387 | + |
| 1388 | + return event_stream |
| 1389 | + |
| 1390 | + |
| 1391 | +def execute_subscription( |
| 1392 | + context: ExecutionContext, |
| 1393 | +) -> AwaitableOrValue[AsyncIterable[Any]]: |
| 1394 | + schema = context.schema |
| 1395 | + |
| 1396 | + root_type = schema.subscription_type |
| 1397 | + if root_type is None: |
| 1398 | + raise GraphQLError( |
| 1399 | + "Schema is not configured to execute subscription operation.", |
| 1400 | + context.operation, |
| 1401 | + ) |
| 1402 | + |
| 1403 | + root_fields = collect_fields( |
| 1404 | + schema, |
| 1405 | + context.fragments, |
| 1406 | + context.variable_values, |
| 1407 | + root_type, |
| 1408 | + context.operation.selection_set, |
| 1409 | + ) |
| 1410 | + response_name, field_nodes = next(iter(root_fields.items())) |
| 1411 | + field_name = field_nodes[0].name.value |
| 1412 | + field_def = schema.get_field(root_type, field_name) |
| 1413 | + |
| 1414 | + if not field_def: |
| 1415 | + raise GraphQLError( |
| 1416 | + f"The subscription field '{field_name}' is not defined.", field_nodes |
| 1417 | + ) |
| 1418 | + |
| 1419 | + path = Path(None, response_name, root_type.name) |
| 1420 | + info = context.build_resolve_info(field_def, field_nodes, root_type, path) |
| 1421 | + |
| 1422 | + # Implements the "ResolveFieldEventStream" algorithm from GraphQL specification. |
| 1423 | + # It differs from "ResolveFieldValue" due to providing a different `resolveFn`. |
| 1424 | + |
| 1425 | + try: |
| 1426 | + # Build a dictionary of arguments from the field.arguments AST, using the |
| 1427 | + # variables scope to fulfill any variable references. |
| 1428 | + args = get_argument_values(field_def, field_nodes[0], context.variable_values) |
| 1429 | + |
| 1430 | + # Call the `subscribe()` resolver or the default resolver to produce an |
| 1431 | + # AsyncIterable yielding raw payloads. |
| 1432 | + resolve_fn = field_def.subscribe or context.subscribe_field_resolver |
| 1433 | + |
| 1434 | + result = resolve_fn(context.root_value, info, **args) |
| 1435 | + if context.is_awaitable(result): |
| 1436 | + |
| 1437 | + # noinspection PyShadowingNames |
| 1438 | + async def await_result() -> AsyncIterable[Any]: |
| 1439 | + try: |
| 1440 | + return assert_event_stream(await result) |
| 1441 | + except Exception as error: |
| 1442 | + raise located_error(error, field_nodes, path.as_list()) |
| 1443 | + |
| 1444 | + return await_result() |
| 1445 | + |
| 1446 | + return assert_event_stream(result) |
| 1447 | + |
| 1448 | + except Exception as error: |
| 1449 | + raise located_error(error, field_nodes, path.as_list()) |
| 1450 | + |
| 1451 | + |
| 1452 | +def assert_event_stream(result: Any) -> AsyncIterable: |
| 1453 | + if isinstance(result, Exception): |
| 1454 | + raise result |
| 1455 | + |
| 1456 | + # Assert field returned an event stream, otherwise yield an error. |
| 1457 | + if not isinstance(result, AsyncIterable): |
| 1458 | + raise GraphQLError( |
| 1459 | + "Subscription field must return AsyncIterable." |
| 1460 | + f" Received: {inspect(result)}." |
| 1461 | + ) |
| 1462 | + |
| 1463 | + return result |
0 commit comments