Skip to content

Commit

Permalink
add blacklist mechanism and scheduling in specified order
Browse files Browse the repository at this point in the history
  • Loading branch information
Caozhou1995 committed Dec 22, 2023
1 parent 0af3faa commit 814d042
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 0 deletions.
96 changes: 96 additions & 0 deletions python/paddle/distributed/auto_tuner/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,3 +593,99 @@ def prune_by_sharding_overlap(tuner_cfg, cur_cfg, history_cfgs=[]):
if not result[tuner_cfg['metric_cfg']['name']]:
return True
return False


def is_invalid(cur_cfg, invalid_strategy):
mapping = {
"dp_degree": "dp",
"mp_degree": "mp",
"pp_degree": "pp",
"vpp_degree": "vpp",
"micro_batch_size": "mbs",
"sharding_degree": "sharding",
"sharding_stage": "stage",
"use_recompute": "recompute",
"recompute_granularity": "granularity",
}
granularity_mapping = {0: "full", 1: "full_attn", 2: "core_attn"}
reversed_mapping = {}
for key in mapping:
reversed_mapping[mapping[key]] = key

for strategy in invalid_strategy:
assert isinstance(strategy, str)
dims = strategy.split("_")
has_matched = 0
for dim in dims:
matched = None
for key in reversed_mapping:
if dim.startswith(key):
matched = key
break
if matched:
value = dim[len(matched)]
# * means this strategy turned on
if matched in ["dp", "mp", "pp", "vpp", "sharding"]:
if value == "*":
if cur_cfg[reversed_mapping[matched]] != 1:
has_matched += 1
continue
else:
value = int(value)
if cur_cfg[reversed_mapping[matched]] == value:
has_matched += 1
continue
elif matched == "recompute":
if value == "*":
if cur_cfg[reversed_mapping[matched]]:
has_matched += 1
continue
else:
value = bool(int(value))
if cur_cfg[reversed_mapping[matched]] == value:
has_matched += 1
continue
elif matched == "stage":
if value == "*":
if cur_cfg[reversed_mapping["sharding"]] != 1:
has_matched += 1
continue
else:
value = int(value)
if cur_cfg[reversed_mapping[matched]] == value:
has_matched += 1
continue
elif matched == "mbs":
if value == "*":
has_matched += 1
continue
else:
value = int(value)
if cur_cfg[reversed_mapping[matched]] == value:
has_matched += 1
continue
elif matched == "granularity":
if value == "*":
if cur_cfg[reversed_mapping["use_recompute"]]:
has_matched += 1
continue
else:
value = int(value)
granularity = granularity_mapping[value]
if cur_cfg[reversed_mapping[matched]] == granularity:
has_matched += 1
continue
if has_matched == len(dims):
return True
return False


@register_prune
def prune_by_invalid_strategy(tuner_cfg, cur_cfg, history_cfgs=[]):
if tuner_cfg.get("invalid_strategy", None):
invalid_strategy = tuner_cfg["invalid_strategy"]
assert isinstance(invalid_strategy, list)
if is_invalid(cur_cfg, invalid_strategy):
return True

return False
25 changes: 25 additions & 0 deletions python/paddle/distributed/auto_tuner/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,42 @@ def __init__(self, tuner_cfg):
super().__init__(tuner_cfg)
self.idx = 0
self.all_tasks = search_all(tuner_cfg)
need_baseline = self.tuner_cfg.get("need_baseline", False)
self.baseline = None
if need_baseline:
from .utils import memory_sort

self.all_tasks.sort(key=memory_sort)
self.previous_cfg = None

def search_once(self, history_cfgs):
new_cfg = None
stop = False
if self.previous_cfg:
if self.previous_cfg.get("time", -1) > 0:
if self.baseline is None:
from .utils import performance_sort

self.baseline = self.previous_cfg
self.all_tasks[self.idx :] = sorted(
self.all_tasks[self.idx : len(self.all_tasks)],
key=performance_sort,
)
if self.tuner_cfg.get("schedule_prior", False):
from .utils import sort_by_sepecial

