From 5edc052614d5bbdabe1b800a8d0dd4925351ad81 Mon Sep 17 00:00:00 2001 From: amaurya Date: Thu, 6 Jun 2024 02:08:00 +0000 Subject: [PATCH] Efficient restart --- deepspeed/checkpoint/utils.py | 4 +- deepspeed/runtime/bf16_optimizer.py | 11 ++-- .../torch_sn_async_checkpoint_engine.py | 52 +++++-------------- deepspeed/runtime/engine.py | 14 ++--- deepspeed/runtime/pipe/engine.py | 2 +- deepspeed/runtime/pipe/module.py | 40 ++++++++++++++ 6 files changed, 69 insertions(+), 54 deletions(-) diff --git a/deepspeed/checkpoint/utils.py b/deepspeed/checkpoint/utils.py index 493d6f3e681e..d4a1dee92456 100644 --- a/deepspeed/checkpoint/utils.py +++ b/deepspeed/checkpoint/utils.py @@ -51,8 +51,8 @@ def clone_tensors_for_torch_save(item, device=torch.device('cpu')): - copy of ``item`` with cloned tensors on target device """ if torch.is_tensor(item): - # return item.detach().clone().to(device) - return item.clone().detach() + return item.detach().clone().to(device) + # return item.clone().detach() # return item.contiguous().detach() elif isinstance(item, list): return [clone_tensors_for_torch_save(v, device) for v in item] diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index aaa836bf1c31..1250897136c4 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -415,13 +415,14 @@ def load_state_dict(self, checkpoint_folder, load_optimizer_states=True, load_from_fp32_weights=False, - load_serial=None): + load_serial=None, + is_datastates_llm=False): if checkpoint_folder: self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights) else: - self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights) + self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights, is_datastates_llm) - def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False): + def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False, is_datastates_llm=False): dp_rank = dist.get_rank(group=self.dp_process_group) current_rank_sd = state_dict_list[dp_rank] @@ -431,9 +432,9 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l ckpt_version = pkg_version.parse(ckpt_version) self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad) - + if load_optimizer_states: - self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE]) + self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE], is_datastates_llm) if load_from_fp32_weights: for current, saved in zip(self.fp32_groups_flat_partition, diff --git a/deepspeed/runtime/checkpoint_engine/torch_sn_async_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/torch_sn_async_checkpoint_engine.py index e7bea7e3033b..bc6d5140ba46 100644 --- a/deepspeed/runtime/checkpoint_engine/torch_sn_async_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/torch_sn_async_checkpoint_engine.py @@ -13,7 +13,7 @@ from collections import deque import sys import logging - +from deepspeed.utils import groups class TSNAsyncCheckpointEngine(CheckpointEngine): @@ -21,7 +21,6 @@ def __init__(self, config_params, r): t = time.time() super().__init__(config_params, r) self.rank = r - # print("<<<<<<<<<<< Inited on rank ", self.rank) self.prev_sn = deque() logger = logging.getLogger("torchsnapshot.scheduler") logger.setLevel(logging.DEBUG) @@ -53,10 +52,9 @@ def _to_statedict(self, ele, snapshot): @instrument_w_nvtx def save(self, state_dict, path: str): - # logger.info(f"[TSNAsyncCheckpointEngine][Rank {self.rank}] Starting ckpt {path} at {time.time_ns()}") + logger.info(f"[TSNAsyncCheckpointEngine][Rank {self.rank}] Starting ckpt {path} at {time.time_ns()}") t = time.time() try: - # x = self._to_statedict(state_dict, {}) p = Snapshot.async_take(path=path, app_state={"objects": StateDict(ckpt=state_dict)}, replicated=[] @@ -64,7 +62,7 @@ def save(self, state_dict, path: str): self.prev_sn.append((path, p)) # p.wait() # Snapshot.take(path=path, app_state={"objects": StateDict(ckpt=state_dict)}, replicated=[]) - logger.info(f"[TSNAsyncCheckpointEngine][Rank {self.rank}] Saved {path}. in time {time.time()-t} started at {time.time_ns()}") + # logger.info(f"[TSNAsyncCheckpointEngine][Rank {self.rank}] Saved {path}. in time {time.time()-t} started at {time.time_ns()}") except Exception as e: print(f"TSNAsyncCheckpointEngine][Rank {self.rank}] Async checkpoint failed with error: {e}") sys.exit(-1) @@ -72,58 +70,32 @@ def save(self, state_dict, path: str): return None def load(self, path: str, map_location=None): - logger.info(f"[TSNAsyncCheckpointEngine] Loading checkpoint from {path}...") - partition = torch.load(path, map_location=map_location) - logger.info(f"[TSNAsyncCheckpointEngine] Loaded checkpoint from {path}.") - return partition + snapshot = Snapshot(path=path) + partition={"objects": StateDict(ckpt={})} + snapshot.restore(app_state=partition) + return partition["objects"]["ckpt"] def commit(self, tag): logger.info(f"[TSNAsyncCheckpointEngine] Checkpoint {tag} is ready now!") return True def wait(self, prev_version = -1): - # while len(self.prev_sn) > 0: - # try: - # (path, p) = self.prev_sn.popleft() - # # logger.info(f"[TSNAsyncCheckpointEngine][Rank {self.rank}] In wait for {len(self.prev_sn)} for path {path}.") - # # for i, (x, y) in enumerate(self.prev_sn): - # # print(i, x, y, y.done()) - # if not p.done(): - # # logger.info(f"[TSNAsyncCheckpointEngine] Waiting for {path}.") - # p.wait() - # # for i, (x, y) in enumerate(self.prev_sn): - # # if y.done(): - # # logger.info(f"[TSNAsyncCheckpointEngine] Done checkpointing {i}, {x}, {y}.") - # # del self.prev_sn[i] - # # break - # except Exception as e: - # print(f"TSNAsyncCheckpointEngine][Rank {self.rank}] Async checkpoint waiting failed with error: {e}") - # sys.exit(-1) return True def shutdown(self): - # self.wait() t = time.time() while len(self.prev_sn) > 0: try: inner_t = time.time() (path, p) = self.prev_sn.popleft() - # logger.info(f"[TSNAsyncCheckpointEngine][Rank {self.rank}] In wait for {len(self.prev_sn)} for path {path}.") - # for i, (x, y) in enumerate(self.prev_sn): - # print(i, x, y, y.done()) - # while not p.done(): - # pass while not p.done(): - # logger.info(f"[TSNAsyncCheckpointEngine] Waiting for {path}.") + logger.info(f"[TSNAsyncCheckpointEngine] Waiting for {path}.") p.wait() - logger.info(f"[TSNAsyncCheckpointEngine][Rank {self.rank}] at time {time.time_ns()} time {time.time()-inner_t} len {len(self.prev_sn)} for path {path}.") - # for i, (x, y) in enumerate(self.prev_sn): - # if y.done(): - # logger.info(f"[TSNAsyncCheckpointEngine] Done checkpointing {i}, {x}, {y}.") - # del self.prev_sn[i] - # break except Exception as e: print(f"TSNAsyncCheckpointEngine][Rank {self.rank}] Async checkpoint waiting failed with error: {e}") sys.exit(-1) logger.info(f"[TSNAsyncCheckpointEngine] Shutdown took time {time.time()-t}") - return True \ No newline at end of file + return True + + def __del__(self): + self.shutdown() \ No newline at end of file diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index fc12919b40db..430c4698e4d7 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2666,6 +2666,8 @@ def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, f if checkpoint.get(FROZEN_PARAM_FRAGMENTS, None) is not None: saved_frozen_params = checkpoint[FROZEN_PARAM_FRAGMENTS] + print(f"In load_module_state_dict") + import pdb; pdb.set_trace(); for param in self.module.parameters(): if param.requires_grad: continue @@ -2804,7 +2806,6 @@ def load_checkpoint(self, load_lr_scheduler_states=load_lr_scheduler_states, load_module_only=load_module_only, custom_load_fn=custom_load_fn) - load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled()) if load_zero_checkpoint: if load_optimizer_states and not load_module_only: @@ -2998,13 +2999,13 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag) if zero_sd_list is None: return False - self.optimizer.load_state_dict(state_dict_list=zero_sd_list, load_optimizer_states=load_optimizer_states, load_from_fp32_weights=self.zero_load_from_fp32_weights(), checkpoint_folder=checkpoint_folder, - load_serial=load_serial) - + load_serial=load_serial, + is_datastates_llm="DataStatesCheckpointEngine" in str(type(self.checkpoint_engine))) + if self.load_universal_checkpoint(): logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}') else: @@ -3205,8 +3206,9 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, # if self._config is not None and not self._config.veloc_ckpt_config and not self._config.async_ckpt_config: self.checkpoint_engine.commit(tag) if save_latest and rank == 0: - with open(os.path.join(save_dir, 'latest'), 'w') as fd: - fd.write(tag) + if not os.path.exists(os.path.join(save_dir, 'latest')): + with open(os.path.join(save_dir, 'latest'), 'w') as fd: + fd.write(tag) dist.barrier() diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 22cd397b132a..1818cf655513 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -1410,7 +1410,7 @@ def _exec_schedule(self, pipe_schedule): # Equivalent to: self._exec_forward_pass(buffer_id=0) self._exec_instr = MethodType(self._INSTRUCTION_MAP[type(cmd)], self) self._exec_instr(**cmd.kwargs) - print(f"[Rank {self.global_rank}] <<<{str(type(cmd))}:{time.time()-ts}>>>") + # print(f"[Rank {self.global_rank}] <<<{str(type(cmd))}:{time.time()-ts}>>>") # if type(cmd) == schedule.OptimizerStep: # print("In after optimizer step") diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 45d420bea291..10ebea7633c5 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -22,6 +22,7 @@ from deepspeed.accelerator import get_accelerator from deepspeed.checkpoint.utils import clone_tensors_for_torch_save import time +from concurrent.futures import ThreadPoolExecutor class PipelineError(Exception): """Errors related to the use of deepspeed.PipelineModule """ @@ -604,7 +605,27 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal checkpoint_engine.save(final_state_dict, model_ckpt_path) # logger.info(f"[SaveStateDict] ckpt engine save took {time.time()-t} for path {model_ckpt_path}") + + def datastates_llm_sd_loader(self, layer, strict, model_ckpt_list, version, checkpoint_engine, mp_world_size, mp_rank, module_key=None, is_pipe_parallel=True): + sd_loader = SDLoaderFactory.get_sd_loader(model_ckpt_list, + version=version, + checkpoint_engine=checkpoint_engine) + load_path, checkpoint, _ = sd_loader.load(mp_world_size, mp_rank, module_key=module_key, is_pipe_parallel=is_pipe_parallel) + layer.load_state_dict(checkpoint, strict=strict) + return + def load_state_dir(self, load_dir, checkpoint_engine, strict=True): + num_loader_threads = 1 + is_datastates_llm = False + futures = [] + t = time.time() + sd_loader_time = 0 + sd_loader_load_time = 0 + sd_layer_load_sd_time = 0 + if "DataStatesCheckpointEngine" in str(type(checkpoint_engine)): + num_loader_threads = 4 + is_datastates_llm = True + executor = ThreadPoolExecutor(max_workers=num_loader_threads) for idx, layer in enumerate(self.forward_funcs): # Functions, etc. will not have state_dicts if not hasattr(layer, 'load_state_dict'): @@ -615,19 +636,38 @@ def load_state_dir(self, load_dir, checkpoint_engine, strict=True): mp_rank = self._grid.get_slice_parallel_rank() mp_world_size = self._grid.get_slice_parallel_world_size() + if is_datastates_llm: + version=2.0 + module_key = None + is_pipe_parallel = True + f = executor.submit(self.datastates_llm_sd_loader, layer, strict, model_ckpt_list, version, checkpoint_engine, mp_world_size, mp_rank, module_key, is_pipe_parallel) + futures.append(f) + continue + intime = time.time() sd_loader = SDLoaderFactory.get_sd_loader(model_ckpt_list, version=2.0, checkpoint_engine=checkpoint_engine) + sd_loader_time += time.time()-intime + intime = time.time() load_path, checkpoint, _ = sd_loader.load(mp_world_size, mp_rank, module_key=None, is_pipe_parallel=True) + sd_loader_load_time += time.time()-intime + intime = time.time() layer.load_state_dict(checkpoint, strict=strict) + sd_layer_load_sd_time += time.time()-intime # if self._grid.data_parallel_id == 0: # logger.info( # f'RANK={self.global_rank} Loaded layer={idx+self._local_start} file={load_path}' # ) + if is_datastates_llm: + executor.shutdown(wait=True) + logger.info(f"[Rank {dist.get_rank()}] Loaded the layer_* in {time.time()-t}, sd_init {sd_loader_time}, sd_load: {sd_loader_load_time}, layer_load_sd: {sd_layer_load_sd_time}") + self._synchronize_tied_weights() + + def _is_checkpointable(self, funcs):