Skip to content

Commit 27fbf54

Browse files
committed
Simplify API & fix case
1 parent d814d7c commit 27fbf54

File tree

5 files changed

+29
-34
lines changed

5 files changed

+29
-34
lines changed

mars/deploy/oscar/tests/test_local.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@
9393
"serialization": {},
9494
"most_calls": DICT_NOT_EMPTY,
9595
"slow_calls": DICT_NOT_EMPTY,
96-
# "band_subtasks": DICT_NOT_EMPTY,
97-
# "slow_subtasks": DICT_NOT_EMPTY,
96+
"band_subtasks": {},
97+
"slow_subtasks": {},
9898
}
9999
}
100100
EXPECT_PROFILING_STRUCTURE_NO_SLOW = copy.deepcopy(EXPECT_PROFILING_STRUCTURE)

mars/services/scheduling/api/oscar.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ async def cancel_subtasks(
9999
self,
100100
subtask_ids: List[str],
101101
kill_timeout: Union[float, int] = None,
102-
wait: bool = False,
103102
):
104103
"""
105104
Cancel pending and running subtasks.
@@ -111,18 +110,11 @@ async def cancel_subtasks(
111110
kill_timeout
112111
timeout seconds to kill actor process forcibly
113112
"""
114-
if wait:
115-
await self._manager_ref.cancel_subtasks(
116-
subtask_ids, kill_timeout=kill_timeout
117-
)
118-
else:
119-
await self._manager_ref.cancel_subtasks.tell(
120-
subtask_ids, kill_timeout=kill_timeout
121-
)
113+
await self._manager_ref.cancel_subtasks(subtask_ids, kill_timeout=kill_timeout)
122114

