Skip to content

Commit

Permalink
Merge branch 'main' into homogeneous-balancing
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait committed Feb 4, 2025
2 parents 2827f6f + b86b714 commit f97cf57
Show file tree
Hide file tree
Showing 9 changed files with 230 additions and 113 deletions.
4 changes: 3 additions & 1 deletion distributed/dashboard/tests/test_scheduler_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ async def test_counters(c, s, a, b):
await asyncio.sleep(0.01)


@gen_cluster(client=True)
@gen_cluster(
client=True, config={"distributed.scheduler.work-stealing-interval": "100ms"}
)
async def test_stealing_events(c, s, a, b):
se = StealingEvents(s)

Expand Down
2 changes: 1 addition & 1 deletion distributed/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ distributed:
idle-timeout: null # Shut down after this duration, like "1h" or "30 minutes"
no-workers-timeout: null # If a task remains unrunnable for longer than this, it fails.
work-stealing: True # workers should steal tasks from each other
work-stealing-interval: 100ms # Callback time for work stealing
work-stealing-interval: 1s # Callback time for work stealing
worker-saturation: 1.1 # Send this fraction of nthreads root tasks to workers
rootish-taskgroup: 5 # number of dependencies of a rootish tg
rootish-taskgroup-dependencies: 5 # number of dependencies of the dependencies of the rootish tg
Expand Down
8 changes: 6 additions & 2 deletions distributed/http/scheduler/tests/test_stealing_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ async def test_prometheus(c, s, a, b):
assert active_metrics == expected_metrics


@gen_cluster(client=True)
@gen_cluster(
client=True, config={"distributed.scheduler.work-stealing-interval": "100ms"}
)
async def test_prometheus_collect_count_total_by_cost_multipliers(c, s, a, b):
pytest.importorskip("prometheus_client")

Expand Down Expand Up @@ -58,7 +60,9 @@ async def fetch_metrics_by_cost_multipliers():
assert count == expected_count


@gen_cluster(client=True)
@gen_cluster(
client=True, config={"distributed.scheduler.work-stealing-interval": "100ms"}
)
async def test_prometheus_collect_cost_total_by_cost_multipliers(c, s, a, b):
pytest.importorskip("prometheus_client")

Expand Down
84 changes: 28 additions & 56 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,9 +1674,6 @@ class SchedulerState:
#: Subset of tasks that exist in memory on more than one worker
replicated_tasks: set[TaskState]

#: Tasks with unknown duration, grouped by prefix
#: {task prefix: {ts, ts, ...}}
unknown_durations: dict[str, set[TaskState]]
task_groups: dict[str, TaskGroup]
task_prefixes: dict[str, TaskPrefix]
task_metadata: dict[Key, Any]
Expand Down Expand Up @@ -1776,7 +1773,6 @@ def __init__(
self.task_metadata = {}
self.total_nthreads = 0
self.total_nthreads_history = [(time(), 0)]
self.unknown_durations = {}
self.queued = queued
self.unrunnable = unrunnable
self.validate = validate
Expand Down Expand Up @@ -1855,7 +1851,6 @@ def __pdict__(self) -> dict[str, Any]:
"unrunnable": self.unrunnable,
"queued": self.queued,
"n_tasks": self.n_tasks,
"unknown_durations": self.unknown_durations,
"validate": self.validate,
"tasks": self.tasks,
"task_groups": self.task_groups,
Expand Down Expand Up @@ -1907,7 +1902,6 @@ def _clear_task_state(self) -> None:
self.task_prefixes,
self.task_groups,
self.task_metadata,
self.unknown_durations,
self.replicated_tasks,
):
collection.clear()
Expand All @@ -1931,22 +1925,37 @@ def total_occupancy(self) -> float:
self._network_occ_global,
)

def _get_prefix_duration(self, prefix: TaskPrefix) -> float:
"""Get the estimated computation cost of the given task prefix
(not including any communication cost).
If no data has been observed, value of
`distributed.scheduler.default-task-durations` are used. If none is set
for this task, `distributed.scheduler.unknown-task-duration` is used
instead.
See Also
--------
WorkStealing.get_task_duration
"""
# TODO: Deal with unknown tasks better
assert prefix is not None
duration = prefix.duration_average
if duration < 0:
if prefix.max_exec_time > 0:
duration = 2 * prefix.max_exec_time
else:
duration = self.UNKNOWN_TASK_DURATION
return duration

def _calc_occupancy(
self,
task_prefix_count: dict[str, int],
network_occ: float,
) -> float:
res = 0.0
for prefix_name, count in task_prefix_count.items():
# TODO: Deal with unknown tasks better
prefix = self.task_prefixes[prefix_name]
assert prefix is not None
duration = prefix.duration_average
if duration < 0:
if prefix.max_exec_time > 0:
duration = 2 * prefix.max_exec_time
else:
duration = self.UNKNOWN_TASK_DURATION
duration = self._get_prefix_duration(self.task_prefixes[prefix_name])
res += duration * count
occ = res + network_occ / self.bandwidth
assert occ >= 0, (occ, res, network_occ, self.bandwidth)
Expand Down Expand Up @@ -2536,13 +2545,6 @@ def _transition_processing_memory(
action=startstop["action"],
)

