1
1
from inspect import isawaitable
2
- from typing import Any , AsyncIterable , AsyncIterator , Dict , Optional , Type , Union
2
+ from typing import (
3
+ Any ,
4
+ AsyncIterable ,
5
+ AsyncIterator ,
6
+ Awaitable ,
7
+ Dict ,
8
+ Optional ,
9
+ Type ,
10
+ Union ,
11
+ cast ,
12
+ )
3
13
4
14
from ..error import GraphQLError , located_error
5
15
from ..execution .collect_fields import collect_fields
11
21
)
12
22
from ..execution .values import get_argument_values
13
23
from ..language import DocumentNode
14
- from ..pyutils import Path , inspect
24
+ from ..pyutils import AwaitableOrValue , Path , inspect
15
25
from ..type import GraphQLFieldResolver , GraphQLSchema
16
26
from .map_async_iterator import MapAsyncIterator
17
27
18
28
19
29
__all__ = ["subscribe" , "create_source_event_stream" ]
20
30
21
31
22
- async def subscribe (
32
+ def subscribe (
23
33
schema : GraphQLSchema ,
24
34
document : DocumentNode ,
25
35
root_value : Any = None ,
@@ -29,7 +39,7 @@ async def subscribe(
29
39
field_resolver : Optional [GraphQLFieldResolver ] = None ,
30
40
subscribe_field_resolver : Optional [GraphQLFieldResolver ] = None ,
31
41
execution_context_class : Optional [Type [ExecutionContext ]] = None ,
32
- ) -> Union [AsyncIterator [ExecutionResult ], ExecutionResult ]:
42
+ ) -> AwaitableOrValue [ Union [AsyncIterator [ExecutionResult ], ExecutionResult ] ]:
33
43
"""Create a GraphQL subscription.
34
44
35
45
Implements the "Subscribe" algorithm described in the GraphQL spec.
@@ -49,7 +59,7 @@ async def subscribe(
49
59
If the operation succeeded, the coroutine will yield an AsyncIterator, which yields
50
60
a stream of ExecutionResults representing the response stream.
51
61
"""
52
- result_or_stream = await create_source_event_stream (
62
+ result_or_stream = create_source_event_stream (
53
63
schema ,
54
64
document ,
55
65
root_value ,
@@ -59,8 +69,6 @@ async def subscribe(
59
69
subscribe_field_resolver ,
60
70
execution_context_class ,
61
71
)
62
- if isinstance (result_or_stream , ExecutionResult ):
63
- return result_or_stream
64
72
65
73
async def map_source_to_response (payload : Any ) -> ExecutionResult :
66
74
"""Map source to response.
@@ -84,11 +92,28 @@ async def map_source_to_response(payload: Any) -> ExecutionResult:
84
92
)
85
93
return await result if isawaitable (result ) else result
86
94
95
+ if (execution_context_class or ExecutionContext ).is_awaitable (result_or_stream ):
96
+ awaitable_result_or_stream = cast (Awaitable , result_or_stream )
97
+
98
+ # noinspection PyShadowingNames
99
+ async def await_result () -> Any :
100
+ result_or_stream = await awaitable_result_or_stream
101
+ if isinstance (result_or_stream , ExecutionResult ):
102
+ return result_or_stream
103
+ return MapAsyncIterator (result_or_stream , map_source_to_response )
104
+
105
+ return await_result ()
106
+
107
+ if isinstance (result_or_stream , ExecutionResult ):
108
+ return result_or_stream
109
+
87
110
# Map every source value to a ExecutionResult value as described above.
88
- return MapAsyncIterator (result_or_stream , map_source_to_response )
111
+ return MapAsyncIterator (
112
+ cast (AsyncIterable [Any ], result_or_stream ), map_source_to_response
113
+ )
89
114
90
115
91
- async def create_source_event_stream (
116
+ def create_source_event_stream (
92
117
schema : GraphQLSchema ,
93
118
document : DocumentNode ,
94
119
root_value : Any = None ,
@@ -97,7 +122,7 @@ async def create_source_event_stream(
97
122
operation_name : Optional [str ] = None ,
98
123
subscribe_field_resolver : Optional [GraphQLFieldResolver ] = None ,
99
124
execution_context_class : Optional [Type [ExecutionContext ]] = None ,
100
- ) -> Union [AsyncIterable [Any ], ExecutionResult ]:
125
+ ) -> AwaitableOrValue [ Union [AsyncIterable [Any ], ExecutionResult ] ]:
101
126
"""Create source event stream
102
127
103
128
Implements the "CreateSourceEventStream" algorithm described in the GraphQL
@@ -145,12 +170,28 @@ async def create_source_event_stream(
145
170
return ExecutionResult (data = None , errors = context )
146
171
147
172
try :
148
- return await execute_subscription (context )
173
+ event_stream = execute_subscription (context )
149
174
except GraphQLError as error :
150
175
return ExecutionResult (data = None , errors = [error ])
151
176
177
+ if context .is_awaitable (event_stream ):
178
+ awaitable_event_stream = cast (Awaitable , event_stream )
179
+
180
+ # noinspection PyShadowingNames
181
+ async def await_event_stream () -> Union [AsyncIterable [Any ], ExecutionResult ]:
182
+ try :
183
+ return await awaitable_event_stream
184
+ except GraphQLError as error :
185
+ return ExecutionResult (data = None , errors = [error ])
152
186
153
- async def execute_subscription (context : ExecutionContext ) -> AsyncIterable [Any ]:
187
+ return await_event_stream ()
188
+
189
+ return event_stream
190
+
191
+
192
+ def execute_subscription (
193
+ context : ExecutionContext ,
194
+ ) -> AwaitableOrValue [AsyncIterable [Any ]]:
154
195
schema = context .schema
155
196
156
197
root_type = schema .subscription_type
@@ -191,19 +232,33 @@ async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
191
232
# AsyncIterable yielding raw payloads.
192
233
resolve_fn = field_def .subscribe or context .subscribe_field_resolver
193
234
194
- event_stream = resolve_fn (context .root_value , info , ** args )
195
- if context .is_awaitable (event_stream ):
196
- event_stream = await event_stream
197
- if isinstance (event_stream , Exception ):
198
- raise event_stream
235
+ result = resolve_fn (context .root_value , info , ** args )
236
+ if context .is_awaitable (result ):
199
237
200
- # Assert field returned an event stream, otherwise yield an error.
201
- if not isinstance (event_stream , AsyncIterable ):
202
- raise GraphQLError (
203
- "Subscription field must return AsyncIterable."
204
- f" Received: { inspect (event_stream )} ."
205
- )
238
+ # noinspection PyShadowingNames
239
+ async def await_result () -> AsyncIterable [Any ]:
240
+ try :
241
+ return assert_event_stream (await result )
242
+ except Exception as error :
243
+ raise located_error (error , field_nodes , path .as_list ())
244
+
245
+ return await_result ()
246
+
247
+ return assert_event_stream (result )
206
248
207
- return event_stream
208
249
except Exception as error :
209
250
raise located_error (error , field_nodes , path .as_list ())
251
+
252
+
253
+ def assert_event_stream (result : Any ) -> AsyncIterable :
254
+ if isinstance (result , Exception ):
255
+ raise result
256
+
257
+ # Assert field returned an event stream, otherwise yield an error.
258
+ if not isinstance (result , AsyncIterable ):
259
+ raise GraphQLError (
260
+ "Subscription field must return AsyncIterable."
261
+ f" Received: { inspect (result )} ."
262
+ )
263
+
264
+ return result
0 commit comments