From 2c24e99a04b802d73aacf185b2003ddb1ee5ac76 Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Mon, 14 Oct 2024 11:37:08 +0800 Subject: [PATCH 1/5] [Unified Checkpoint] Support expert parallel (#9055) * update code --- .../trainer/plugins/unified_checkpoint.py | 239 +++++++++++++++--- 1 file changed, 205 insertions(+), 34 deletions(-) diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index f35b23f95050..5f3cda836d3b 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -27,6 +27,11 @@ from paddle.distributed import fleet from tqdm.auto import tqdm +try: + from paddle.base import core +except: + core = None + from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM from paddlenlp.trainer.argparser import strtobool from paddlenlp.trainer.trainer_utils import ExplicitEnum @@ -389,7 +394,7 @@ def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str) ) return - if self.args.dataset_rank == 0: + if self.args.dataset_rank == 0 or self.args.use_expert_parallel: load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True) def save_non_merge_optimizer(self, model, optimizer, output_dir): @@ -422,6 +427,26 @@ def save_non_merge_optimizer(self, model, optimizer, output_dir): for key in list(master_weights.keys()): master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) + no_sync_kname = [] + model_state_dict = get_expected_state_dict(model) + for k, v in model_state_dict.items(): + if getattr(v, "no_sync", False): + no_sync_kname.append(k) + + hcg = fleet.get_hybrid_communicate_group() + dp_group = hcg.get_data_parallel_group() + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 + if self.args.use_expert_parallel: + for k in list(optim_state_dict.keys()): + model_k = k.split("/")[0] + if dp_rank > 0 and model_k not in no_sync_kname: + optim_state_dict.pop(k) + if master_weights is not None: + for k in list(master_weights.keys()): + model_k = k.split("/")[0] + if dp_rank > 0 and model_k not in no_sync_kname: + master_weights.pop(k) + optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix) @@ -462,7 +487,10 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): key_name = key.split("/") static_name = struct2static_name_mappings[key_name[0]] if has_master_weights: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) with device_guard(): @@ -568,7 +596,7 @@ def load_unified_optimizer(self, args, model, optimizer, resume_from_checkpoint) ) # If not having merge optimizer, then load non-merge optimizer. if not has_merge_optimizer_safetensors: - if self.args.data_parallel_rank == 0: + if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel: returned_optim_state_dict = self.load_non_merge_optimizer( model, optimizer, @@ -588,7 +616,7 @@ def load_unified_optimizer(self, args, model, optimizer, resume_from_checkpoint) ) return returned_optim_state_dict - if self.args.data_parallel_rank == 0: + if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel: returned_optim_state_dict = load_unified_optimizer_locally( self.args, model, optimizer, resume_from_checkpoint, safe_serialization=True ) @@ -651,8 +679,11 @@ def save_single_card_optimizer(self, model, optimizer, output_dir): static2struct_name_mappings = {} state_dict = get_expected_state_dict(model) + fp32_weight = {} for k, v in state_dict.items(): static2struct_name_mappings[v.name] = k + if master_weights is not None and v.dtype == core.VarDesc.VarType.FP32: + fp32_weight[k] = v # rename optimizer param for key in list(optim_state_dict.keys()): @@ -662,6 +693,7 @@ def save_single_card_optimizer(self, model, optimizer, output_dir): if master_weights is not None: for key in list(master_weights.keys()): master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) + master_weights.update(fp32_weight) # save index json index_optimizer_file, index_master_weight_file = {}, {} @@ -744,7 +776,7 @@ def unlink_shared_memory(self): def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False): """ - Only dataset_rank == 0 can enter this function. + Only dataset_rank == 0 or using expert parallel can enter this function. """ index_filename = select_model_weight_index(args, model, resume_from_checkpoint, safe_serialization, local=True) @@ -755,7 +787,14 @@ def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, sa loaded_keys = sharded_metadata["all_checkpoint_keys"] model_state_dict = get_expected_state_dict(model) - expected_keys = set(list(model_state_dict.keys())) + # If using expert parallel, when dp_rank > 0, need to modify the expected_keys here. + if not args.use_expert_parallel or (args.use_expert_parallel and args.data_parallel_rank == 0): + expected_keys = set(list(model_state_dict.keys())) + else: + expected_keys = set() + for key in model_state_dict.keys(): + if getattr(model_state_dict[key], "no_sync", False): + expected_keys.add(key) missing_keys = expected_keys - set(loaded_keys) use_fast_set = True @@ -889,11 +928,17 @@ def unified_checkpoint_into_shards( weights_name = SAFE_WEIGHTS_NAME if safe_serialization else PADDLE_WEIGHTS_NAME shard_file = get_sharded_file_name(args, weights_name) + # renumerize shard_file name for expert_parallel. + if args.use_expert_parallel: + shard_file = rename_shard_file(args, shard_file, weights_name) + for key, weight in state_dict.items(): index_weight_file[key] = shard_file total_size += weight.numel().item() * dtype_byte_size(weight.dtype) - index_file_list, total_size_list = gather_sharded_object(index_weight_file, total_size) + index_file_list, total_size_list = gather_sharded_object( + index_weight_file, total_size, use_expert_parallel=args.use_expert_parallel + ) sharded_index = get_sharded_index( index_file_list, total_size_list, @@ -931,7 +976,7 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin model_keys = list(model_state_dict.keys()) struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} # get optimizer param mappings - expected_keys = get_expected_keys(sharded_metadata, model, optimizer) + expected_keys = get_expected_keys(args, sharded_metadata, model, optimizer) # This should always be a list but, just to be sure. if not isinstance(resolved_archive_file, list): @@ -955,7 +1000,7 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin index_filename=os.path.join(resume_from_checkpoint, index_filename_master_weights), ) - expected_keys_mw = get_expected_keys(sharded_metadata_mw, model, optimizer) + expected_keys_mw = get_expected_keys(args, sharded_metadata_mw, model, optimizer, is_master_weights=True) if not isinstance(resolved_archive_file_mw, list): resolved_archive_file_mw = [resolved_archive_file_mw] if len(resolved_archive_file_mw) > 1: @@ -1005,7 +1050,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected key_name = key.split("/") static_name = struct2static_name_mappings[key_name[0]] if has_master_weights: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) returned_optim_state_dict[key_name] = state_dict_optim.pop(key) @@ -1049,8 +1097,13 @@ def unified_optimizer_into_shards( # get optimizer param mappings static2struct_name_mappings = {} state_dict = get_expected_state_dict(model) + fp32_weight = {} for k, v in state_dict.items(): static2struct_name_mappings[v.name] = k + if master_weights is not None and v.dtype == core.VarDesc.VarType.FP32: + if args.dataset_rank > 0: # deal with different dataset rank. + continue + fp32_weight[k] = v # rename optimizer param for key in list(optim_state_dict.keys()): @@ -1060,6 +1113,7 @@ def unified_optimizer_into_shards( if master_weights is not None: for key in list(master_weights.keys()): master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) + master_weights.update(fp32_weight) # filter optimizer param if master_weights is not None: @@ -1087,6 +1141,7 @@ def unified_optimizer_into_shards( optim_state_dict, tp_actions, filter_optim_keys, + state_dict if args.use_expert_parallel else None, ) paddle.device.cuda.empty_cache() @@ -1096,6 +1151,7 @@ def unified_optimizer_into_shards( master_weights, tp_actions, filter_master_keys, + state_dict if args.use_expert_parallel else None, ) paddle.device.cuda.empty_cache() @@ -1119,12 +1175,18 @@ def unified_optimizer_into_shards( total_master_weight_size += weight.numel().item() * dtype_byte_size(weight.dtype) index_optimizer_filelist, total_optim_size_list = gather_sharded_object( - index_optimizer_file, total_optim_size, is_optimizer=True + index_optimizer_file, + total_optim_size, + is_optimizer=True, + use_expert_parallel=args.use_expert_parallel, ) sharded_optim_index = get_sharded_index(index_optimizer_filelist, total_optim_size_list) if master_weights is not None: index_master_weight_filelist, total_master_weight_size_list = gather_sharded_object( - index_master_weight_file, total_master_weight_size, is_optimizer=True + index_master_weight_file, + total_master_weight_size, + is_optimizer=True, + use_expert_parallel=args.use_expert_parallel, ) sharded_master_weight_index = get_sharded_index(index_master_weight_filelist, total_master_weight_size_list) @@ -1175,15 +1237,20 @@ def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serializa # To decide whether to load the checkpoint locally, or need to dynamically send tensors across machines. local_resume = True - if args.dataset_rank == 0: + if args.dataset_rank == 0 or args.use_expert_parallel: hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() pp_group = hcg.get_pipe_parallel_group() + dp_group = hcg.get_data_parallel_group() + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 need_files = set() state_dict = get_expected_state_dict(model) for key in state_dict.keys(): filename = index["weight_map"][key] + # When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0. + if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): + continue need_files.add(filename) diff_filelist = list(need_files.difference(set(existed_files))) num_diff = paddle.to_tensor([len(diff_filelist)]) @@ -1191,6 +1258,8 @@ def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serializa dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=tp_group) if pp_group.nranks > 1: dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=pp_group) + if args.use_expert_parallel and dp_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=dp_group) if num_diff.item() == 0: local_resume = True else: @@ -1243,8 +1312,10 @@ def check_unified_optimizer(args, model, optimizer, resume_from_checkpoint, safe hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() pp_group = hcg.get_pipe_parallel_group() + dp_group = hcg.get_data_parallel_group() sharding_group = hcg.get_sharding_parallel_group() sharding_rank = sharding_group.rank + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 struct2static_name_mappings = {k: v.name for k, v in model.state_dict().items()} if sharding_group.nranks > 1: param2rank = optimizer._param2rank @@ -1269,9 +1340,10 @@ def check_complete(all_filenames): def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, typename_set=None): # To decide whether to load the checkpoint locally, or need to dynamically distribute the checkpoint. local_resume = True - if args.data_parallel_rank == 0: + if args.data_parallel_rank == 0 or args.use_expert_parallel: need_files = set() state_dict = get_expected_state_dict(model) + for key in state_dict.keys(): if sharding_group.nranks > 1: static_name = struct2static_name_mappings.get(key, None) @@ -1279,6 +1351,13 @@ def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, if param_rank != sharding_rank: continue + # When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0. + if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): + continue + + if is_master_weights and state_dict[key].dtype == core.VarDesc.VarType.FP32: + continue + if not is_master_weights: for type_name in typename_set: type_key = key + "/" + type_name @@ -1296,6 +1375,8 @@ def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=pp_group) if sharding_group.nranks > 1: dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=sharding_group) + if args.use_expert_parallel and dp_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=dp_group) if num_diff.item() == 0: local_resume = True @@ -1548,8 +1629,10 @@ def load_unified_optimizer_dynamically(args, model, optimizer, resume_from_check for key in index["weight_map"].keys(): _, typename = key.split("/") typename_set.add(typename) - struct2static_name_mappings = {k: v.name for k, v in get_expected_state_dict(model).items()} - static2struct_name_mappings = {v.name: k for k, v in get_expected_state_dict(model).items()} + + model_state_dict = get_expected_state_dict(model) + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} + static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} # Get send_table and recv_table. The send table indicates which workers are responsible for sending tensors, and the recv table indicates which workers should receive the tensors. send_table, recv_table = create_optimizer_dispatch_table( args, @@ -1671,7 +1754,10 @@ def check_optimizer_param(parameter): key_name = key.split("/") static_name = struct2static_name_mappings[key_name[0]] if has_master_weights: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) optim_state_dict[key_name] = optim_state_dict.pop(key) @@ -1745,9 +1831,10 @@ def load_single_card_optimizer(args, model, optimizer, resume_from_checkpoint: s key_name = key.split("/") static_name = struct2static_name_mappings[key_name[0]] if has_master_weights: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) - else: - key_name = "_".join([static_name, key_name[1]]) + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) returned_optim_state_dict[key_name] = state_dict_optim.pop(key) returned_optim_state_dict[key_name].name = key_name if has_master_weights: @@ -1872,26 +1959,29 @@ def distributed_send_recv( def get_sharded_file_name(args, file_name, is_optimizer=False): if not is_optimizer: + sd_degree = args.sharding_parallel_degree if args.sharding_parallel_degree > 1 else 1 + size = sd_degree if args.use_expert_parallel else args.dataset_world_size shard_file = file_name.replace( ".pdparams", - f"-{args.logical_process_index + 1:05d}-of-{args.world_size//args.dataset_world_size:05d}.pdparams", + f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdparams", ) shard_file = shard_file.replace( ".safetensors", - f"-{args.logical_process_index + 1:05d}-of-{args.world_size//args.dataset_world_size:05d}.safetensors", + f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.safetensors", ) else: hcg = fleet.get_hybrid_communicate_group() dp_group = hcg.get_data_parallel_group() + size = dp_group.nranks if not args.use_expert_parallel else 1 shard_file = file_name.replace( - ".pdparams", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//dp_group.nranks:05d}.pdparams" + ".pdparams", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdparams" ) shard_file = shard_file.replace( ".safetensors", - f"-{args.logical_process_index + 1:05d}-of-{args.world_size//dp_group.nranks:05d}.safetensors", + f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.safetensors", ) shard_file = shard_file.replace( - ".pdopt", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//dp_group.nranks:05d}.pdopt" + ".pdopt", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdopt" ) return shard_file @@ -1935,7 +2025,7 @@ def reduce_master_weights_status(has_master_weights=False): return data.item() > 0 -def gather_sharded_object(index_file, total_size, is_optimizer=False): +def gather_sharded_object(index_file, total_size, is_optimizer=False, use_expert_parallel=False): index_file_list, total_size_list = [], [] @@ -1969,6 +2059,17 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False): if len(index_file_list) == 0 and len(total_size_list) == 0: index_file_list = [index_file] total_size_list = [total_size] + + if use_expert_parallel: + data_group = hcg.get_data_parallel_group() + if data_group.nranks > 1: + data_index_file_list = [] + data_total_size_list = [] + dist.all_gather_object(data_index_file_list, index_file_list, data_group) + dist.all_gather_object(data_total_size_list, total_size_list, data_group) + index_file_list = flatten_list(data_index_file_list) + total_size_list = flatten_list(data_total_size_list) + if is_optimizer: sharding_group = hcg.get_sharding_parallel_group() if sharding_group.nranks > 1: @@ -1982,16 +2083,58 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False): return index_file_list, total_size_list +def rename_shard_file(args, shard_file, file_name): + """rename shard file when using expert_parallel.""" + assert args.use_expert_parallel, "only expert_parallel need to use this function" + + shard_file_list = [] + + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + pp_group = hcg.get_pipe_parallel_group() + data_group = hcg.get_data_parallel_group() + + if tp_group.nranks > 1: + dist.all_gather_object(shard_file_list, shard_file, tp_group) + if pp_group.nranks > 1: + pp_shard_file_list = [] + dist.all_gather_object( + pp_shard_file_list, shard_file_list if len(shard_file_list) > 0 else shard_file, pp_group + ) + shard_file_list = flatten_list(pp_shard_file_list) + if data_group.nranks > 1: + data_shard_file_list = [] + dist.all_gather_object( + data_shard_file_list, shard_file_list if len(shard_file_list) > 0 else shard_file, data_group + ) + shard_file_list = flatten_list(data_shard_file_list) + + new_index = shard_file_list.index(shard_file) + sd_degree = args.sharding_parallel_degree if args.sharding_parallel_degree > 1 else 1 + shard_file = file_name.replace( + ".pdparams", + f"-{new_index + 1:05d}-of-{args.world_size//sd_degree:05d}.pdparams", + ) + shard_file = shard_file.replace( + ".safetensors", + f"-{new_index + 1:05d}-of-{args.world_size//sd_degree:05d}.safetensors", + ) + return shard_file + + def generate_base_static_name(vname): # return base static name and specific type name, like [embedding_0.w_0, moment1_0] if FP32_MASTER in vname: vname = vname.split("_" + FP32_MASTER + "_") return vname[0], vname[1] else: - vname = vname.split(".") - a = vname[0] + "." + vname[1][:3] - b = vname[1][4:] - return a, b + # Directly deal with type names, for example: moe_gate_1_moment1_0. + type_names = optimizer_scalar_name + optimizer_non_scaler_name + for name in type_names: + if name in vname: + a = vname.split(name)[0][:-1] + b = name + return a, b def filter_params(model_to_save, state_dict, is_optimizer=False): @@ -2087,7 +2230,9 @@ def merge_large_tensor_parallel(tensor, tp_group, tp_action, dst_rank, is_dst): def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() + dp_group = hcg.get_data_parallel_group() tp_rank = tp_group.rank + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 # filter actions for pipeline mode if hcg.get_pipe_parallel_group().nranks > 1: @@ -2105,6 +2250,9 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): continue key = filter_keys[i] tensor = state_dict[key] + # When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0. + if dp_rank > 0 and not getattr(tensor, "no_sync", False): + continue if key in tp_actions: # Get tensor size tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks @@ -2128,16 +2276,24 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): if len(tp_actions) > 0: for x in tp_actions.keys(): - logger.warning(f"key <{x}> need to merge tensor parallel but we can't find in model state.") + logger.debug(f"key <{x}> need to merge tensor parallel but we can't find in model state.") return state_dict_to_save -def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys): +def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys, model_state_dict=None): # Core function for UC hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() + dp_group = hcg.get_data_parallel_group() tp_rank = tp_group.rank + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 + + no_sync_kname = [] + if model_state_dict is not None: + for k, v in model_state_dict.items(): + if getattr(v, "no_sync", False): + no_sync_kname.append(k) state_dict_to_save = {} max_key_len = max([len(_) for _ in all_filter_keys]) @@ -2149,6 +2305,9 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys) # get base model key model_key = filter_keys[i].split("/")[0] tensor = state_dict[filter_keys[i]] + # When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0. + if dp_rank > 0 and model_key not in no_sync_kname: + continue if model_key in tp_actions: # for example: beta1, beta2 if tensor.numel().item() == 1: @@ -2217,7 +2376,7 @@ def get_optimizer_shard_files(optimizer_path, index_filename): return shard_filenames, sharded_metadata -def get_expected_keys(sharded_metadata, model, optimizer): +def get_expected_keys(args, sharded_metadata, model, optimizer, is_master_weights=False): hcg = fleet.get_hybrid_communicate_group() sharding_group = hcg.get_sharding_parallel_group() sharding_rank = sharding_group.rank @@ -2225,11 +2384,23 @@ def get_expected_keys(sharded_metadata, model, optimizer): if in_sharding_parallel_model: params2rank = optimizer._param2rank - struct2static_name_mappings = {k: v.name for k, v in get_expected_state_dict(model).items()} + model_state_dict = get_expected_state_dict(model) + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} expected_keys = [] for key in list(sharded_metadata["all_optimizer_keys"]): key_name = key.split("/")[0] + if ( + is_master_weights + and key_name in model_state_dict + and model_state_dict[key_name].dtype == core.VarDesc.VarType.FP32 + ): + continue + + if args.use_expert_parallel and args.data_parallel_rank > 0: + if key_name in model_state_dict and not getattr(model_state_dict[key_name], "no_sync", False): + continue + static_name = struct2static_name_mappings.get(key_name, None) if in_sharding_parallel_model: From 0ea031c3d93e43168093f1f2706ab75029c45457 Mon Sep 17 00:00:00 2001 From: Weiguo Zhu Date: Mon, 14 Oct 2024 22:40:03 +0800 Subject: [PATCH 2/5] [Unified Checkpoint] Fix generation config save (#9223) --- paddlenlp/trainer/plugins/unified_checkpoint.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index 5f3cda836d3b..54465d8ec816 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -360,6 +360,9 @@ def save_unified_checkpoint(self, model, optimizer, output_dir): config_to_save.architectures = [model_to_save.__class__.__name__] if self.args.should_save: config_to_save.save_pretrained(save_directory) + # save generation config + if model_to_save.can_generate(): + model_to_save.generation_config.save_pretrained(save_directory) paddle.device.cuda.empty_cache() if strtobool(os.getenv("FLAG_LLM_PDC", "False")) and self.args.should_save: @@ -667,6 +670,10 @@ def save_single_card_checkpoint(self, model_to_save, output_dir): config_to_save.architectures = [model_to_save.__class__.__name__] config_to_save.save_pretrained(output_dir) + # save generation config + if model_to_save.can_generate(): + model_to_save.generation_config.save_pretrained(output_dir) + def save_single_card_optimizer(self, model, optimizer, output_dir): """ "Save optimizer for non-distributed environment.""" # Split into optimizer params and master weights. From 6af261d78de169bf3682ac4c050b113796fca518 Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Thu, 26 Sep 2024 16:25:56 +0800 Subject: [PATCH 3/5] [Unified Checkpoint] update async_save_info in develop (#9173) --- paddlenlp/trainer/trainer.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index fd25e6adbcb4..e6f10223bb7a 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2297,16 +2297,7 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op self.model_wrapped.get_all_parameters(convert2cpu=True) if self.args.should_save_model_state: - unified_checkpoint_config_backup = self.args.unified_checkpoint_config - # backup and remove unified_checkpoint_config for not trine stage - if not self.is_in_train: - self.args.unified_checkpoint_config = [] - self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel) - - # recover unified_checkpoint_config for not trine stage - if not self.is_in_train: - self.args.unified_checkpoint_config = unified_checkpoint_config_backup else: if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: os.makedirs(output_dir, exist_ok=True) @@ -2584,10 +2575,9 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` - local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) if ( strtobool(os.getenv("FLAG_LLM_PDC", "False")) - and local_rank == 0 + and paddle.distributed.get_rank() == 0 and self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config ): @@ -2598,9 +2588,10 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ "ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim, "skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config, } - if not os.path.exists(os.path.join(self.args.logging_dir, "async_save_info.json")): - with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f: - json.dump(save_info, f) + if os.path.exists(os.path.join(self.args.logging_dir, "async_save_info.json")): # afs cannot overwrite + os.remove(os.path.join(self.args.logging_dir, "async_save_info.json")) + with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f: + json.dump(save_info, f) if self.args.should_save: if self.tokenizer is not None: @@ -2609,7 +2600,17 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) if self.args.unified_checkpoint: + unified_checkpoint_config_backup = self.args.unified_checkpoint_config + # backup and remove unified_checkpoint_config for not trine stage + if not self.is_in_train: + self.args.unified_checkpoint_config = [] + self.unified_checkpoint_handler.save_unified_checkpoint(self.model, self.optimizer, output_dir) + + # recover unified_checkpoint_config for not trine stage + if not self.is_in_train: + self.args.unified_checkpoint_config = unified_checkpoint_config_backup + return merge_tensor_parallel = merge_tensor_parallel and self.args.use_hybrid_parallel From 4f66a199ff7a6bfc4d2c87c0a4ab04f13394ed2e Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Wed, 16 Oct 2024 14:53:20 +0800 Subject: [PATCH 4/5] [Unified Checkpoint] update async save logic (#9274) * update async save signal * fix async save hang --- .../trainer/plugins/unified_checkpoint.py | 53 +++++++++++--- paddlenlp/trainer/trainer.py | 71 +++++++++++++------ paddlenlp/trainer/trainer_utils.py | 8 ++- paddlenlp/trainer/training_args.py | 5 ++ 4 files changed, 103 insertions(+), 34 deletions(-) diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index 54465d8ec816..4c5b54a20ddb 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -140,7 +140,6 @@ def __init__(self, args): self._process_master_weight = None self._process_optimizer_weight = None self._lock = None - self._shared_save_path = None self._shared_save_model_flag = None self._shared_save_master_weight_flag = None self._shared_save_optimizer_flag = None @@ -148,13 +147,18 @@ def __init__(self, args): if "async_save" in self.args.unified_checkpoint_config: self._lock = multiprocessing.Lock() self._shared_save_model_path = multiprocessing.Array("c", 100000) + self._shared_save_model_signal_path = multiprocessing.Array("c", 100000) self._shared_save_master_weight_path = multiprocessing.Array("c", 100000) + self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000) self._shared_save_optimizer_path = multiprocessing.Array("c", 100000) + self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000) self._shared_save_model_flag = multiprocessing.Array("i", 1) self._shared_save_master_weight_flag = multiprocessing.Array("i", 1) self._shared_save_optimizer_flag = multiprocessing.Array("i", 1) - def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_type="model_weight"): + def _file_save_async_or_sync( + self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight" + ): if is_sync: for k in list(state_dict.keys()): if isinstance(state_dict[k], paddle.Tensor): @@ -169,6 +173,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty meta_dict = self._meta_dict_model shared_save_flag = self._shared_save_model_flag shared_save_path = self._shared_save_model_path + shared_save_signal_path = self._shared_save_model_signal_path if self._process_model_weight is None: self._process_model_weight = multiprocessing.Process( target=self._save_file_async_in_process, @@ -177,12 +182,14 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty self._shm_model_weight.name, self._shared_save_model_flag, self._shared_save_model_path, + self._shared_save_model_signal_path, self._lock, state_dict_type, self.global_rank, ), ) self._process_model_weight.start() + process = self._process_model_weight elif state_dict_type == "master_weight": if self._shm_master_weight is None: self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict) @@ -191,6 +198,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty meta_dict = self._meta_dict_master_weight shared_save_flag = self._shared_save_master_weight_flag shared_save_path = self._shared_save_master_weight_path + shared_save_signal_path = self._shared_save_master_weight_signal_path if self._process_master_weight is None: self._process_master_weight = multiprocessing.Process( target=self._save_file_async_in_process, @@ -199,6 +207,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty self._shm_master_weight.name, self._shared_save_master_weight_flag, self._shared_save_master_weight_path, + self._shared_save_master_weight_signal_path, self._lock, "model_weight" if "skip_save_model_weight" in self.args.unified_checkpoint_config @@ -207,6 +216,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty ), ) self._process_master_weight.start() + process = self._process_master_weight elif state_dict_type == "optimizer_weight": if self._shm_optimizer_weight is None: self._meta_dict_optim, buffer_size = create_meta_dict(state_dict) @@ -215,6 +225,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty meta_dict = self._meta_dict_optim shared_save_flag = self._shared_save_optimizer_flag shared_save_path = self._shared_save_optimizer_path + shared_save_signal_path = self._shared_save_optimizer_signal_path if self._process_optimizer_weight is None: self._process_optimizer_weight = multiprocessing.Process( target=self._save_file_async_in_process, @@ -223,21 +234,26 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty self._shm_optimizer_weight.name, self._shared_save_optimizer_flag, self._shared_save_optimizer_path, + self._shared_save_optimizer_signal_path, self._lock, state_dict_type, self.global_rank, ), ) self._process_optimizer_weight.start() + process = self._process_optimizer_weight while True: # wait until no process is saving. flag_value = shared_save_flag[0] if flag_value == 0: break + if not process.is_alive(): + raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.") time.sleep(0.5) logger.info(f"Wait for the previous save process to finish saving {state_dict_type}") # only save model weight or save master weight, we enter this loop. self._reset_and_update(shared_save_path, path) + self._reset_and_update(shared_save_signal_path, signal_path) _traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf) with self._lock: shared_save_flag[0] = 1 @@ -248,6 +264,7 @@ def _save_file_async_in_process( shm_name, shared_save_flag, shared_save_path, + shared_save_signal_path, lock, state_dict_type, global_rank, @@ -261,11 +278,12 @@ def _save_file_async_in_process( continue if flag_value == 1: # need to save path = shared_save_path[:].decode("utf-8").rstrip("\x00") + signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00") logger.info(f"Start to async save {path}") state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array safe_save_file(state_dict, path, {"format": "np"}) del state_dict - saved_signal_path = os.path.join(os.path.dirname(path), f".{state_dict_type}.done.{global_rank}") + saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}") paddle.save(global_rank, saved_signal_path) with lock: shared_save_flag[0] = 0 @@ -280,7 +298,7 @@ def _reset_and_update(self, shared_array, new_value): encoded_value = new_value.encode("utf-8") shared_array[: len(encoded_value)] = encoded_value - def save_unified_checkpoint(self, model, optimizer, output_dir): + def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None): """save unified checkpoint Args: @@ -317,6 +335,8 @@ def save_unified_checkpoint(self, model, optimizer, output_dir): save_directory = output_dir os.makedirs(save_directory, exist_ok=True) + if signal_dir is not None: + os.makedirs(signal_dir, exist_ok=True) # only for async save # save model weights if not skip_save_model_weight: @@ -329,6 +349,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir): self._file_save_async_or_sync( state_dict, path=os.path.join(save_directory, shard_file), + signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="model_weight", ) @@ -400,7 +421,7 @@ def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str) if self.args.dataset_rank == 0 or self.args.use_expert_parallel: load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True) - def save_non_merge_optimizer(self, model, optimizer, output_dir): + def save_non_merge_optimizer(self, model, optimizer, output_dir, signal_dir): paddle.device.cuda.empty_cache() optim_state_dict = nested_copy(optimizer.state_dict()) master_weights = None @@ -459,12 +480,14 @@ def save_non_merge_optimizer(self, model, optimizer, output_dir): self._file_save_async_or_sync( optim_state_dict, path=os.path.join(output_dir, optimizer_name), + signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="optimizer_weight", ) self._file_save_async_or_sync( master_weights, path=os.path.join(output_dir, master_weights_name), + signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="master_weight", ) @@ -514,22 +537,23 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): return returned_optim_state_dict - def save_unified_optimizer(self, model, optimizer, output_dir): + def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir): """save unified optimizer Args: model (PretrainedModel): model used to get key mapping. optimizer (Optimizer): optimizer to save output_dir (str): Save directory. + signal_dir (str): Asynchronous saving signal directory. """ if "ignore_merge_optimizer" in self.args.unified_checkpoint_config: - self.save_non_merge_optimizer(model, optimizer, output_dir) + self.save_non_merge_optimizer(model, optimizer, output_dir, signal_dir) return if paddle.distributed.get_world_size() <= 1: - self.save_single_card_optimizer(model, optimizer, output_dir) + self.save_single_card_optimizer(model, optimizer, output_dir) # no need to save signal return # Split into naive optimizer params and master weights. @@ -545,6 +569,8 @@ def save_unified_optimizer(self, model, optimizer, output_dir): save_directory = output_dir os.makedirs(save_directory, exist_ok=True) + if signal_dir is not None: + os.makedirs(signal_dir, exist_ok=True) is_sync_save = True if "async_save" in self.args.unified_checkpoint_config: @@ -552,6 +578,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir): self._file_save_async_or_sync( optim_state_dict, path=os.path.join(save_directory, shard_optim_file), + signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="optimizer_weight", ) @@ -559,6 +586,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir): self._file_save_async_or_sync( master_weight_state_dict, path=os.path.join(save_directory, shard_master_weight_file), + signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="master_weight", ) @@ -754,14 +782,20 @@ def unlink_shared_memory(self): if self._shared_save_model_flag is not None: while self._shared_save_model_flag[0] > 0: # async process is saving + if not self._process_model_weight.is_alive(): + raise RuntimeError("The process that saves model_weight has been killed unexpectedly.") time.sleep(0.5) self._shared_save_model_flag[0] = -1 if self._shared_save_master_weight_flag is not None: while self._shared_save_master_weight_flag[0] > 0: + if not self._process_master_weight.is_alive(): + raise RuntimeError("The process that saves master_weight has been killed unexpectedly.") time.sleep(0.5) self._shared_save_master_weight_flag[0] = -1 if self._shared_save_optimizer_flag is not None: while self._shared_save_optimizer_flag[0] > 0: + if not self._process_optimizer_weight.is_alive(): + raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.") time.sleep(0.5) self._shared_save_optimizer_flag[0] = -1 @@ -778,7 +812,8 @@ def unlink_shared_memory(self): self._shm_optimizer_weight.unlink() self._shm_optimizer_weight = None - dist.barrier() + if paddle.distributed.get_world_size() > 1: + dist.barrier() def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False): diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index e6f10223bb7a..cfe3ad193327 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2283,7 +2283,12 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle return loss.detach() - def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Optional[bool] = False): + def save_model( + self, + output_dir: Optional[str] = None, + merge_tensor_parallel: Optional[bool] = False, + signal_dir: Optional[str] = None, + ): """ Will save the model, so you can reload it using `from_pretrained()`. @@ -2293,17 +2298,20 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op if output_dir is None: output_dir = self.args.output_dir + if signal_dir is None: + signal_dir = self.args.output_signal_dir + if ShardingOption.FULL_SHARD in self.args.sharding: self.model_wrapped.get_all_parameters(convert2cpu=True) if self.args.should_save_model_state: - self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel) + self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel, signal_dir=signal_dir) else: if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: - os.makedirs(output_dir, exist_ok=True) + os.makedirs(signal_dir, exist_ok=True) if self.is_in_train: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 - paddle.save(global_rank, os.path.join(output_dir, f".model_weight.done.{global_rank}")) + paddle.save(global_rank, os.path.join(signal_dir, f".model_weight.done.{global_rank}")) if strtobool(os.getenv("FLAG_LLM_PDC", "False")): # save model_done file to ensure model is complete @@ -2319,9 +2327,9 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op and "async_save" in self.args.unified_checkpoint_config and not self.is_in_train ): - os.makedirs(output_dir, exist_ok=True) + os.makedirs(signal_dir, exist_ok=True) global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 - paddle.save(self.state.global_step, os.path.join(output_dir, f".model_weight.done.{global_rank}")) + paddle.save(self.state.global_step, os.path.join(signal_dir, f".model_weight.done.{global_rank}")) def _filter_moe_no_sync_optimizer_params(self): """ @@ -2332,7 +2340,7 @@ def _filter_moe_no_sync_optimizer_params(self): filter_optimzier_state_dict = OrderedDict() param_names_in_master_weights = list(optimzier_state_dict["master_weights"].keys()) if self.args.bf16 else [] filter_optimzier_state_dict["master_weights"] = OrderedDict() - for k, v in state_dict.items(): + for _, v in state_dict.items(): if getattr(v, "no_sync", False): if v.name in param_names_in_master_weights: filter_optimzier_state_dict["master_weights"][v.name] = optimzier_state_dict["master_weights"][ @@ -2351,15 +2359,17 @@ def _save_checkpoint(self, model, metrics=None): checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self.args.output_dir + run_signal_dir = self.args.output_signal_dir output_dir = os.path.join(run_dir, checkpoint_folder) + signal_dir = os.path.join(run_signal_dir, checkpoint_folder) if isinstance(self.model, LoRAModel) and (self.model.quantized or self.args.pipeline_parallel_degree > 1): - self.save_model(output_dir) + self.save_model(output_dir, False, signal_dir) elif isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM): - self.save_model(output_dir, True) + self.save_model(output_dir, True, signal_dir) else: - self.save_model(output_dir) + self.save_model(output_dir, False, signal_dir) # only save model state dict, ignore optimizer and scheduler if not self.args.ignore_save_lr_and_optim: @@ -2375,6 +2385,7 @@ def _save_checkpoint(self, model, metrics=None): self.model, self.optimizer, output_dir, + signal_dir, ) else: if self.dp_group.rank > 0: # this should only work for MoE saving @@ -2397,10 +2408,10 @@ def _save_checkpoint(self, model, metrics=None): else: if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 - os.makedirs(output_dir, exist_ok=True) - paddle.save(global_rank, os.path.join(output_dir, f".optimizer_weight.done.{global_rank}")) + os.makedirs(signal_dir, exist_ok=True) + paddle.save(global_rank, os.path.join(signal_dir, f".optimizer_weight.done.{global_rank}")) if "skip_save_model_weight" not in self.args.unified_checkpoint_config: - paddle.save(global_rank, os.path.join(output_dir, f".master_weight.done.{global_rank}")) + paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}")) if self.args.should_save or self.args.use_expert_parallel: if not self.args.use_hybrid_parallel: logger.info("Saving optimizer files.") @@ -2409,6 +2420,7 @@ def _save_checkpoint(self, model, metrics=None): self.model, self.optimizer, output_dir, + signal_dir, ) else: if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel: @@ -2433,10 +2445,10 @@ def _save_checkpoint(self, model, metrics=None): if self.args.unified_checkpoint and not self.args.use_hybrid_parallel: if "async_save" in self.args.unified_checkpoint_config: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 - os.makedirs(output_dir, exist_ok=True) - paddle.save(global_rank, os.path.join(output_dir, f".optimizer_weight.done.{global_rank}")) + os.makedirs(signal_dir, exist_ok=True) + paddle.save(global_rank, os.path.join(signal_dir, f".optimizer_weight.done.{global_rank}")) if "skip_save_model_weight" not in self.args.unified_checkpoint_config: - paddle.save(global_rank, os.path.join(output_dir, f".master_weight.done.{global_rank}")) + paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}")) self.runtime_timer.stop() # Determine the new best metric / best model checkpoint @@ -2485,7 +2497,7 @@ def _save_checkpoint(self, model, metrics=None): # For hybrid parallel training, the checkpoint files maybe on different node. need_to_rotate_checkpoints = False if self.args.use_hybrid_parallel: - if self.dp_group.rank <= 0: + if self.dp_group.rank <= 0 or self.args.use_expert_parallel: need_to_rotate_checkpoints = True else: need_to_rotate_checkpoints = self.args.should_save_model_state @@ -2494,6 +2506,7 @@ def _save_checkpoint(self, model, metrics=None): need_to_rotate_checkpoints = need_to_rotate_checkpoints and self.args.local_rank == 0 if need_to_rotate_checkpoints: self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + self._rotate_checkpoints(use_mtime=False, output_dir=run_signal_dir) if strtobool(os.getenv("FLAG_LLM_PDC", "False")) and not ("async_save" in self.args.unified_checkpoint_config): # save checkpoint_done file to ensure checkpoint is complete @@ -2568,10 +2581,23 @@ def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: # ignore_errors for shared disks between train nodes. shutil.rmtree(checkpoint, ignore_errors=True) - def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_parallel=False): + def _save( + self, + output_dir: Optional[str] = None, + state_dict=None, + merge_tensor_parallel=False, + signal_dir: Optional[str] = None, + ): output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving model checkpoint to {output_dir}") + + # signal_dir is used for asynchronous saving situations. + if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: + signal_dir = signal_dir if signal_dir is not None else self.args.output_signal_dir + os.makedirs(signal_dir, exist_ok=True) + logger.info(f"Saving model checkpoint finish signal to {signal_dir}") + # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` @@ -2581,16 +2607,15 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ and self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config ): - os.makedirs(self.args.logging_dir, exist_ok=True) world_size = paddle.distributed.get_world_size() save_info = { "world_size": world_size, "ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim, "skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config, } - if os.path.exists(os.path.join(self.args.logging_dir, "async_save_info.json")): # afs cannot overwrite - os.remove(os.path.join(self.args.logging_dir, "async_save_info.json")) - with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f: + if os.path.exists(os.path.join(signal_dir, "async_save_info.json")): # afs cannot overwrite + os.remove(os.path.join(signal_dir, "async_save_info.json")) + with open(os.path.join(signal_dir, "async_save_info.json"), "w") as f: json.dump(save_info, f) if self.args.should_save: @@ -2605,7 +2630,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ if not self.is_in_train: self.args.unified_checkpoint_config = [] - self.unified_checkpoint_handler.save_unified_checkpoint(self.model, self.optimizer, output_dir) + self.unified_checkpoint_handler.save_unified_checkpoint(self.model, self.optimizer, output_dir, signal_dir) # recover unified_checkpoint_config for not trine stage if not self.is_in_train: diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index 0588ea3530ee..ca816b585e3b 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -256,7 +256,7 @@ def _check_checkpoint_files(folder_path, world_size, ignore_save_lr_and_optim, s return a -def get_last_checkpoint(folder, uc_async_save=False): +def get_last_checkpoint(folder, signal_folder=None, uc_async_save=False): content = os.listdir(folder) checkpoints = [ path @@ -266,6 +266,9 @@ def get_last_checkpoint(folder, uc_async_save=False): if len(checkpoints) == 0: return + if uc_async_save: + assert signal_folder is not None + if strtobool(os.getenv("FLAG_LLM_PDC", "False")): for i in sorted(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0]), reverse=True): current_path = os.path.join(folder, i) @@ -275,11 +278,12 @@ def get_last_checkpoint(folder, uc_async_save=False): return current_path else: saving_info = paddle.load(distributed_file(os.path.join(current_path, ".saving_info"))) + current_signal_path = os.path.join(signal_folder, i) pre_world_size = saving_info.get("world_size", 1) ignore_save_lr_and_optim = saving_info.get("ignore_save_lr_and_optim", False) skip_save_model_weight = saving_info.get("skip_save_model_weight", False) if _check_checkpoint_files( - current_path, pre_world_size, ignore_save_lr_and_optim, skip_save_model_weight + current_signal_path, pre_world_size, ignore_save_lr_and_optim, skip_save_model_weight ): return current_path return diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index e33f16f0ce9e..50345a2485eb 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -436,6 +436,7 @@ class TrainingArguments: }, ) logging_dir: Optional[str] = field(default=None, metadata={"help": "VisualDL log dir."}) + output_signal_dir: Optional[str] = field(default=None, metadata={"help": "Asynchronous saving signal dir."}) logging_strategy: IntervalStrategy = field( default="steps", metadata={"help": "The logging strategy to use."}, @@ -897,6 +898,10 @@ def __post_init__(self): self.logging_dir = os.path.join(self.output_dir, default_logdir()) if self.logging_dir is not None: self.logging_dir = os.path.expanduser(self.logging_dir) + if self.output_signal_dir is None and self.output_dir is not None: + self.output_signal_dir = self.output_dir + if self.output_signal_dir is not None: + self.output_signal_dir = os.path.expanduser(self.output_signal_dir) if self.disable_tqdm is None: self.disable_tqdm = False # logger.getEffectiveLevel() > logging.WARN From 8f6e88b06f5e1ef7e74d4c343ffff7202df22fe3 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Wed, 16 Oct 2024 15:35:30 +0800 Subject: [PATCH 5/5] bug fix --- paddlenlp/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index cfe3ad193327..39ad7d8ea85c 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2506,7 +2506,7 @@ def _save_checkpoint(self, model, metrics=None): need_to_rotate_checkpoints = need_to_rotate_checkpoints and self.args.local_rank == 0 if need_to_rotate_checkpoints: self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) - self._rotate_checkpoints(use_mtime=False, output_dir=run_signal_dir) + self._rotate_checkpoints(use_mtime=True, output_dir=run_signal_dir) if strtobool(os.getenv("FLAG_LLM_PDC", "False")) and not ("async_save" in self.args.unified_checkpoint_config): # save checkpoint_done file to ensure checkpoint is complete