36
36
from ....cluster .api import ClusterAPI
37
37
from ....lifecycle .api import LifecycleAPI
38
38
from ....meta .api import MetaAPI
39
- from ....subtask import SubtaskGraph
39
+ from ....subtask import Subtask , SubtaskGraph
40
40
from ....subtask .utils import iter_input_data_keys , iter_output_data
41
41
from ..api import TaskExecutor , ExecutionChunkResult , register_executor_cls
42
+ from .context import RayExecutionContext
42
43
43
44
ray = lazy_import ("ray" )
44
45
logger = logging .getLogger (__name__ )
@@ -55,7 +56,7 @@ def execute_subtask(
55
56
ensure_coverage ()
56
57
subtask_chunk_graph = deserialize (* subtask_chunk_graph )
57
58
# inputs = [i[1] for i in inputs]
58
- context = dict (zip (input_keys , inputs ))
59
+ context = RayExecutionContext (zip (input_keys , inputs ))
59
60
# optimize chunk graph.
60
61
subtask_chunk_graph = optimize (subtask_chunk_graph )
61
62
# from data_key to results
@@ -117,7 +118,7 @@ async def create(
117
118
session_id : str ,
118
119
address : str ,
119
120
task ,
120
- tile_context ,
121
+ tile_context : TileContext ,
121
122
** kwargs
122
123
) -> "TaskExecutor" :
123
124
ray_executor = ray .remote (execute_subtask )
@@ -159,14 +160,10 @@ async def execute_subtask_graph(
159
160
result_keys = {chunk .key for chunk in chunk_graph .result_chunks }
160
161
for subtask in subtask_graph .topological_iter ():
161
162
subtask_chunk_graph = subtask .chunk_graph
162
- chunk_key_to_data_keys = get_chunk_key_to_data_keys (subtask_chunk_graph )
163
- key_to_input = {
164
- key : context [key ]
165
- for key , _ in iter_input_data_keys (
166
- subtask , subtask_chunk_graph , chunk_key_to_data_keys
167
- )
168
- }
169
- output_keys = self ._get_output_keys (subtask_chunk_graph )
163
+ key_to_input = await self ._load_subtask_inputs (
164
+ stage_id , subtask , subtask_chunk_graph , context
165
+ )
166
+ output_keys = self ._get_subtask_output_keys (subtask_chunk_graph )
170
167
output_meta_keys = result_keys & output_keys
171
168
output_count = len (output_keys ) + bool (output_meta_keys )
172
169
output_object_refs = self ._ray_executor .options (
@@ -250,8 +247,44 @@ async def get_progress(self) -> float:
250
247
async def cancel (self ):
251
248
"""Cancel execution."""
252
249
250
+ async def _load_subtask_inputs (
251
+ self , stage_id : str , subtask : Subtask , chunk_graph : ChunkGraph , context : Dict
252
+ ):
253
+ """
254
+ Load a dict of input key to object ref of subtask from context.
255
+
256
+ It updates the context if the input object refs are fetched from
257
+ the meta service.
258
+ """
259
+ key_to_input = {}
260
+ key_to_get_meta = {}
261
+ chunk_key_to_data_keys = get_chunk_key_to_data_keys (chunk_graph )
262
+ for key , _ in iter_input_data_keys (
263
+ subtask , chunk_graph , chunk_key_to_data_keys
264
+ ):
265
+ if key in context :
266
+ key_to_input [key ] = context [key ]
267
+ else :
268
+ key_to_get_meta [key ] = self ._meta_api .get_chunk_meta .delay (
269
+ key , fields = ["object_refs" ]
270
+ )
271
+ if key_to_get_meta :
272
+ logger .info (
273
+ "Fetch %s metas and update context of stage %s." ,
274
+ len (key_to_get_meta ),
275
+ stage_id ,
276
+ )
277
+ meta_list = await self ._meta_api .get_chunk_meta .batch (
278
+ * key_to_get_meta .values ()
279
+ )
280
+ for key , meta in zip (key_to_get_meta .keys (), meta_list ):
281
+ object_ref = meta ["object_refs" ][0 ]
282
+ key_to_input [key ] = object_ref
283
+ context [key ] = object_ref
284
+ return key_to_input
285
+
253
286
@staticmethod
254
- def _get_output_keys (chunk_graph ):
287
+ def _get_subtask_output_keys (chunk_graph : ChunkGraph ):
255
288
output_keys = {}
256
289
for chunk in chunk_graph .results :
257
290
if isinstance (chunk .op , VirtualOperand ):
0 commit comments