s = self.unknown_durations.pop(ts.prefix.name, set())
steal = self.extensions.get("stealing")
if steal:
for tts in s:
if tts.processing_on:
steal.recalculate_cost(tts)

############################
# Update State Information #
############################
Expand Down Expand Up @@ -3171,26 +3173,6 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> float:
nbytes = sum(dts.get_nbytes() for dts in deps)
return nbytes / self.bandwidth

def get_task_duration(self, ts: TaskState) -> float:
"""Get the estimated computation cost of the given task (not including
any communication cost).
If no data has been observed, value of
`distributed.scheduler.default-task-durations` are used. If none is set
for this task, `distributed.scheduler.unknown-task-duration` is used
instead.
"""
prefix = ts.prefix
duration: float = prefix.duration_average
if duration >= 0:
return duration

s = self.unknown_durations.get(prefix.name)
if s is None:
self.unknown_durations[prefix.name] = s = set()
s.add(ts)
return self.UNKNOWN_TASK_DURATION

def valid_workers(self, ts: TaskState) -> set[WorkerState] | None:
"""Return set of currently valid workers for key
Expand Down Expand Up @@ -3569,20 +3551,15 @@ def _client_releases_keys(
elif ts.state != "erred" and not ts.waiters:
recommendations[ts.key] = "released"

def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]:
def _task_to_msg(self, ts: TaskState) -> dict[str, Any]:
"""Convert a single computational task to a message"""
# FIXME: The duration attribute is not used on worker. We could save ourselves the
# time to compute and submit this
if duration < 0:
duration = self.get_task_duration(ts)
ts.run_id = next(TaskState._run_id_iterator)
assert ts.priority, ts
msg: dict[str, Any] = {
"op": "compute-task",
"key": ts.key,
"run_id": ts.run_id,
"priority": ts.priority,
"duration": duration,
"stimulus_id": f"compute-task-{time()}",
"who_has": {
dts.key: tuple(ws.address for ws in (dts.who_has or ()))
Expand Down Expand Up @@ -6003,12 +5980,10 @@ async def remove_client_from_events() -> None:
cleanup_delay, remove_client_from_events
)

def send_task_to_worker(
self, worker: str, ts: TaskState, duration: float = -1
) -> None:
def send_task_to_worker(self, worker: str, ts: TaskState) -> None:
"""Send a single computational task to a worker"""
try:
msg = self._task_to_msg(ts, duration)
msg = self._task_to_msg(ts)
self.worker_send(worker, msg)
except Exception as e:
logger.exception(e)
Expand Down Expand Up @@ -8859,10 +8834,7 @@ def adaptive_target(self, target_duration=None):
queued = take(100, concat([self.queued, self.unrunnable.keys()]))
queued_occupancy = 0
for ts in queued:
if ts.prefix.duration_average == -1:
queued_occupancy += self.UNKNOWN_TASK_DURATION
else:
queued_occupancy += ts.prefix.duration_average
queued_occupancy += self._get_prefix_duration(ts.prefix)

tasks_ready = len(self.queued) + len(self.unrunnable)
if tasks_ready > 100:
Expand Down
56 changes: 40 additions & 16 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ class WorkStealing(SchedulerPlugin):
metrics: dict[str, dict[int, float]]
_in_flight_event: asyncio.Event
_request_counter: int
#: Tasks with unknown duration, grouped by prefix
#: {task prefix: {ts, ts, ...}}
unknown_durations: dict[str, set[TaskState]]

def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
Expand All @@ -111,6 +114,7 @@ def __init__(self, scheduler: Scheduler):
self.in_flight_occupancy = defaultdict(int)
self.in_flight_tasks = defaultdict(int)
self._in_flight_event = asyncio.Event()
self.unknown_durations = {}
self.metrics = {
"request_count_total": defaultdict(int),
"request_cost_total": defaultdict(int),
Expand Down Expand Up @@ -188,6 +192,13 @@ def transition(
ts = self.scheduler.tasks[key]
self.remove_key_from_stealable(ts)
self._remove_from_in_flight(ts)

if finish == "memory":
s = self.unknown_durations.pop(ts.prefix.name, set())
for tts in s:
if tts.processing_on:
self.recalculate_cost(tts)

if finish == "processing":
ts = self.scheduler.tasks[key]
self.put_key_in_stealable(ts)
Expand Down Expand Up @@ -223,13 +234,27 @@ def recalculate_cost(self, ts: TaskState) -> None:

def put_key_in_stealable(self, ts: TaskState) -> None:
cost_multiplier, level = self.steal_time_ratio(ts)
if cost_multiplier is not None:
assert level is not None
assert ts.processing_on
ws = ts.processing_on
worker = ws.address
self.stealable[worker][level].add(ts)
self.key_stealable[ts] = (worker, level)

if cost_multiplier is None:
return

prefix = ts.prefix
duration = self.scheduler._get_prefix_duration(prefix)

assert level is not None
assert ts.processing_on
ws = ts.processing_on
worker = ws.address
self.stealable[worker][level].add(ts)
self.key_stealable[ts] = (worker, level)

if duration == ts.prefix.duration_average:
return

if prefix.name not in self.unknown_durations:
self.unknown_durations[prefix.name] = set()

self.unknown_durations[prefix.name].add(ts)

def remove_key_from_stealable(self, ts: TaskState) -> None:
result = self.key_stealable.pop(ts, None)
Expand All @@ -255,7 +280,7 @@ def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, Non
if not ts.dependencies: # no dependencies fast path
return 0, 0

compute_time = self.scheduler.get_task_duration(ts)
compute_time = self.scheduler._get_prefix_duration(ts.prefix)

if not compute_time:
# occupancy/ws.processing[ts] is only allowed to be zero for
Expand Down Expand Up @@ -301,12 +326,9 @@ def move_task_request(

# TODO: occupancy no longer concats linearly so we can't easily
# assume that the network cost would go down by that much
victim_duration = self.scheduler.get_task_duration(
ts
) + self.scheduler.get_comm_cost(ts, victim)
thief_duration = self.scheduler.get_task_duration(
ts
) + self.scheduler.get_comm_cost(ts, thief)
compute = self.scheduler._get_prefix_duration(ts.prefix)
victim_duration = compute + self.scheduler.get_comm_cost(ts, victim)
thief_duration = compute + self.scheduler.get_comm_cost(ts, thief)

self.scheduler.stream_comms[victim.address].send(
{"op": "steal-request", "key": key, "stimulus_id": stimulus_id}
Expand Down Expand Up @@ -457,8 +479,7 @@ def balance(self) -> None:
occ_victim = self._combined_occupancy(victim)
comm_cost_thief = self.scheduler.get_comm_cost(ts, thief)
comm_cost_victim = self.scheduler.get_comm_cost(ts, victim)
compute = self.scheduler.get_task_duration(ts)

compute = self.scheduler._get_prefix_duration(ts.prefix)
if (
occ_thief + comm_cost_thief + compute
<= occ_victim - (comm_cost_victim + compute) / 2
Expand All @@ -483,6 +504,8 @@ def balance(self) -> None:
occ_thief = self._combined_occupancy(thief)
nproc_thief = self._combined_nprocessing(thief)

# FIXME: In the worst case, the victim may have 3x the amount of work
# of the thief when this aborts balancing.
if not self.scheduler.is_unoccupied(
thief, occ_thief, nproc_thief
):
Expand Down Expand Up @@ -514,6 +537,7 @@ def restart(self, scheduler: Any) -> None:
s.clear()

self.key_stealable.clear()
self.unknown_durations.clear()

def story(self, *keys_or_ts: str | TaskState) -> list:
keys = {key.key if not isinstance(key, str) else key for key in keys_or_ts}
Expand Down
21 changes: 12 additions & 9 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2788,24 +2788,26 @@ async def test_retire_workers_bad_params(c, s, a, b):
@gen_cluster(
client=True, config={"distributed.scheduler.default-task-durations": {"inc": 100}}
)
async def test_get_task_duration(c, s, a, b):
async def test_get_prefix_duration(c, s, a, b):
future = c.submit(inc, 1)
await future
assert 10 < s.task_prefixes["inc"].duration_average < 100

ts_pref1 = s.new_task("inc-abcdefab", None, "released")
assert 10 < s.get_task_duration(ts_pref1) < 100
assert 10 < s._get_prefix_duration(ts_pref1.prefix) < 100

extension = s.extensions["stealing"]
# make sure get_task_duration adds TaskStates to unknown dict
assert len(s.unknown_durations) == 0
assert len(extension.unknown_durations) == 0
x = c.submit(slowinc, 1, delay=0.5)
while len(s.tasks) < 3:
await asyncio.sleep(0.01)

ts = s.tasks[x.key]
assert s.get_task_duration(ts) == 0.5 # default
assert len(s.unknown_durations) == 1
assert len(s.unknown_durations["slowinc"]) == 1
assert s._get_prefix_duration(ts.prefix) == 0.5 # default

assert len(extension.unknown_durations) == 1
assert len(extension.unknown_durations["slowinc"]) == 1


@gen_cluster(client=True)
Expand Down Expand Up @@ -3338,10 +3340,11 @@ async def test_unknown_task_duration_config(client, s, a, b):
future = client.submit(slowinc, 1)
while not s.tasks:
await asyncio.sleep(0.001)
assert sum(s.get_task_duration(ts) for ts in s.tasks.values()) == 3600
assert len(s.unknown_durations) == 1
assert sum(s._get_prefix_duration(ts.prefix) for ts in s.tasks.values()) == 3600
extension = s.extensions["stealing"]
assert len(extension.unknown_durations) == 1
await wait(future)
assert len(s.unknown_durations) == 0
assert len(extension.unknown_durations) == 0


@gen_cluster()
Expand Down
Loading

0 comments on commit f97cf57

Please # to comment.