@@ -244,16 +244,14 @@ async def set_subtask_result(self, result: SubtaskResult, band: BandType):
244
244
245
245
async def finish_subtasks (
246
246
self ,
247
- subtask_results : List [SubtaskResult ],
247
+ subtask_ids : List [str ],
248
248
bands : List [BandType ] = None ,
249
249
schedule_next : bool = True ,
250
250
):
251
- subtask_ids = [result .subtask_id for result in subtask_results ]
252
251
logger .debug ("Finished subtasks %s." , subtask_ids )
253
252
band_tasks = defaultdict (lambda : 0 )
254
253
bands = bands or [None ] * len (subtask_ids )
255
- for result , subtask_band in zip (subtask_results , bands ):
256
- subtask_id = result .subtask_id
254
+ for subtask_id , subtask_band in zip (subtask_ids , bands ):
257
255
subtask_info = self ._subtask_infos .get (subtask_id , None )
258
256
259
257
if subtask_info is not None :
@@ -265,13 +263,13 @@ async def finish_subtasks(
265
263
"stage_id" : subtask_info .subtask .stage_id ,
266
264
},
267
265
)
268
- self . _subtask_summaries [ subtask_id ] = subtask_info . to_summary (
269
- is_finished = True ,
270
- is_cancelled = result . status == SubtaskStatus . cancelled ,
271
- )
266
+ if subtask_id not in self . _subtask_summaries :
267
+ self . _subtask_summaries [ subtask_id ] = subtask_info . to_summary (
268
+ is_finished = True ,
269
+ )
272
270
subtask_info .end_time = time .time ()
273
271
self ._speculation_execution_scheduler .finish_subtask (subtask_info )
274
- # Cancel subtask on other bands.
272
+ # Cancel subtask on other bands.
275
273
aio_task = subtask_info .band_futures .pop (subtask_band , None )
276
274
if aio_task :
277
275
yield aio_task
@@ -414,9 +412,8 @@ async def cancel_task_in_band(band):
414
412
415
413
info = self ._subtask_infos [subtask_id ]
416
414
info .cancel_pending = True
417
- raw_tasks_to_cancel = list (info .band_futures .values ())
418
415
419
- if not raw_tasks_to_cancel :
416
+ if not info . band_futures :
420
417
# not submitted yet: mark subtasks as cancelled
421
418
result = SubtaskResult (
422
419
subtask_id = info .subtask .subtask_id ,
@@ -435,13 +432,13 @@ async def cancel_task_in_band(band):
435
432
)
436
433
band_to_futures [band ].append (future )
437
434
438
- for band in band_to_futures :
439
- cancel_tasks .append (asyncio .create_task (cancel_task_in_band (band )))
440
-
435
+ # Dequeue first as it is possible to leak subtasks from queues
441
436
if queued_subtask_ids :
442
- # Don't use `finish_subtasks` because it may remove queued
443
437
await self ._queueing_ref .remove_queued_subtasks (queued_subtask_ids )
444
438
439
+ for band in band_to_futures :
440
+ cancel_tasks .append (asyncio .create_task (cancel_task_in_band (band )))
441
+
445
442
if cancel_tasks :
446
443
yield asyncio .gather (* cancel_tasks )
447
444
0 commit comments