123115
async def finish_subtasks(
124116
self,
125-
subtask_results: List[SubtaskResult],
117+
subtask_ids: List[str],
126118
bands: List[Tuple] = None,
127119
schedule_next: bool = True,
128120
):
@@ -132,14 +124,14 @@ async def finish_subtasks(
132124
133125
Parameters
134126
----------
135-
subtask_results
127+
subtask_ids
136128
results of subtasks, must in finished states
137129
bands
138130
bands of subtasks to mark as finished
139131
schedule_next
140132
whether to schedule succeeding subtasks
141133
"""
142-
await self._manager_ref.finish_subtasks(subtask_results, bands, schedule_next)
134+
await self._manager_ref.finish_subtasks.tell(subtask_ids, bands, schedule_next)
143135

144136

145137
class MockSchedulingAPI(SchedulingAPI):

mars/services/scheduling/supervisor/manager.py

+12-15
Original file line numberDiff line numberDiff line change
@@ -244,16 +244,14 @@ async def set_subtask_result(self, result: SubtaskResult, band: BandType):
244244

245245
async def finish_subtasks(
246246
self,
247-
subtask_results: List[SubtaskResult],
247+
subtask_ids: List[str],
248248
bands: List[BandType] = None,
249249
schedule_next: bool = True,
250250
):
251-
subtask_ids = [result.subtask_id for result in subtask_results]
252251
logger.debug("Finished subtasks %s.", subtask_ids)
253252
band_tasks = defaultdict(lambda: 0)
254253
bands = bands or [None] * len(subtask_ids)
255-
for result, subtask_band in zip(subtask_results, bands):
256-
subtask_id = result.subtask_id
254+
for subtask_id, subtask_band in zip(subtask_ids, bands):
257255
subtask_info = self._subtask_infos.get(subtask_id, None)
258256

259257
if subtask_info is not None:
@@ -265,13 +263,13 @@ async def finish_subtasks(
265263
"stage_id": subtask_info.subtask.stage_id,
266264
},
267265
)
268-
self._subtask_summaries[subtask_id] = subtask_info.to_summary(
269-
is_finished=True,
270-
is_cancelled=result.status == SubtaskStatus.cancelled,
271-
)
266+
if subtask_id not in self._subtask_summaries:
267+
self._subtask_summaries[subtask_id] = subtask_info.to_summary(
268+
is_finished=True,
269+
)
272270
subtask_info.end_time = time.time()
273271
self._speculation_execution_scheduler.finish_subtask(subtask_info)
274-
# Cancel subtask on other bands.
272+
# Cancel subtask on other bands.
275273
aio_task = subtask_info.band_futures.pop(subtask_band, None)
276274
if aio_task:
277275
yield aio_task
@@ -414,9 +412,8 @@ async def cancel_task_in_band(band):
414412

415413
info = self._subtask_infos[subtask_id]
416414
info.cancel_pending = True
417-
raw_tasks_to_cancel = list(info.band_futures.values())
418415

419-
if not raw_tasks_to_cancel:
416+
if not info.band_futures:
420417
# not submitted yet: mark subtasks as cancelled
421418
result = SubtaskResult(
422419
subtask_id=info.subtask.subtask_id,
@@ -435,13 +432,13 @@ async def cancel_task_in_band(band):
435432
)
436433
band_to_futures[band].append(future)
437434

438-
for band in band_to_futures:
439-
cancel_tasks.append(asyncio.create_task(cancel_task_in_band(band)))
440-
435+
# Dequeue first as it is possible to leak subtasks from queues
441436
if queued_subtask_ids:
442-
# Don't use `finish_subtasks` because it may remove queued
443437
await self._queueing_ref.remove_queued_subtasks(queued_subtask_ids)
444438

439+
for band in band_to_futures:
440+
cancel_tasks.append(asyncio.create_task(cancel_task_in_band(band)))
441+
445442
if cancel_tasks:
446443
yield asyncio.gather(*cancel_tasks)
447444

mars/services/scheduling/tests/test_service.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ async def set_subtask_result(self, subtask_result: SubtaskResult):
5050
for event in self._events[subtask_result.subtask_id]:
5151
event.set()
5252
self._events.pop(subtask_result.subtask_id, None)
53-
await scheduling_api.finish_subtasks([subtask_result], subtask_result.bands)
53+
await scheduling_api.finish_subtasks(
54+
[subtask_result.subtask_id], subtask_result.bands
55+
)
5456

5557
def _return_result(self, subtask_id: str):
5658
result = self._results[subtask_id]

mars/services/task/execution/mars/stage.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def __init__(
8282
# status
8383
self._done = asyncio.Event()
8484
self._cancelled = asyncio.Event()
85+
self._terminated = asyncio.Event()
8586

8687
# add metrics
8788
self._stage_execution_time = Metrics.gauge(
@@ -149,7 +150,7 @@ async def set_subtask_result(self, result: SubtaskResult, band: BandType = None)
149150
if all_done or error_or_cancelled:
150151
# tell scheduling to finish subtasks
151152
await self._scheduling_api.finish_subtasks(
152-
[result], bands=[band], schedule_next=not error_or_cancelled
153+
[result.subtask_id], bands=[band], schedule_next=not error_or_cancelled
153154
)
154155
if self.result.status != TaskStatus.terminated:
155156
self.result = TaskResult(
@@ -162,6 +163,7 @@ async def set_subtask_result(self, result: SubtaskResult, band: BandType = None)
162163
error=result.error,
163164
traceback=result.traceback,
164165
)
166+
self._terminated.set()
165167
if not all_done and error_or_cancelled:
166168
if result.status == SubtaskStatus.errored:
167169
logger.exception(
@@ -184,8 +186,7 @@ async def set_subtask_result(self, result: SubtaskResult, band: BandType = None)
184186
)
185187
# if error or cancel, cancel all submitted subtasks
186188
await self._scheduling_api.cancel_subtasks(
187-
list(self._submitted_subtask_ids),
188-
wait=False,
189+
list(self._submitted_subtask_ids)
189190
)
190191
self._schedule_done()
191192
cost_time_secs = self.result.end_time - self.result.start_time
@@ -219,7 +220,9 @@ async def set_subtask_result(self, result: SubtaskResult, band: BandType = None)
219220
# all predecessors finished
220221
to_schedule_subtasks.append(succ_subtask)
221222
await self._schedule_subtasks(to_schedule_subtasks)
222-
await self._scheduling_api.finish_subtasks([result], bands=[band])
223+
await self._scheduling_api.finish_subtasks(
224+
[result.subtask_id], bands=[band]
225+
)
223226

224227
async def run(self):
225228
if len(self.subtask_graph) == 0:
@@ -234,6 +237,7 @@ async def run(self):
234237

235238
# wait for completion
236239
await self._done.wait()
240+
await self._terminated.wait()
237241
if self.error_or_cancelled():
238242
if self.result.error is not None:
239243
raise self.result.error.with_traceback(self.result.traceback)

0 commit comments

Comments
 (0)