self.all_tasks[self.idx :] = sort_by_sepecial(
self.all_tasks[self.idx :], self.tuner_cfg
)
while not stop:
if self.idx < len(self.all_tasks):
new_cfg = self.all_tasks[self.idx]
self.idx += 1
stop = not self.prune(self.tuner_cfg, new_cfg, history_cfgs)
else:
return None
if self.previous_cfg is None:
self.previous_cfg = new_cfg
return new_cfg


Expand Down
123 changes: 123 additions & 0 deletions python/paddle/distributed/auto_tuner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,9 +414,132 @@ def search_all(tuner_cfg):
logger.info(
f"{search_space_size_before_prune - search_space_size_after_prune} tasks are pruned before launching."
)
if tuner_cfg.get("schedule_prior", False):
pruned_all_cfgs = sort_by_sepecial(pruned_all_cfgs, tuner_cfg)
return pruned_all_cfgs


def sort_by_sepecial(cfgs, tuner_cfg):
assert tuner_cfg.get("schedule_prior", False)
prior_strategy = tuner_cfg["schedule_prior"]
prior_strategy.sort(reverse=True)
for strategy in prior_strategy:
idx = 0
matched_count = 0
while idx < len(cfgs):
cfg = cfgs[idx]
if _matched(cfg, strategy):
cfgs.pop(idx)
cfgs.insert(0, cfg)
matched_count += 1
idx += 1
tmp = cfgs[:matched_count]
tmp.reverse()
cfgs[:matched_count] = tmp
return cfgs


def memory_sort(cfg):
# ascending order in default
return (
-cfg['mp_degree'],
-cfg['pp_degree'],
-cfg['vpp_degree'],
-cfg["sharding_degree"],
-cfg["sharding_stage"],
cfg["micro_batch_size"],
-cfg["use_recompute"],
)


def performance_sort(cfg):
return -cfg["micro_batch_size"]


def _matched(cur_cfg, strategy):
mapping = {
"dp_degree": "dp",
"mp_degree": "mp",
"pp_degree": "pp",
"vpp_degree": "vpp",
"micro_batch_size": "mbs",
"sharding_degree": "sharding",
"sharding_stage": "stage",
"use_recompute": "recompute",
"recompute_granularity": "granularity",
}
granularity_mapping = {0: "full", 1: "full_attn", 2: "core_attn"}
reversed_mapping = {}
for key in mapping:
reversed_mapping[mapping[key]] = key

assert isinstance(strategy, str)
dims = strategy.split("_")
has_matched = 0
for dim in dims:
matched = None
for key in reversed_mapping:
if dim.startswith(key):
matched = key
break
if matched:
value = dim[len(matched)]
# * means this strategy turned on
if matched in ["dp", "mp", "pp", "vpp", "sharding"]:
if value == "*":
if cur_cfg[reversed_mapping[matched]] > 1:
has_matched += 1
continue
else:
value = int(value)
if cur_cfg[reversed_mapping[matched]] == value:
has_matched += 1
continue
elif matched == "recompute":
if value == "*":
if cur_cfg[reversed_mapping[matched]]:
has_matched += 1
continue
else:
value = bool(int(value))
if cur_cfg[reversed_mapping[matched]] == value:
has_matched += 1
continue
elif matched == "stage":
if value == "*":
if cur_cfg[reversed_mapping["sharding"]] > 1:
has_matched += 1
continue
else:
value = int(value)
if cur_cfg[reversed_mapping[matched]] == value:
has_matched += 1
continue
elif matched == "mbs":
if value == "*":
has_matched += 1
continue
else:
value = int(value)
if cur_cfg[reversed_mapping[matched]] == value:
has_matched += 1
continue
elif matched == "granularity":
if value == "*":
if cur_cfg[reversed_mapping["use_recompute"]]:
has_matched += 1
continue
else:
value = int(value)
granularity = granularity_mapping[value]
if cur_cfg[reversed_mapping[matched]] == granularity:
has_matched += 1
continue
if has_matched == len(dims):
return True
return False


def _param2range(param_from_json_file, max_value, param_key):
"""Convert a param from json file to candidates range."""
selected_range = None
Expand Down

0 comments on commit 814d042

Please # to comment.