Skip to content

Commit

Permalink
Efficient restart
Browse files Browse the repository at this point in the history
  • Loading branch information
amaurya committed Jun 6, 2024
1 parent 7dddee7 commit 5edc052
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 54 deletions.
4 changes: 2 additions & 2 deletions deepspeed/checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def clone_tensors_for_torch_save(item, device=torch.device('cpu')):
- copy of ``item`` with cloned tensors on target device
"""
if torch.is_tensor(item):
# return item.detach().clone().to(device)
return item.clone().detach()
return item.detach().clone().to(device)
# return item.clone().detach()
# return item.contiguous().detach()
elif isinstance(item, list):
return [clone_tensors_for_torch_save(v, device) for v in item]
Expand Down
11 changes: 6 additions & 5 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,13 +415,14 @@ def load_state_dict(self,
checkpoint_folder,
load_optimizer_states=True,
load_from_fp32_weights=False,
load_serial=None):
load_serial=None,
is_datastates_llm=False):
if checkpoint_folder:
self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)
else:
self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights)
self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights, is_datastates_llm)

def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False):
def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False, is_datastates_llm=False):

dp_rank = dist.get_rank(group=self.dp_process_group)
current_rank_sd = state_dict_list[dp_rank]
Expand All @@ -431,9 +432,9 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l
ckpt_version = pkg_version.parse(ckpt_version)

self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad)

if load_optimizer_states:
self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])
self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE], is_datastates_llm)

if load_from_fp32_weights:
for current, saved in zip(self.fp32_groups_flat_partition,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
from collections import deque
import sys
import logging

from deepspeed.utils import groups

class TSNAsyncCheckpointEngine(CheckpointEngine):

def __init__(self, config_params, r):
t = time.time()
super().__init__(config_params, r)
self.rank = r
# print("<<<<<<<<<<< Inited on rank ", self.rank)
self.prev_sn = deque()
logger = logging.getLogger("torchsnapshot.scheduler")
logger.setLevel(logging.DEBUG)
Expand Down Expand Up @@ -53,77 +52,50 @@ def _to_statedict(self, ele, snapshot):

@instrument_w_nvtx
def save(self, state_dict, path: str):
# logger.info(f"[TSNAsyncCheckpointEngine][Rank {self.rank}] Starting ckpt {path} at {time.time_ns()}")
logger.info(f"[TSNAsyncCheckpointEngine][Rank {self.rank}] Starting ckpt {path} at {time.time_ns()}")
t = time.time()
try:
# x = self._to_statedict(state_dict, {})
p = Snapshot.async_take(path=path,
app_state={"objects": StateDict(ckpt=state_dict)},
replicated=[]
)
self.prev_sn.append((path, p))
# p.wait()
# Snapshot.take(path=path, app_state={"objects": StateDict(ckpt=state_dict)}, replicated=[])
logger.info(f"[TSNAsyncCheckpointEngine][Rank {self.rank}] Saved {path}. in time {time.time()-t} started at {time.time_ns()}")
# logger.info(f"[TSNAsyncCheckpointEngine][Rank {self.rank}] Saved {path}. in time {time.time()-t} started at {time.time_ns()}")
except Exception as e:
print(f"TSNAsyncCheckpointEngine][Rank {self.rank}] Async checkpoint failed with error: {e}")
sys.exit(-1)
# import pdb; pdb.set_trace()
return None

def load(self, path: str, map_location=None):
logger.info(f"[TSNAsyncCheckpointEngine] Loading checkpoint from {path}...")
partition = torch.load(path, map_location=map_location)
logger.info(f"[TSNAsyncCheckpointEngine] Loaded checkpoint from {path}.")
return partition
snapshot = Snapshot(path=path)
partition={"objects": StateDict(ckpt={})}
snapshot.restore(app_state=partition)
return partition["objects"]["ckpt"]

def commit(self, tag):
logger.info(f"[TSNAsyncCheckpointEngine] Checkpoint {tag} is ready now!")
return True

def wait(self, prev_version = -1):
# while len(self.prev_sn) > 0:
# try:
# (path, p) = self.prev_sn.popleft()
# # logger.info(f"[TSNAsyncCheckpointEngine][Rank {self.rank}] In wait for {len(self.prev_sn)} for path {path}.")
# # for i, (x, y) in enumerate(self.prev_sn):
# # print(i, x, y, y.done())
# if not p.done():
# # logger.info(f"[TSNAsyncCheckpointEngine] Waiting for {path}.")
# p.wait()
# # for i, (x, y) in enumerate(self.prev_sn):
# # if y.done():
# # logger.info(f"[TSNAsyncCheckpointEngine] Done checkpointing {i}, {x}, {y}.")
# # del self.prev_sn[i]
# # break
# except Exception as e:
# print(f"TSNAsyncCheckpointEngine][Rank {self.rank}] Async checkpoint waiting failed with error: {e}")
# sys.exit(-1)
return True

def shutdown(self):
# self.wait()
t = time.time()
while len(self.prev_sn) > 0:
try:
inner_t = time.time()
(path, p) = self.prev_sn.popleft()
# logger.info(f"[TSNAsyncCheckpointEngine][Rank {self.rank}] In wait for {len(self.prev_sn)} for path {path}.")
# for i, (x, y) in enumerate(self.prev_sn):
# print(i, x, y, y.done())
# while not p.done():
# pass
while not p.done():
# logger.info(f"[TSNAsyncCheckpointEngine] Waiting for {path}.")
logger.info(f"[TSNAsyncCheckpointEngine] Waiting for {path}.")
p.wait()
logger.info(f"[TSNAsyncCheckpointEngine][Rank {self.rank}] at time {time.time_ns()} time {time.time()-inner_t} len {len(self.prev_sn)} for path {path}.")
# for i, (x, y) in enumerate(self.prev_sn):
# if y.done():
# logger.info(f"[TSNAsyncCheckpointEngine] Done checkpointing {i}, {x}, {y}.")
# del self.prev_sn[i]
# break
except Exception as e:
print(f"TSNAsyncCheckpointEngine][Rank {self.rank}] Async checkpoint waiting failed with error: {e}")
sys.exit(-1)
logger.info(f"[TSNAsyncCheckpointEngine] Shutdown took time {time.time()-t}")
return True
return True

def __del__(self):
self.shutdown()
14 changes: 8 additions & 6 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2666,6 +2666,8 @@ def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, f

if checkpoint.get(FROZEN_PARAM_FRAGMENTS, None) is not None:
saved_frozen_params = checkpoint[FROZEN_PARAM_FRAGMENTS]
print(f"In load_module_state_dict")
import pdb; pdb.set_trace();
for param in self.module.parameters():
if param.requires_grad:
continue
Expand Down Expand Up @@ -2804,7 +2806,6 @@ def load_checkpoint(self,
load_lr_scheduler_states=load_lr_scheduler_states,
load_module_only=load_module_only,
custom_load_fn=custom_load_fn)

load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled())
if load_zero_checkpoint:
if load_optimizer_states and not load_module_only:
Expand Down Expand Up @@ -2998,13 +2999,13 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
if zero_sd_list is None:
return False

self.optimizer.load_state_dict(state_dict_list=zero_sd_list,
load_optimizer_states=load_optimizer_states,
load_from_fp32_weights=self.zero_load_from_fp32_weights(),
checkpoint_folder=checkpoint_folder,
load_serial=load_serial)

load_serial=load_serial,
is_datastates_llm="DataStatesCheckpointEngine" in str(type(self.checkpoint_engine)))

if self.load_universal_checkpoint():
logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}')
else:
Expand Down Expand Up @@ -3205,8 +3206,9 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True,
# if self._config is not None and not self._config.veloc_ckpt_config and not self._config.async_ckpt_config:
self.checkpoint_engine.commit(tag)
if save_latest and rank == 0:
with open(os.path.join(save_dir, 'latest'), 'w') as fd:
fd.write(tag)
if not os.path.exists(os.path.join(save_dir, 'latest')):
with open(os.path.join(save_dir, 'latest'), 'w') as fd:
fd.write(tag)

dist.barrier()

Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,7 +1410,7 @@ def _exec_schedule(self, pipe_schedule):
# Equivalent to: self._exec_forward_pass(buffer_id=0)
self._exec_instr = MethodType(self._INSTRUCTION_MAP[type(cmd)], self)
self._exec_instr(**cmd.kwargs)
print(f"[Rank {self.global_rank}] <<<{str(type(cmd))}:{time.time()-ts}>>>")
# print(f"[Rank {self.global_rank}] <<<{str(type(cmd))}:{time.time()-ts}>>>")

# if type(cmd) == schedule.OptimizerStep:
# print("In after optimizer step")
Expand Down
40 changes: 40 additions & 0 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from deepspeed.accelerator import get_accelerator
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
import time
from concurrent.futures import ThreadPoolExecutor

class PipelineError(Exception):
"""Errors related to the use of deepspeed.PipelineModule """
Expand Down Expand Up @@ -604,7 +605,27 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal
checkpoint_engine.save(final_state_dict, model_ckpt_path)
# logger.info(f"[SaveStateDict] ckpt engine save took {time.time()-t} for path {model_ckpt_path}")


def datastates_llm_sd_loader(self, layer, strict, model_ckpt_list, version, checkpoint_engine, mp_world_size, mp_rank, module_key=None, is_pipe_parallel=True):
sd_loader = SDLoaderFactory.get_sd_loader(model_ckpt_list,
version=version,
checkpoint_engine=checkpoint_engine)
load_path, checkpoint, _ = sd_loader.load(mp_world_size, mp_rank, module_key=module_key, is_pipe_parallel=is_pipe_parallel)
layer.load_state_dict(checkpoint, strict=strict)
return

def load_state_dir(self, load_dir, checkpoint_engine, strict=True):
num_loader_threads = 1
is_datastates_llm = False
futures = []
t = time.time()
sd_loader_time = 0
sd_loader_load_time = 0
sd_layer_load_sd_time = 0
if "DataStatesCheckpointEngine" in str(type(checkpoint_engine)):
num_loader_threads = 4
is_datastates_llm = True
executor = ThreadPoolExecutor(max_workers=num_loader_threads)
for idx, layer in enumerate(self.forward_funcs):
# Functions, etc. will not have state_dicts
if not hasattr(layer, 'load_state_dict'):
Expand All @@ -615,19 +636,38 @@ def load_state_dir(self, load_dir, checkpoint_engine, strict=True):
mp_rank = self._grid.get_slice_parallel_rank()
mp_world_size = self._grid.get_slice_parallel_world_size()

if is_datastates_llm:
version=2.0
module_key = None
is_pipe_parallel = True
f = executor.submit(self.datastates_llm_sd_loader, layer, strict, model_ckpt_list, version, checkpoint_engine, mp_world_size, mp_rank, module_key, is_pipe_parallel)
futures.append(f)
continue
intime = time.time()
sd_loader = SDLoaderFactory.get_sd_loader(model_ckpt_list,
version=2.0,
checkpoint_engine=checkpoint_engine)
sd_loader_time += time.time()-intime
intime = time.time()
load_path, checkpoint, _ = sd_loader.load(mp_world_size, mp_rank, module_key=None, is_pipe_parallel=True)
sd_loader_load_time += time.time()-intime

intime = time.time()
layer.load_state_dict(checkpoint, strict=strict)
sd_layer_load_sd_time += time.time()-intime

# if self._grid.data_parallel_id == 0:
# logger.info(
# f'RANK={self.global_rank} Loaded layer={idx+self._local_start} file={load_path}'
# )

if is_datastates_llm:
executor.shutdown(wait=True)
logger.info(f"[Rank {dist.get_rank()}] Loaded the layer_* in {time.time()-t}, sd_init {sd_loader_time}, sd_load: {sd_loader_load_time}, layer_load_sd: {sd_layer_load_sd_time}")

self._synchronize_tied_weights()



def _is_checkpointable(self, funcs):

Expand Down

0 comments on commit 5edc052

Please # to comment.