From 5451d31486b95957864e21cbe1151f4968554ab0 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Wed, 21 Aug 2024 13:53:49 +0800 Subject: [PATCH 01/16] [Unified checkpoint] update optimizer async save signal --- paddlenlp/trainer/trainer.py | 7 ++++++- paddlenlp/trainer/trainer_utils.py | 3 ++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index b77c45b1427c..58207834a0bc 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2305,7 +2305,12 @@ def _save_checkpoint(self, model, metrics=None): self._save_ckpt_func(state_dict, save_path) with open(saved_signal_path, mode="w+") as f: f.write("1") - + else: + if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: + global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 + paddle.save(global_rank, os.path.join(output_dir, f".optimizer_weight.done.{global_rank}")) + if "skip_save_model_weight" not in self.args.unified_checkpoint_config: + paddle.save(global_rank, os.path.join(output_dir, f".master_weight.done.{global_rank}")) if self.args.should_save or self.args.use_expert_parallel: if not self.args.use_hybrid_parallel: logger.info("Saving optimizer files.") diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index a385e36550de..86504648cc48 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -46,6 +46,7 @@ from ..transformers.tokenizer_utils_base import BatchEncoding from ..utils.import_utils import is_paddle_cuda_available, is_psutil_available from ..utils.log import logger +from .utils.helper import distributed_file __all__ = [ "TrainOutput", @@ -273,7 +274,7 @@ def get_last_checkpoint(folder, uc_async_save=False): if os.path.exists(os.path.join(current_path, ".checkpoint_done")): return current_path else: - saving_info = paddle.load(os.path.join(current_path, ".saving_info")) + saving_info = paddle.load(distributed_file(os.path.join(current_path, ".saving_info"))) pre_world_size = saving_info.get("world_size", 1) ignore_save_lr_and_optim = saving_info.get("ignore_save_lr_and_optim", False) skip_save_model_weight = saving_info.get("skip_save_model_weight", False) From 15e83e29676c018822eacb1b885285fc333fda8d Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 6 Sep 2024 22:48:13 +0800 Subject: [PATCH 02/16] update paddlepaddle --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index bacd415e9620..5640dee891bc 100644 --- a/Makefile +++ b/Makefile @@ -46,7 +46,7 @@ unit-test: .PHONY: install install: - pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html + pip install --pre paddlepaddle -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/ pip install -r requirements-dev.txt pip install -r requirements.txt pip install -r paddlenlp/experimental/autonlp/requirements.txt From 6837b2f7f29ca0c0496ce8df36224d4cfb67c4d2 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Tue, 10 Sep 2024 11:31:14 +0800 Subject: [PATCH 03/16] split param --- .../trainer/plugins/unified_checkpoint.py | 32 ++++++++++++++++--- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index 56183485cca8..61c56096534a 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -29,7 +29,7 @@ from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM from paddlenlp.trainer.argparser import strtobool -from paddlenlp.trainer.trainer_utils import ExplicitEnum +from paddlenlp.trainer.trainer_utils import ExplicitEnum, ShardingOption from paddlenlp.trainer.utils.helper import distributed_file, distributed_isfile from paddlenlp.transformers.model_utils import ( PretrainedModel, @@ -489,15 +489,37 @@ def save_unified_optimizer(self, model, optimizer, output_dir): output_dir (str): Save directory. """ + if paddle.distributed.get_world_size() <= 1: + self.save_single_card_optimizer(model, optimizer, output_dir) + return + + if ( + self.args.sharding_parallel_degree > 1 + and ShardingOption.SHARD_OP in self.args.sharding + and "split_param" in self.args.sharding_parallel_config + ): + pass if "ignore_merge_optimizer" in self.args.unified_checkpoint_config: self.save_non_merge_optimizer(model, optimizer, output_dir) return - if paddle.distributed.get_world_size() <= 1: - self.save_single_card_optimizer(model, optimizer, output_dir) - return - + print(type(optimizer), type(optimizer._inner_opt)) + # print([p.name for p in optimizer._inner_opt._parameter_list]) + print(optimizer._inner_opt._comm_buffer_list[0]) + + # comm_buffer + for buffer in optimizer._inner_opt._comm_buffer_list: + print(buffer.buffer_size) + # print([p.name for p in buffer._params]) + for key in buffer._sharding_param_grad_view.keys(): + print( + key, + buffer._sharding_param_grad_view[key]._param_begin, + buffer._sharding_param_grad_view[key]._param_end, + ) + paddle.distributed.barrier() + raise ValueError # Split into naive optimizer params and master weights. results = unified_optimizer_into_shards(self.args, model, optimizer, safe_serialization=True) master_weight_state_dict = None From 633d742f05b8082b6abe1b4e44287ec8e0396f21 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Tue, 24 Sep 2024 14:13:51 +0800 Subject: [PATCH 04/16] add save for split param --- .../trainer/plugins/unified_checkpoint.py | 151 +++++++++++++----- paddlenlp/trainer/trainer.py | 1 - 2 files changed, 114 insertions(+), 38 deletions(-) diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index 61c56096534a..866d92b81b52 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -389,15 +389,8 @@ def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str) if self.args.dataset_rank == 0: load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True) - def save_non_merge_optimizer(self, model, optimizer, output_dir): + def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, output_dir): paddle.device.cuda.empty_cache() - optim_state_dict = nested_copy(optimizer.state_dict()) - master_weights = None - if "master_weights" in optim_state_dict.keys(): - master_weights = optim_state_dict["master_weights"] - optim_state_dict.pop("master_weights") - if "LR_Scheduler" in optim_state_dict.keys(): - optim_state_dict.pop("LR_Scheduler") # gather global master_weights status. global_master_weights = reduce_master_weights_status(master_weights is not None) @@ -498,30 +491,24 @@ def save_unified_optimizer(self, model, optimizer, output_dir): and ShardingOption.SHARD_OP in self.args.sharding and "split_param" in self.args.sharding_parallel_config ): - pass + optim_state_dict, master_weights = self.gather_split_param_for_optimizer(optimizer) + else: + optim_state_dict = nested_copy(optimizer.state_dict()) + master_weights = None + if "master_weights" in optim_state_dict.keys(): + master_weights = optim_state_dict["master_weights"] + optim_state_dict.pop("master_weights") + if "LR_Scheduler" in optim_state_dict.keys(): + optim_state_dict.pop("LR_Scheduler") if "ignore_merge_optimizer" in self.args.unified_checkpoint_config: - self.save_non_merge_optimizer(model, optimizer, output_dir) + self.save_non_merge_optimizer(model, optim_state_dict, master_weights, output_dir) return - print(type(optimizer), type(optimizer._inner_opt)) - # print([p.name for p in optimizer._inner_opt._parameter_list]) - print(optimizer._inner_opt._comm_buffer_list[0]) - - # comm_buffer - for buffer in optimizer._inner_opt._comm_buffer_list: - print(buffer.buffer_size) - # print([p.name for p in buffer._params]) - for key in buffer._sharding_param_grad_view.keys(): - print( - key, - buffer._sharding_param_grad_view[key]._param_begin, - buffer._sharding_param_grad_view[key]._param_end, - ) - paddle.distributed.barrier() - raise ValueError # Split into naive optimizer params and master weights. - results = unified_optimizer_into_shards(self.args, model, optimizer, safe_serialization=True) + results = unified_optimizer_into_shards( + self.args, model, optim_state_dict, master_weights, safe_serialization=True + ) master_weight_state_dict = None if len(results) == 1: optim_state_dict, shard_optim_file, sharded_optim_index = results[0] @@ -530,7 +517,6 @@ def save_unified_optimizer(self, model, optimizer, output_dir): master_weight_state_dict, shard_master_weight_file, sharded_master_weight_index = results[1] paddle.device.cuda.empty_cache() - save_directory = output_dir os.makedirs(save_directory, exist_ok=True) @@ -567,7 +553,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir): with open(master_path, "w") as f: json.dump(sharded_master_weight_index, f, indent=4) - def load_unified_optimizer(self, args, model, optimizer, resume_from_checkpoint): + def load_unified_optimizer(self, model, optimizer, resume_from_checkpoint): """Load potential model checkpoint Args: @@ -724,6 +710,70 @@ def save_single_card_optimizer(self, model, optimizer, output_dir): state_dict_type="master_weight", ) + def gather_split_param_for_optimizer(self, optimizer): + hcg = fleet.get_hybrid_communicate_group() + sharding_group = hcg.get_sharding_parallel_group() + global_rank = dist.get_rank() + param_slice_info = {} + param_shape_info = {} + for buffer in optimizer._inner_opt._comm_buffer_list: + for key in buffer._sharding_param_grad_view.keys(): + param_slice_info[key] = ( + buffer._sharding_param_grad_view[key]._param_begin, + buffer._sharding_param_grad_view[key]._param_end, + ) + param_shape_info[key] = ( + buffer._sharding_param_grad_view[key]._param.shape, + buffer._sharding_param_grad_view[key]._param.numel().item(), + ) + param_slice_info["global_rank"] = global_rank + param_slice_info_list = [] + dist.all_gather_object(param_slice_info_list, param_slice_info, group=sharding_group) + + optim_state_dict = nested_copy(optimizer.state_dict()) + master_weights = None + if "master_weights" in optim_state_dict.keys(): + master_weights = optim_state_dict.pop("master_weights") + if "LR_Scheduler" in optim_state_dict.keys(): + optim_state_dict.pop("LR_Scheduler") + + # deal with optimizer param + partial_tensor_list = [] + for key in list(optim_state_dict.keys()): + static_name, _ = generate_base_static_name(key) + if static_name in param_slice_info.keys(): + if optim_state_dict[key].numel().item() == 1: # for example: beta1, beta2 + continue + begin, end = param_slice_info[static_name] + shape, numel = param_shape_info[static_name] + if end - begin == numel: # full tensor + optim_state_dict[key] = optim_state_dict[key].reshape(shape) + elif end <= begin: # empty tensor + continue + else: # partial tensor, end > begin but end - begin < numel + partial_tensor_list.append(static_name) + + send_table = {} + recv_table = {} + for key in partial_tensor_list: + sharding_ranklist = [] + for slice_info in param_slice_info_list: + begin, end = slice_info[key] + if end > begin: + sharding_ranklist.append((slice_info["global_rank"], begin, end)) + recv_table[key] = sharding_ranklist[0][0] # which sharding_rank to recv the splited tensor + send_table[key] = [(rank, begin, end) for rank, begin, end in sharding_ranklist] + + distributed_send_recv_splited_param( + optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False + ) + if master_weights is not None: + distributed_send_recv_splited_param( + master_weights, partial_tensor_list, param_shape_info, send_table, recv_table, True + ) + + return optim_state_dict, master_weights + def unlink_shared_memory(self): if not ("async_save" in self.args.unified_checkpoint_config): return @@ -1038,7 +1088,8 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected def unified_optimizer_into_shards( args, model, - optimizer, + optim_state_dict, + master_weights, safe_serialization=False, ): """Get optimizer state dict and master weight state dict. @@ -1048,13 +1099,6 @@ def unified_optimizer_into_shards( safe_serialization (bool, optional): safe serialization using safetensors. Defaults to False. """ paddle.device.cuda.empty_cache() - optim_state_dict = nested_copy(optimizer.state_dict()) - master_weights = None - if "master_weights" in optim_state_dict.keys(): - master_weights = optim_state_dict["master_weights"] - optim_state_dict.pop("master_weights") - if "LR_Scheduler" in optim_state_dict.keys(): - optim_state_dict.pop("LR_Scheduler") # gather global master_weights status. global_master_weights = reduce_master_weights_status(master_weights is not None) @@ -1885,6 +1929,39 @@ def distributed_send_recv( return state_dict +def distributed_send_recv_splited_param( + state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False +): + global_rank = dist.get_rank() + for key in list(state_dict.keys()): + if state_dict[key].numel().item() == 1: # for example: beta1, beta2 + continue + + static_name = key if is_master_weights else generate_base_static_name(key)[0] + if static_name not in partial_tensor_list: + continue + + recv_rank = recv_table[static_name] + send_info = send_table[static_name] + + if global_rank == recv_rank: + tmp_tensor_list = [] + for send_rank, begin, end in send_info: + if send_rank == recv_rank: + tmp_tensor_list.append(state_dict[key]) + else: + tmp_tensor = paddle.empty(shape=[end - begin], dtype=state_dict[key].dtype) + dist.stream.recv(tmp_tensor, src=send_rank) + tmp_tensor_list.append(tmp_tensor) + state_dict[key] = paddle.concat(tmp_tensor_list, axis=0).reshape(param_shape_info[static_name][0]) + else: + for send_rank, _, _ in send_info: + if global_rank == send_rank: + dist.stream.send(state_dict[key], dst=recv_rank) + state_dict.pop(key) + return state_dict + + def get_sharded_file_name(args, file_name, is_optimizer=False): if not is_optimizer: shard_file = file_name.replace( diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 307adc7b60ec..06db04927d79 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2654,7 +2654,6 @@ def _load_optimizer_and_scheduler(self, checkpoint): opt_state_dict = None else: opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer( - args=self.args, model=self.model, optimizer=self.optimizer, resume_from_checkpoint=checkpoint, From b6aa309d699b1be823b2290ab6cfa110815128b2 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Thu, 10 Oct 2024 16:48:51 +0800 Subject: [PATCH 05/16] fix save split_param --- .../trainer/plugins/unified_checkpoint.py | 32 +++++++++++++++---- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index 0295f98edcd6..4c170596f621 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -732,6 +732,8 @@ def gather_split_param_for_optimizer(self, optimizer): param_shape_info[key] = ( buffer._sharding_param_grad_view[key]._param.shape, buffer._sharding_param_grad_view[key]._param.numel().item(), + buffer._sharding_param_grad_view[key]._index, + buffer._sharding_param_grad_view[key]._padded_size, ) param_slice_info["global_rank"] = global_rank param_slice_info_list = [] @@ -752,7 +754,7 @@ def gather_split_param_for_optimizer(self, optimizer): if optim_state_dict[key].numel().item() == 1: # for example: beta1, beta2 continue begin, end = param_slice_info[static_name] - shape, numel = param_shape_info[static_name] + shape, numel, _, _ = param_shape_info[static_name] if end - begin == numel: # full tensor optim_state_dict[key] = optim_state_dict[key].reshape(shape) elif end <= begin: # empty tensor @@ -1945,26 +1947,44 @@ def distributed_send_recv_splited_param( continue static_name = key if is_master_weights else generate_base_static_name(key)[0] + shape, numel, index, padded_size = param_shape_info[static_name] + if static_name not in partial_tensor_list: + state_dict[key] = state_dict[key].reshape(shape) continue recv_rank = recv_table[static_name] send_info = send_table[static_name] + base_padding_start = index + numel + base_padding_end = index + padded_size + if global_rank == recv_rank: tmp_tensor_list = [] for send_rank, begin, end in send_info: + padding_start = max(begin, base_padding_start) + padding_end = min(end, base_padding_end) + if send_rank == recv_rank: - tmp_tensor_list.append(state_dict[key]) + tensor = ( + state_dict[key] if padding_start >= padding_end else state_dict[key][: padding_start - begin] + ) + tmp_tensor_list.append(tensor) else: - tmp_tensor = paddle.empty(shape=[end - begin], dtype=state_dict[key].dtype) + length = end - begin if padding_start >= padding_end else padding_start - begin + tmp_tensor = paddle.empty(shape=[length], dtype=state_dict[key].dtype) dist.stream.recv(tmp_tensor, src=send_rank) tmp_tensor_list.append(tmp_tensor) - state_dict[key] = paddle.concat(tmp_tensor_list, axis=0).reshape(param_shape_info[static_name][0]) + state_dict[key] = paddle.concat(tmp_tensor_list, axis=0).reshape(shape) else: - for send_rank, _, _ in send_info: + for send_rank, begin, end in send_info: + padding_start = max(begin, base_padding_start) + padding_end = min(end, base_padding_end) if global_rank == send_rank: - dist.stream.send(state_dict[key], dst=recv_rank) + tensor = ( + state_dict[key] if padding_start >= padding_end else state_dict[key][: padding_start - begin] + ) + dist.stream.send(tensor, dst=recv_rank) state_dict.pop(key) return state_dict From bf5d72b18408d07f220e91a636fb0c11858eacb9 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 11 Oct 2024 20:19:30 +0800 Subject: [PATCH 06/16] add load uc split_param --- .../trainer/plugins/unified_checkpoint.py | 163 ++++++++++++++++++ 1 file changed, 163 insertions(+) diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index 4c170596f621..2e16f064fec9 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -983,7 +983,160 @@ def unified_checkpoint_into_shards( return state_dict, shard_file, sharded_index +def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint): + returned_optim_state_dict = nested_copy(optimizer.state_dict()) + + index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME + + resolved_archive_file, sharded_metadata = get_optimizer_shard_files( + optimizer_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, index_filename), + ) + has_master_weights = True if sharded_metadata["master_weights"] else False + + typename_set = set() + for key in sharded_metadata["weight_map"].keys(): + _, typename = key.split("/") + typename_set.add(typename) + + model_state_dict = get_expected_state_dict(model) + model_keys = list(model_state_dict.keys()) + static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} # get optimizer param mappings + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} + + expected_keys = [] + param_slice_info = {} + param_shape_info = {} + for buffer in optimizer._inner_opt._comm_buffer_list: + for key in buffer._sharding_param_grad_view.keys(): + begin = buffer._sharding_param_grad_view[key]._param_begin + end = buffer._sharding_param_grad_view[key]._param_end + if end > begin: + expected_keys.append(key) + shape = buffer._sharding_param_grad_view[key]._param.shape + numel = buffer._sharding_param_grad_view[key]._param.numel().item() + index = buffer._sharding_param_grad_view[key]._index + padded_size = buffer._sharding_param_grad_view[key]._padded_size + param_slice_info[key] = (begin, end) + param_shape_info[key] = (shape, numel, index, padded_size) + + expected_keys = set([static2struct_name_mappings.get(name, None) for name in expected_keys]) + expected_keys_optim = [] + for key in expected_keys: + for typename in typename_set: + expected_keys_optim.append(f"{key}/{typename}") + expected_keys_optim = set(expected_keys_optim) + + if len(resolved_archive_file) > 1: + resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards") + + if has_master_weights: + returned_optim_state_dict["master_weights"] = {} + resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files( + optimizer_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, index_filename_master_weights), + ) + if len(resolved_archive_file_mw) > 1: + resolved_archive_file_mw = tqdm(resolved_archive_file_mw, desc="Loading master weights shards") + + def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False): + returned_state_dict = {} + + if model.config.tensor_parallel_degree > 1: + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + tp_actions = model._get_tensor_parallel_convert_actions(model_keys, is_split=True, ignore_error=True) + else: + tp_actions = model.get_tensor_parallel_convert_actions(model.config, model_keys, ignore_error=True) + if not is_master_weights: + tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys) + + for shard_file in resolved_archive_file: + if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]): + continue + + if model.config.tensor_parallel_degree > 1: + state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected") + else: + state_dict = load_state_dict(shard_file, None, expected_keys, device="expected") + + returned_state_dict.update(state_dict) + del state_dict + gc.collect() + + return returned_state_dict + + # get tp params + state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys_optim) + if has_master_weights: + state_dict_master_weight = load_resolved_archive_file( + resolved_archive_file_mw, + sharded_metadata_mw, + expected_keys, + is_master_weights=True, + ) + + # need to split param for different sharding rank, maybe need to deal with oom issue. + for key in list(state_dict_optim.keys()): + key_name = key.split("/") + static_name = struct2static_name_mappings.get(key_name[0], None) + + if state_dict_optim[key].numel().item() > 1: + begin, end = param_slice_info[static_name] + shape, numel, index, padded_size = param_shape_info[static_name] + state_dict_optim[key] = state_dict_optim[key].reshape([-1]) + state_dict_optim[key] = state_dict_optim[key][begin - index : end - index] + + padding_start = max(begin, index + numel) + padding_end = min(end, index + padded_size) + if padding_start < padding_end: + state_dict_optim[key] = paddle.concat( + ( + state_dict_optim[key], + paddle.zeros([padding_end - padding_start], dtype=state_dict_optim[key].dtype), + ) + ) + + if has_master_weights: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) + returned_optim_state_dict[key_name] = state_dict_optim.pop(key) + returned_optim_state_dict[key_name].name = key_name + + if has_master_weights: + for key in list(state_dict_master_weight.keys()): + static_name = struct2static_name_mappings.get(key, None) + if state_dict_master_weight[key].numel().item() > 1: + begin, end = param_slice_info[static_name] + shape, numel, index, padded_size = param_shape_info[static_name] + state_dict_master_weight[key] = state_dict_master_weight[key].reshape([-1]) + state_dict_master_weight[key] = state_dict_master_weight[key][begin - index : end - index] + + padding_start = max(begin, index + numel) + padding_end = min(end, index + padded_size) + if padding_start < padding_end: + state_dict_master_weight[key] = paddle.concat( + ( + state_dict_master_weight[key], + paddle.zeros([padding_end - padding_start], dtype=state_dict_master_weight[key].dtype), + ) + ) + returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key) + returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) + + return returned_optim_state_dict + + def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): + # Special process with split param. + if ( + args.sharding_parallel_degree > 1 + and ShardingOption.SHARD_OP in args.sharding + and "split_param" in args.sharding_parallel_config + ): + returned_optim_state_dict = load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint) + return returned_optim_state_dict + # init and get optimizer LR_Scheduler returned_optim_state_dict = nested_copy(optimizer.state_dict()) @@ -1314,6 +1467,16 @@ def check_unified_optimizer(args, model, optimizer, resume_from_checkpoint, safe sharding_group = hcg.get_sharding_parallel_group() sharding_rank = sharding_group.rank struct2static_name_mappings = {k: v.name for k, v in model.state_dict().items()} + + if ( + args.sharding_parallel_degree > 1 + and ShardingOption.SHARD_OP in args.sharding + and "split_param" in args.sharding_parallel_config + ): + # We do not check optimizer files completion for split_param, since it is very complicated. Directly support local resume. + logger.warning("We only support local resume for split_param mode, do not support dynamically loading.") + return True + if sharding_group.nranks > 1: param2rank = optimizer._param2rank From 9fdaae279d6007faa8eeaa0796b89d91ed262b95 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Mon, 14 Oct 2024 11:40:10 +0800 Subject: [PATCH 07/16] update uc files --- .../trainer/plugins/unified_checkpoint.py | 568 +----------------- .../plugins/unified_checkpoint_utils.py | 567 +++++++++++++++++ paddlenlp/utils/nested.py | 10 + 3 files changed, 603 insertions(+), 542 deletions(-) create mode 100644 paddlenlp/trainer/plugins/unified_checkpoint_utils.py diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index 2e16f064fec9..1cfb7b0053b4 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -29,7 +29,7 @@ from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM from paddlenlp.trainer.argparser import strtobool -from paddlenlp.trainer.trainer_utils import ExplicitEnum, ShardingOption +from paddlenlp.trainer.trainer_utils import ShardingOption from paddlenlp.trainer.utils.helper import distributed_file, distributed_isfile from paddlenlp.transformers.model_utils import ( PretrainedModel, @@ -46,15 +46,12 @@ get_checkpoint_shard_files, is_safetensors_available, ) -from paddlenlp.utils.distributed import distributed_allgather, distributed_gather from paddlenlp.utils.env import ( LORA_WEIGHTS_NAME, PADDLE_MASTER_WEIGHTS_INDEX_NAME, PADDLE_MASTER_WEIGHTS_NAME, PADDLE_OPTIMIZER_INDEX_NAME, PADDLE_OPTIMIZER_NAME, - PADDLE_PEFT_WEIGHTS_INDEX_NAME, - PADDLE_WEIGHTS_INDEX_NAME, PADDLE_WEIGHTS_NAME, PAST_KEY_VALUES_FILE_NAME, PREFIX_WEIGHTS_NAME, @@ -68,8 +65,7 @@ SAFE_WEIGHTS_NAME, ) from paddlenlp.utils.log import logger -from paddlenlp.utils.nested import nested_copy, nested_copy_place -from paddlenlp.utils.tools import get_env_device +from paddlenlp.utils.nested import flatten_list, nested_copy, nested_copy_place if is_safetensors_available(): from safetensors.numpy import save_file as safe_save_file @@ -86,39 +82,33 @@ _traverse_copy_to_shm, create_meta_dict, ) - -FP32_MASTER = "fp32_master_0" -optimizer_scalar_name = [ - "beta1_pow_acc_0", - "beta2_pow_acc_0", -] -optimizer_non_scaler_name = [ - "moment1_0", - "moment2_0", - "velocity_0", -] # to be added - +from .unified_checkpoint_utils import ( + FP32_MASTER, + UnifiedCheckpointOption, + filter_params, + gather_sharded_object, + generate_base_static_name, + get_expected_keys, + get_expected_state_dict, + get_optimizer_shard_files, + get_sharded_file_name, + get_sharded_index, + is_need_master_weight, + mapping_optimizer_tp_actions, + merge_tensor_parallel_for_optimizer, + merge_tensor_parallel_with_shard, + optimizer_non_scaler_name, + optimizer_scalar_name, + reduce_master_weights_status, + select_model_weight_index, + update_master_weight_status, +) DEST_PLACE = paddle.CPUPlace() if paddle.device.is_compiled_with_cuda(): DEST_PLACE = paddle.CUDAPinnedPlace() -class UnifiedCheckpointOption(ExplicitEnum): - """ - "- skip_save_model_weight: do not save model weights when the masters weight exist\n" - "- master_weight_compatible: 1. if the master weights exist, only load when needed\n" - " 2. if master weights does not exist, convert model weights to master weights when needed\n" - "- async_save: enable asynchronous saving checkpoints to disk\n" - "- enable_all_options: enable all optimization configurations\n" - """ - - SKIP_SAVE_MODEL_WEIGHT = "skip_save_model_weight" - MASTER_WEIGHT_COMPATIBLE = "master_weight_compatible" - ASYNC_SAVE = "async_save" - IGNORE_MERGE_OPTIMIZER = "ignore_merge_optimizer" - - class UnifiedCheckpointHandler: def __init__(self, args): self.args = args @@ -820,7 +810,7 @@ def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, sa """ Only dataset_rank == 0 can enter this function. """ - index_filename = select_model_weight_index(args, model, resume_from_checkpoint, safe_serialization, local=True) + index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=True) resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( pretrained_model_name_or_path=resume_from_checkpoint, @@ -1366,7 +1356,7 @@ def unified_optimizer_into_shards( def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serialization=False): - index_filename = select_model_weight_index(args, model, resume_from_checkpoint, safe_serialization, local=False) + index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=False) index_filename = os.path.join(resume_from_checkpoint, index_filename) # Find index json file and distribute this file in global group. if distributed_isfile(index_filename): @@ -1570,26 +1560,6 @@ def save_prefix_past_key_value(model_to_save, save_directory): np.save(os.path.join(save_directory, PAST_KEY_VALUES_FILE_NAME), past_key_value) -def get_expected_state_dict(model_to_save): - if isinstance(model_to_save, PretrainedModel): - state_dict = model_to_save.state_dict() - if ( - hasattr(model_to_save.config, "tie_word_embeddings") - and model_to_save.config.tie_word_embeddings - and hasattr(model_to_save, "_tied_weights_keys") - and model_to_save._tied_weights_keys is not None - ): - for key in model_to_save._tied_weights_keys: - if key in state_dict: - state_dict.pop(key) - elif isinstance(model_to_save, LoRAModel): - state_dict = model_to_save.get_trainable_state_dict() - elif isinstance(model_to_save, PrefixModelForCausalLM): - state_dict = model_to_save.prefix_encoder.state_dict() - - return state_dict - - def create_dispatch_table(args, model, file_keyname_mappings, file_machine_mappings, resume_from_checkpoint): """Create dispatch table for dynamically loading state dict. @@ -1689,7 +1659,7 @@ def create_optimizer_dispatch_table( def load_unified_checkpoint_dynamically(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): - index_filename = select_model_weight_index(args, model, resume_from_checkpoint, safe_serialization, local=False) + index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=False) index_filename = os.path.join(resume_from_checkpoint, index_filename) with open(index_filename, "r") as f: @@ -2150,489 +2120,3 @@ def distributed_send_recv_splited_param( dist.stream.send(tensor, dst=recv_rank) state_dict.pop(key) return state_dict - - -def get_sharded_file_name(args, file_name, is_optimizer=False): - if not is_optimizer: - shard_file = file_name.replace( - ".pdparams", - f"-{args.logical_process_index + 1:05d}-of-{args.world_size//args.dataset_world_size:05d}.pdparams", - ) - shard_file = shard_file.replace( - ".safetensors", - f"-{args.logical_process_index + 1:05d}-of-{args.world_size//args.dataset_world_size:05d}.safetensors", - ) - else: - hcg = fleet.get_hybrid_communicate_group() - dp_group = hcg.get_data_parallel_group() - shard_file = file_name.replace( - ".pdparams", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//dp_group.nranks:05d}.pdparams" - ) - shard_file = shard_file.replace( - ".safetensors", - f"-{args.logical_process_index + 1:05d}-of-{args.world_size//dp_group.nranks:05d}.safetensors", - ) - shard_file = shard_file.replace( - ".pdopt", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//dp_group.nranks:05d}.pdopt" - ) - return shard_file - - -def get_sharded_index( - index_file_list, - total_size_list, -): - # save index json file - local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) - if local_rank == 0: - sharded_index_json = {} - - sharded_index_json["metadata"] = {"total_size": sum(total_size_list)} - - weight_map = {} - for i, index_file in enumerate(index_file_list): - weight_map.update(index_file_list[i]) - - sharded_index_json["weight_map"] = weight_map - return sharded_index_json - - return None - - -def reduce_master_weights_status(has_master_weights=False): - data = paddle.to_tensor([has_master_weights], dtype="int32") - - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - pp_group = hcg.get_pipe_parallel_group() - sharding_group = hcg.get_sharding_parallel_group() - - if tp_group.nranks > 1: - dist.all_reduce(data, op=dist.ReduceOp.SUM, group=tp_group) - if pp_group.nranks > 1: - dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pp_group) - if sharding_group.nranks > 1: - dist.all_reduce(data, op=dist.ReduceOp.SUM, group=sharding_group) - - return data.item() > 0 - - -def gather_sharded_object(index_file, total_size, is_optimizer=False): - - index_file_list, total_size_list = [], [] - - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - pp_group = hcg.get_pipe_parallel_group() - - logger.info( - f"Unified checkpoint: generating sharded_index json files for {'optimizer or master weight' if is_optimizer else 'model weight'}." - ) - - if tp_group.nranks > 1: - dist.all_gather_object(index_file_list, index_file, tp_group) - dist.all_gather_object(total_size_list, total_size, tp_group) - if pp_group.nranks > 1: - pp_index_file_list = [] - pp_total_size_list = [] - dist.all_gather_object( - pp_index_file_list, index_file_list if len(index_file_list) > 0 else index_file, pp_group - ) - dist.all_gather_object( - pp_total_size_list, total_size_list if len(total_size_list) > 0 else total_size, pp_group - ) - index_file_list = pp_index_file_list - total_size_list = pp_total_size_list - - index_file_list = flatten_list(index_file_list) - total_size_list = flatten_list(total_size_list) - - # for pure sharding - if len(index_file_list) == 0 and len(total_size_list) == 0: - index_file_list = [index_file] - total_size_list = [total_size] - if is_optimizer: - sharding_group = hcg.get_sharding_parallel_group() - if sharding_group.nranks > 1: - sharding_index_file_list = [] - sharding_total_size_list = [] - dist.all_gather_object(sharding_index_file_list, index_file_list, sharding_group) - dist.all_gather_object(sharding_total_size_list, total_size_list, sharding_group) - index_file_list = flatten_list(sharding_index_file_list) - total_size_list = flatten_list(sharding_total_size_list) - - return index_file_list, total_size_list - - -def generate_base_static_name(vname): - # return base static name and specific type name, like [embedding_0.w_0, moment1_0] - if FP32_MASTER in vname: - vname = vname.split("_" + FP32_MASTER + "_") - return vname[0], vname[1] - else: - vname = vname.split(".") - a = vname[0] + "." + vname[1][:3] - b = vname[1][4:] - return a, b - - -def filter_params(model_to_save, state_dict, is_optimizer=False): - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - - tp_size = tp_group.nranks - tp_rank = tp_group.rank - - # for pure sharding or pure pp - if tp_size <= 1: - return [list(state_dict.keys())] - - filter_tensor_list = [[] for i in range(tp_size)] - - if tp_rank == 0: - tensor_bytes_dict = {} - model_state_dict = get_expected_state_dict(model_to_save) - for (k, v) in state_dict.items(): - model_v = model_state_dict[k.split("/")[0]] if is_optimizer else v - if hasattr(model_v, "is_distributed") and model_v.is_distributed: - tensor_bytes_dict[k] = v.numel().item() * tp_size * dtype_byte_size(v.dtype) - else: - tensor_bytes_dict[k] = v.numel().item() * dtype_byte_size(v.dtype) - - filter_tensor_list = [] - current_block = [] - current_block_size = 0 - total_size = 0 - - max_shard_size = (sum(tensor_bytes_dict.values()) + tp_size - 1) // tp_size - - for index, (key, weight_size) in enumerate(tensor_bytes_dict.items()): - # If this weight is going to tip up over the maximal size, we split. - # if current_block_size + weight_size > max_shard_size: - if total_size + weight_size > max_shard_size * (len(filter_tensor_list) + 1) or ( - len(tensor_bytes_dict) - index < (tp_size - len(filter_tensor_list)) - ): - # fix if the first param is large than max_shard_size - if len(current_block) > 0: - filter_tensor_list.append(current_block) - current_block = [] - current_block_size = 0 - - current_block.append(key) - current_block_size += weight_size - total_size += weight_size - - filter_tensor_list.append(current_block) - if len(filter_tensor_list) < tp_size: - filter_tensor_list.extend([[] for i in range(tp_size - len(filter_tensor_list))]) - - dist.broadcast_object_list( - filter_tensor_list, - src=hcg.get_model_parallel_group_src_rank(), - group=tp_group, - ) - - return filter_tensor_list - - -def merge_large_tensor_parallel(tensor, tp_group, tp_action, dst_rank, is_dst): - num_rows = tensor.shape[0] - num_splits = 4 - parts = np.array_split(np.arange(num_rows), num_splits) - splits = [len(part) for part in parts] - split_parts = np.insert(np.cumsum(splits), 0, 0) - split_tensors = [] - for i in range(num_splits): - if get_env_device() == "xpu": - ret = distributed_allgather(tensor[split_parts[i] : split_parts[i + 1], :], group=tp_group, offload=False) - else: - ret = distributed_gather( - tensor[split_parts[i] : split_parts[i + 1], :], dst=dst_rank, group=tp_group, offload=False - ) - # Copy to CPUPlace temporarily, may lower speed. - if ret is not None: - ret = [t.cpu() for t in ret] - split_tensors.append(ret) - concat_tensors = [] - if is_dst: - for i in range(tp_group.nranks): - tmp = [] - for j in range(num_splits): - tmp.append(split_tensors[j][i]) - concat_tensors.append(paddle.concat(tmp)) - tensor = tp_action(concat_tensors) - else: - tensor = None - return tensor - - -def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - tp_rank = tp_group.rank - - # filter actions for pipeline mode - if hcg.get_pipe_parallel_group().nranks > 1: - filter_keys = set([y for x in all_filter_keys for y in x]) - for key in list(tp_actions.keys()): - if key not in filter_keys: - tp_actions.pop(key) - - state_dict_to_save = {} - max_key_len = max([len(_) for _ in all_filter_keys]) - for i in range(max_key_len): - for j, filter_keys in enumerate(all_filter_keys): - is_dst = tp_rank == j - if i > len(filter_keys) - 1: - continue - key = filter_keys[i] - tensor = state_dict[key] - if key in tp_actions: - # Get tensor size - tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks - if tensor_bytes >= 5 * 1024 * 1024 * 1024: # temporarily set 5GB as threshold - tensor = merge_large_tensor_parallel(tensor, tp_group, tp_actions[key], j, is_dst) - else: - if get_env_device() == "xpu": - ret = distributed_allgather(tensor, group=tp_group, offload=False) - else: - ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) - action = tp_actions.pop(key) - tensor = action(ret) if is_dst else None - else: - if is_dst: - tensor = tensor._copy_to(DEST_PLACE, False) if tensor.place.is_cpu_place() else tensor - else: - tensor = None - - if is_dst: - state_dict_to_save[key] = tensor - - if len(tp_actions) > 0: - for x in tp_actions.keys(): - logger.warning(f"key <{x}> need to merge tensor parallel but we can't find in model state.") - - return state_dict_to_save - - -def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys): - # Core function for UC - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - tp_rank = tp_group.rank - - state_dict_to_save = {} - max_key_len = max([len(_) for _ in all_filter_keys]) - for i in range(max_key_len): - for j, filter_keys in enumerate(all_filter_keys): - is_dst = tp_rank == j - if i > len(filter_keys) - 1: - continue - # get base model key - model_key = filter_keys[i].split("/")[0] - tensor = state_dict[filter_keys[i]] - if model_key in tp_actions: - # for example: beta1, beta2 - if tensor.numel().item() == 1: - if is_dst: - tensor = tensor._copy_to(DEST_PLACE, False) if not tensor.place.is_cpu_place() else tensor - else: - tensor = None - else: - # Get tensor size - tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks - if tensor_bytes >= 5 * 1024 * 1024 * 1024: # temporarily set 5GB as threshold - tensor = merge_large_tensor_parallel(tensor, tp_group, tp_actions[model_key], j, is_dst) - else: - if get_env_device() == "xpu": - ret = distributed_allgather(tensor, group=tp_group, offload=False) - else: - ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) - action = tp_actions[model_key] - tensor = action(ret) if is_dst else None - else: - if is_dst: - tensor = tensor._copy_to(DEST_PLACE, False) if not tensor.place.is_cpu_place() else tensor - else: - tensor = None - - if is_dst: - state_dict_to_save[filter_keys[i]] = tensor - - return state_dict_to_save - - -def get_optimizer_shard_files(optimizer_path, index_filename): - """ - For a given model: - - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the - Hub - - returns the list of paths to all the shards, as well as some metadata. - For the description of each arg, see [`PretrainedModel.from_pretrained`]. `index_filename` is the full path to the - index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub). - """ - - import json - - if not os.path.isfile(index_filename): - raise ValueError(f"Can't find a optimizer index ({index_filename}) in {optimizer_path}.") - - with open(index_filename, "r") as f: - index = json.loads(f.read()) - - shard_filenames = sorted(set(index["weight_map"].values())) - sharded_metadata = index["metadata"] - sharded_metadata["all_optimizer_keys"] = list(index["weight_map"].keys()) - sharded_metadata["weight_map"] = index["weight_map"].copy() - sharded_metadata["master_weights"] = index.get("master_weights", False) - - file_map = {file: set() for file in shard_filenames} - for weight, file in index["weight_map"].items(): - file_map[file].add(weight) - - sharded_metadata["file_map"] = file_map - - # First, let's deal with local folder. - # TODO: if optimizer_path is a folder, we should check if the optimizer is already cached or not. - if os.path.isdir(optimizer_path): - shard_filenames = [os.path.join(optimizer_path, f) for f in shard_filenames] - return shard_filenames, sharded_metadata - - -def get_expected_keys(sharded_metadata, model, optimizer): - hcg = fleet.get_hybrid_communicate_group() - sharding_group = hcg.get_sharding_parallel_group() - sharding_rank = sharding_group.rank - in_sharding_parallel_model = sharding_group.nranks > 1 - if in_sharding_parallel_model: - params2rank = optimizer._param2rank - - struct2static_name_mappings = {k: v.name for k, v in get_expected_state_dict(model).items()} - - expected_keys = [] - for key in list(sharded_metadata["all_optimizer_keys"]): - key_name = key.split("/")[0] - static_name = struct2static_name_mappings.get(key_name, None) - - if in_sharding_parallel_model: - params_rank = params2rank.get(static_name, None) - if params_rank == sharding_rank: - expected_keys.append(key) - else: - if static_name is not None: - expected_keys.append(key) - expected_keys = set(expected_keys) - - loaded_keys = sharded_metadata["all_optimizer_keys"] - missing_keys = expected_keys - set(loaded_keys) - if len(missing_keys) > 0: - raise ValueError(f"optimizer missing weights keys: {missing_keys}") - - return expected_keys - - -def mapping_optimizer_tp_actions(tp_actions, optimizer_loaded_keys): - """# convert param.name to - param.key/moment1_0 - or param.key/beta1_XXX - or param.key/beta2_XXX - Args: - tp_actions (dict): dictionay of tensor parallel actions {key: action} - optimizer_loaded_keys (list or set): [param.key1/moment1_0, param.key2/beta1_XXX, param.key3/beta2_XXX] - Returns: - dict: new dictionay of tensor parallel actions {key: action} - """ - new_actions = {} - for key in optimizer_loaded_keys: - key_base, typename = key.split("/") - if typename in optimizer_non_scaler_name and key_base in tp_actions: - new_actions[key] = tp_actions[key_base] - return new_actions - - -def flatten_list(nested_list): - flattened_list = [] - for item in nested_list: - if isinstance(item, list): - flattened_list.extend(flatten_list(item)) - else: - flattened_list.append(item) - return flattened_list - - -def select_model_weight_index(args, model, resume_from_checkpoint, safe_serialization, local=True): - """ - try select model weight index from model weight or master weight index. - """ - - # find model weight index file - if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): - index_filename = SAFE_PEFT_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_PEFT_WEIGHTS_INDEX_NAME - else: - index_filename = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_WEIGHTS_INDEX_NAME - - index_filename_path = os.path.join(resume_from_checkpoint, index_filename) - identify_func = os.path.isfile if local else distributed_isfile - - if identify_func(index_filename_path): - return index_filename - else: - index_filename = PADDLE_MASTER_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_MASTER_WEIGHTS_INDEX_NAME - index_filename_path = os.path.join(resume_from_checkpoint, index_filename) - - if identify_func(index_filename_path): - return index_filename - else: - raise ValueError("Can't find a valid unified model or master weight checkpoint to load.") - - -def update_master_weight_status(args, optimizer, has_master_weight, safe_serialization): - if is_need_master_weight(optimizer, is_fp16_or_bp16=(args.fp16 or args.bf16)): - if not has_master_weight: - if UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value in args.unified_checkpoint_config: - index_filename_master_weights = ( - PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME - ) - has_master_weight = True - logger.warning( - "The unified checkpoint does not contain master weight, " - "the model weight will be loaded as master weight." - ) - else: - raise ValueError( - "Can't find a valid unified master weight checkpoint," - f"add '{UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value}' into 'unified_checkpoint_config' to " - "load model checkpoint as master weight" - ) - else: - has_master_weight = True - index_filename_master_weights = ( - PADDLE_MASTER_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_MASTER_WEIGHTS_INDEX_NAME - ) - if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config: - index_filename_master_weights = ( - PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME - ) - else: - has_master_weight = False - index_filename_master_weights = None - - return has_master_weight, index_filename_master_weights - - -def unwrap_optimizer(optimizer): - while hasattr(optimizer, "_inner_opt") or hasattr(optimizer, "_optim"): - if hasattr(optimizer, "_inner_opt"): - optimizer = optimizer._inner_opt - if hasattr(optimizer, "_optim"): - optimizer = optimizer._optim - - return optimizer - - -def is_need_master_weight(optimizer, is_fp16_or_bp16): - optimizer = unwrap_optimizer(optimizer) - if hasattr(optimizer, "_multi_precision"): - return optimizer._multi_precision and is_fp16_or_bp16 - else: - return False diff --git a/paddlenlp/trainer/plugins/unified_checkpoint_utils.py b/paddlenlp/trainer/plugins/unified_checkpoint_utils.py new file mode 100644 index 000000000000..67a876e5fbd3 --- /dev/null +++ b/paddlenlp/trainer/plugins/unified_checkpoint_utils.py @@ -0,0 +1,567 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet + +from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM +from paddlenlp.trainer.trainer_utils import ExplicitEnum +from paddlenlp.trainer.utils.helper import distributed_isfile +from paddlenlp.transformers.model_utils import PretrainedModel +from paddlenlp.transformers.utils import dtype_byte_size +from paddlenlp.utils.distributed import distributed_allgather, distributed_gather +from paddlenlp.utils.env import ( + PADDLE_MASTER_WEIGHTS_INDEX_NAME, + PADDLE_PEFT_WEIGHTS_INDEX_NAME, + PADDLE_WEIGHTS_INDEX_NAME, + SAFE_MASTER_WEIGHTS_INDEX_NAME, + SAFE_PEFT_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_INDEX_NAME, +) +from paddlenlp.utils.log import logger +from paddlenlp.utils.nested import flatten_list +from paddlenlp.utils.tools import get_env_device + +FP32_MASTER = "fp32_master_0" +optimizer_scalar_name = [ + "beta1_pow_acc_0", + "beta2_pow_acc_0", +] +optimizer_non_scaler_name = [ + "moment1_0", + "moment2_0", + "velocity_0", +] # to be added + + +DEST_PLACE = paddle.CPUPlace() +if paddle.device.is_compiled_with_cuda(): + DEST_PLACE = paddle.CUDAPinnedPlace() + + +class UnifiedCheckpointOption(ExplicitEnum): + """ + "- skip_save_model_weight: do not save model weights when the masters weight exist\n" + "- master_weight_compatible: 1. if the master weights exist, only load when needed\n" + " 2. if master weights does not exist, convert model weights to master weights when needed\n" + "- async_save: enable asynchronous saving checkpoints to disk\n" + "- enable_all_options: enable all optimization configurations\n" + """ + + SKIP_SAVE_MODEL_WEIGHT = "skip_save_model_weight" + MASTER_WEIGHT_COMPATIBLE = "master_weight_compatible" + ASYNC_SAVE = "async_save" + IGNORE_MERGE_OPTIMIZER = "ignore_merge_optimizer" + + +"""master weights related functions""" + + +def unwrap_optimizer(optimizer): + while hasattr(optimizer, "_inner_opt") or hasattr(optimizer, "_optim"): + if hasattr(optimizer, "_inner_opt"): + optimizer = optimizer._inner_opt + if hasattr(optimizer, "_optim"): + optimizer = optimizer._optim + return optimizer + + +def is_need_master_weight(optimizer, is_fp16_or_bp16): + optimizer = unwrap_optimizer(optimizer) + if hasattr(optimizer, "_multi_precision"): + return optimizer._multi_precision and is_fp16_or_bp16 + else: + return False + + +def update_master_weight_status(args, optimizer, has_master_weight, safe_serialization): + if is_need_master_weight(optimizer, is_fp16_or_bp16=(args.fp16 or args.bf16)): + if not has_master_weight: + if UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value in args.unified_checkpoint_config: + index_filename_master_weights = ( + PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME + ) + has_master_weight = True + logger.warning( + "The unified checkpoint does not contain master weight, " + "the model weight will be loaded as master weight." + ) + else: + raise ValueError( + "Can't find a valid unified master weight checkpoint," + f"add '{UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value}' into 'unified_checkpoint_config' to " + "load model checkpoint as master weight" + ) + else: + has_master_weight = True + index_filename_master_weights = ( + PADDLE_MASTER_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_MASTER_WEIGHTS_INDEX_NAME + ) + if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config: + index_filename_master_weights = ( + PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME + ) + else: + has_master_weight = False + index_filename_master_weights = None + + return has_master_weight, index_filename_master_weights + + +def reduce_master_weights_status(has_master_weights=False): + data = paddle.to_tensor([has_master_weights], dtype="int32") + + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + pp_group = hcg.get_pipe_parallel_group() + sharding_group = hcg.get_sharding_parallel_group() + + if tp_group.nranks > 1: + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=tp_group) + if pp_group.nranks > 1: + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pp_group) + if sharding_group.nranks > 1: + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=sharding_group) + + return data.item() > 0 + + +def select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=True): + """ + try select model weight index from model weight or master weight index. + """ + + # find model weight index file + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + index_filename = SAFE_PEFT_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_PEFT_WEIGHTS_INDEX_NAME + else: + index_filename = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_WEIGHTS_INDEX_NAME + + index_filename_path = os.path.join(resume_from_checkpoint, index_filename) + identify_func = os.path.isfile if local else distributed_isfile + + if identify_func(index_filename_path): + return index_filename + else: + index_filename = PADDLE_MASTER_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_MASTER_WEIGHTS_INDEX_NAME + index_filename_path = os.path.join(resume_from_checkpoint, index_filename) + + if identify_func(index_filename_path): + return index_filename + else: + raise ValueError("Can't find a valid unified model or master weight checkpoint to load.") + + +def mapping_optimizer_tp_actions(tp_actions, optimizer_loaded_keys): + """# convert param.name to + param.key/moment1_0 + or param.key/beta1_XXX + or param.key/beta2_XXX + Args: + tp_actions (dict): dictionay of tensor parallel actions {key: action} + optimizer_loaded_keys (list or set): [param.key1/moment1_0, param.key2/beta1_XXX, param.key3/beta2_XXX] + Returns: + dict: new dictionay of tensor parallel actions {key: action} + """ + new_actions = {} + for key in optimizer_loaded_keys: + key_base, typename = key.split("/") + if typename in optimizer_non_scaler_name and key_base in tp_actions: + new_actions[key] = tp_actions[key_base] + return new_actions + + +def get_expected_state_dict(model_to_save): + if isinstance(model_to_save, PretrainedModel): + state_dict = model_to_save.state_dict() + if ( + hasattr(model_to_save.config, "tie_word_embeddings") + and model_to_save.config.tie_word_embeddings + and hasattr(model_to_save, "_tied_weights_keys") + and model_to_save._tied_weights_keys is not None + ): + for key in model_to_save._tied_weights_keys: + if key in state_dict: + state_dict.pop(key) + elif isinstance(model_to_save, LoRAModel): + state_dict = model_to_save.get_trainable_state_dict() + elif isinstance(model_to_save, PrefixModelForCausalLM): + state_dict = model_to_save.prefix_encoder.state_dict() + + return state_dict + + +def get_expected_keys(sharded_metadata, model, optimizer): + hcg = fleet.get_hybrid_communicate_group() + sharding_group = hcg.get_sharding_parallel_group() + sharding_rank = sharding_group.rank + in_sharding_parallel_model = sharding_group.nranks > 1 + if in_sharding_parallel_model: + params2rank = optimizer._param2rank + + struct2static_name_mappings = {k: v.name for k, v in get_expected_state_dict(model).items()} + + expected_keys = [] + for key in list(sharded_metadata["all_optimizer_keys"]): + key_name = key.split("/")[0] + static_name = struct2static_name_mappings.get(key_name, None) + + if in_sharding_parallel_model: + params_rank = params2rank.get(static_name, None) + if params_rank == sharding_rank: + expected_keys.append(key) + else: + if static_name is not None: + expected_keys.append(key) + expected_keys = set(expected_keys) + + loaded_keys = sharded_metadata["all_optimizer_keys"] + missing_keys = expected_keys - set(loaded_keys) + if len(missing_keys) > 0: + raise ValueError(f"optimizer missing weights keys: {missing_keys}") + + return expected_keys + + +def get_optimizer_shard_files(optimizer_path, index_filename): + """ + For a given model: + - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the + Hub + - returns the list of paths to all the shards, as well as some metadata. + For the description of each arg, see [`PretrainedModel.from_pretrained`]. `index_filename` is the full path to the + index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub). + """ + + import json + + if not os.path.isfile(index_filename): + raise ValueError(f"Can't find a optimizer index ({index_filename}) in {optimizer_path}.") + + with open(index_filename, "r") as f: + index = json.loads(f.read()) + + shard_filenames = sorted(set(index["weight_map"].values())) + sharded_metadata = index["metadata"] + sharded_metadata["all_optimizer_keys"] = list(index["weight_map"].keys()) + sharded_metadata["weight_map"] = index["weight_map"].copy() + sharded_metadata["master_weights"] = index.get("master_weights", False) + + file_map = {file: set() for file in shard_filenames} + for weight, file in index["weight_map"].items(): + file_map[file].add(weight) + + sharded_metadata["file_map"] = file_map + + # First, let's deal with local folder. + # TODO: if optimizer_path is a folder, we should check if the optimizer is already cached or not. + if os.path.isdir(optimizer_path): + shard_filenames = [os.path.join(optimizer_path, f) for f in shard_filenames] + return shard_filenames, sharded_metadata + + +def generate_base_static_name(vname): + # return base static name and specific type name, like [embedding_0.w_0, moment1_0] + if FP32_MASTER in vname: + vname = vname.split("_" + FP32_MASTER + "_") + return vname[0], vname[1] + else: + vname = vname.split(".") + a = vname[0] + "." + vname[1][:3] + b = vname[1][4:] + return a, b + + +def merge_large_tensor_parallel(tensor, tp_group, tp_action, dst_rank, is_dst): + num_rows = tensor.shape[0] + num_splits = 4 + parts = np.array_split(np.arange(num_rows), num_splits) + splits = [len(part) for part in parts] + split_parts = np.insert(np.cumsum(splits), 0, 0) + split_tensors = [] + for i in range(num_splits): + if get_env_device() == "xpu": + ret = distributed_allgather(tensor[split_parts[i] : split_parts[i + 1], :], group=tp_group, offload=False) + else: + ret = distributed_gather( + tensor[split_parts[i] : split_parts[i + 1], :], dst=dst_rank, group=tp_group, offload=False + ) + # Copy to CPUPlace temporarily, may lower speed. + if ret is not None: + ret = [t.cpu() for t in ret] + split_tensors.append(ret) + concat_tensors = [] + if is_dst: + for i in range(tp_group.nranks): + tmp = [] + for j in range(num_splits): + tmp.append(split_tensors[j][i]) + concat_tensors.append(paddle.concat(tmp)) + tensor = tp_action(concat_tensors) + else: + tensor = None + return tensor + + +def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + tp_rank = tp_group.rank + + # filter actions for pipeline mode + if hcg.get_pipe_parallel_group().nranks > 1: + filter_keys = set([y for x in all_filter_keys for y in x]) + for key in list(tp_actions.keys()): + if key not in filter_keys: + tp_actions.pop(key) + + state_dict_to_save = {} + max_key_len = max([len(_) for _ in all_filter_keys]) + for i in range(max_key_len): + for j, filter_keys in enumerate(all_filter_keys): + is_dst = tp_rank == j + if i > len(filter_keys) - 1: + continue + key = filter_keys[i] + tensor = state_dict[key] + if key in tp_actions: + # Get tensor size + tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks + if tensor_bytes >= 5 * 1024 * 1024 * 1024: # temporarily set 5GB as threshold + tensor = merge_large_tensor_parallel(tensor, tp_group, tp_actions[key], j, is_dst) + else: + if get_env_device() == "xpu": + ret = distributed_allgather(tensor, group=tp_group, offload=False) + else: + ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) + action = tp_actions.pop(key) + tensor = action(ret) if is_dst else None + else: + if is_dst: + tensor = tensor._copy_to(DEST_PLACE, False) if tensor.place.is_cpu_place() else tensor + else: + tensor = None + + if is_dst: + state_dict_to_save[key] = tensor + + if len(tp_actions) > 0: + for x in tp_actions.keys(): + logger.warning(f"key <{x}> need to merge tensor parallel but we can't find in model state.") + + return state_dict_to_save + + +def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys): + # Core function for UC + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + tp_rank = tp_group.rank + + state_dict_to_save = {} + max_key_len = max([len(_) for _ in all_filter_keys]) + for i in range(max_key_len): + for j, filter_keys in enumerate(all_filter_keys): + is_dst = tp_rank == j + if i > len(filter_keys) - 1: + continue + # get base model key + model_key = filter_keys[i].split("/")[0] + tensor = state_dict[filter_keys[i]] + if model_key in tp_actions: + # for example: beta1, beta2 + if tensor.numel().item() == 1: + if is_dst: + tensor = tensor._copy_to(DEST_PLACE, False) if not tensor.place.is_cpu_place() else tensor + else: + tensor = None + else: + # Get tensor size + tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks + if tensor_bytes >= 5 * 1024 * 1024 * 1024: # temporarily set 5GB as threshold + tensor = merge_large_tensor_parallel(tensor, tp_group, tp_actions[model_key], j, is_dst) + else: + if get_env_device() == "xpu": + ret = distributed_allgather(tensor, group=tp_group, offload=False) + else: + ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False) + action = tp_actions[model_key] + tensor = action(ret) if is_dst else None + else: + if is_dst: + tensor = tensor._copy_to(DEST_PLACE, False) if not tensor.place.is_cpu_place() else tensor + else: + tensor = None + + if is_dst: + state_dict_to_save[filter_keys[i]] = tensor + + return state_dict_to_save + + +def filter_params(model_to_save, state_dict, is_optimizer=False): + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + + tp_size = tp_group.nranks + tp_rank = tp_group.rank + + # for pure sharding or pure pp + if tp_size <= 1: + return [list(state_dict.keys())] + + filter_tensor_list = [[] for i in range(tp_size)] + + if tp_rank == 0: + tensor_bytes_dict = {} + model_state_dict = get_expected_state_dict(model_to_save) + for (k, v) in state_dict.items(): + model_v = model_state_dict[k.split("/")[0]] if is_optimizer else v + if hasattr(model_v, "is_distributed") and model_v.is_distributed: + tensor_bytes_dict[k] = v.numel().item() * tp_size * dtype_byte_size(v.dtype) + else: + tensor_bytes_dict[k] = v.numel().item() * dtype_byte_size(v.dtype) + + filter_tensor_list = [] + current_block = [] + current_block_size = 0 + total_size = 0 + + max_shard_size = (sum(tensor_bytes_dict.values()) + tp_size - 1) // tp_size + + for index, (key, weight_size) in enumerate(tensor_bytes_dict.items()): + # If this weight is going to tip up over the maximal size, we split. + # if current_block_size + weight_size > max_shard_size: + if total_size + weight_size > max_shard_size * (len(filter_tensor_list) + 1) or ( + len(tensor_bytes_dict) - index < (tp_size - len(filter_tensor_list)) + ): + # fix if the first param is large than max_shard_size + if len(current_block) > 0: + filter_tensor_list.append(current_block) + current_block = [] + current_block_size = 0 + + current_block.append(key) + current_block_size += weight_size + total_size += weight_size + + filter_tensor_list.append(current_block) + if len(filter_tensor_list) < tp_size: + filter_tensor_list.extend([[] for i in range(tp_size - len(filter_tensor_list))]) + + dist.broadcast_object_list( + filter_tensor_list, + src=hcg.get_model_parallel_group_src_rank(), + group=tp_group, + ) + + return filter_tensor_list + + +def get_sharded_file_name(args, file_name, is_optimizer=False): + if not is_optimizer: + shard_file = file_name.replace( + ".pdparams", + f"-{args.logical_process_index + 1:05d}-of-{args.world_size//args.dataset_world_size:05d}.pdparams", + ) + shard_file = shard_file.replace( + ".safetensors", + f"-{args.logical_process_index + 1:05d}-of-{args.world_size//args.dataset_world_size:05d}.safetensors", + ) + else: + hcg = fleet.get_hybrid_communicate_group() + dp_group = hcg.get_data_parallel_group() + shard_file = file_name.replace( + ".pdparams", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//dp_group.nranks:05d}.pdparams" + ) + shard_file = shard_file.replace( + ".safetensors", + f"-{args.logical_process_index + 1:05d}-of-{args.world_size//dp_group.nranks:05d}.safetensors", + ) + shard_file = shard_file.replace( + ".pdopt", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//dp_group.nranks:05d}.pdopt" + ) + return shard_file + + +def get_sharded_index( + index_file_list, + total_size_list, +): + # save index json file + local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) + if local_rank == 0: + sharded_index_json = {} + + sharded_index_json["metadata"] = {"total_size": sum(total_size_list)} + + weight_map = {} + for i, _ in enumerate(index_file_list): + weight_map.update(index_file_list[i]) + + sharded_index_json["weight_map"] = weight_map + return sharded_index_json + + return None + + +def gather_sharded_object(index_file, total_size, is_optimizer=False): + + index_file_list, total_size_list = [], [] + + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + pp_group = hcg.get_pipe_parallel_group() + + logger.info( + f"Unified checkpoint: generating sharded_index json files for {'optimizer or master weight' if is_optimizer else 'model weight'}." + ) + + if tp_group.nranks > 1: + dist.all_gather_object(index_file_list, index_file, tp_group) + dist.all_gather_object(total_size_list, total_size, tp_group) + if pp_group.nranks > 1: + pp_index_file_list = [] + pp_total_size_list = [] + dist.all_gather_object( + pp_index_file_list, index_file_list if len(index_file_list) > 0 else index_file, pp_group + ) + dist.all_gather_object( + pp_total_size_list, total_size_list if len(total_size_list) > 0 else total_size, pp_group + ) + index_file_list = pp_index_file_list + total_size_list = pp_total_size_list + + index_file_list = flatten_list(index_file_list) + total_size_list = flatten_list(total_size_list) + + # for pure sharding + if len(index_file_list) == 0 and len(total_size_list) == 0: + index_file_list = [index_file] + total_size_list = [total_size] + if is_optimizer: + sharding_group = hcg.get_sharding_parallel_group() + if sharding_group.nranks > 1: + sharding_index_file_list = [] + sharding_total_size_list = [] + dist.all_gather_object(sharding_index_file_list, index_file_list, sharding_group) + dist.all_gather_object(sharding_total_size_list, total_size_list, sharding_group) + index_file_list = flatten_list(sharding_index_file_list) + total_size_list = flatten_list(sharding_total_size_list) + + return index_file_list, total_size_list diff --git a/paddlenlp/utils/nested.py b/paddlenlp/utils/nested.py index 4e800231843c..43f012aa3d0e 100644 --- a/paddlenlp/utils/nested.py +++ b/paddlenlp/utils/nested.py @@ -116,3 +116,13 @@ def nested_copy_place(inputs, place=None, blocking=False): if isinstance(inputs, paddle.Tensor): inputs = inputs if inputs.place == place else inputs._copy_to(place, blocking) return inputs + + +def flatten_list(nested_list): + flattened_list = [] + for item in nested_list: + if isinstance(item, list): + flattened_list.extend(flatten_list(item)) + else: + flattened_list.append(item) + return flattened_list From 19071ef1a6c4df88d4d3d03ea14e7efb0855736c Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Mon, 14 Oct 2024 15:26:50 +0800 Subject: [PATCH 08/16] update uc files --- .../trainer/plugins/unified_checkpoint.py | 968 +----------------- .../plugins/unified_checkpoint_dynamic.py | 493 +++++++++ .../plugins/unified_checkpoint_sharding_v2.py | 298 ++++++ .../plugins/unified_checkpoint_single_card.py | 242 +++++ .../plugins/unified_checkpoint_utils.py | 118 ++- 5 files changed, 1157 insertions(+), 962 deletions(-) create mode 100644 paddlenlp/trainer/plugins/unified_checkpoint_dynamic.py create mode 100644 paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py create mode 100644 paddlenlp/trainer/plugins/unified_checkpoint_single_card.py diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index d1f759e41df0..1cf632c11be6 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -21,7 +21,6 @@ import time from multiprocessing import shared_memory -import numpy as np import paddle import paddle.distributed as dist from paddle.distributed import fleet @@ -41,7 +40,6 @@ _add_variant, _load_state_dict_into_model, faster_set_state_dict, - get_parameter_dtype, load_state_dict, unwrap_model, ) @@ -58,7 +56,6 @@ PADDLE_OPTIMIZER_INDEX_NAME, PADDLE_OPTIMIZER_NAME, PADDLE_WEIGHTS_NAME, - PAST_KEY_VALUES_FILE_NAME, PREFIX_WEIGHTS_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME, SAFE_MASTER_WEIGHTS_NAME, @@ -70,16 +67,14 @@ SAFE_WEIGHTS_NAME, ) from paddlenlp.utils.log import logger -from paddlenlp.utils.nested import flatten_list, nested_copy, nested_copy_place +from paddlenlp.utils.nested import flatten_list, nested_copy if is_safetensors_available(): from safetensors.numpy import save_file as safe_save_file if sys.platform.startswith("win"): - from safetensors import safe_open from safetensors.numpy import load_file else: - from paddlenlp.utils.safetensors import fast_safe_open as safe_open from paddlenlp.utils.safetensors import fast_load_file as load_file from .shared_memory_utils import ( @@ -87,6 +82,20 @@ _traverse_copy_to_shm, create_meta_dict, ) +from .unified_checkpoint_dynamic import ( + load_unified_checkpoint_dynamically, + load_unified_optimizer_dynamically, +) +from .unified_checkpoint_sharding_v2 import ( + gather_splited_param_for_optimizer, + load_unified_optimizer_split_param, +) +from .unified_checkpoint_single_card import ( + load_single_card_checkpoint, + load_single_card_optimizer, + save_single_card_checkpoint, + save_single_card_optimizer, +) from .unified_checkpoint_utils import ( FP32_MASTER, UnifiedCheckpointOption, @@ -102,18 +111,14 @@ mapping_optimizer_tp_actions, merge_tensor_parallel_for_optimizer, merge_tensor_parallel_with_shard, - optimizer_non_scaler_name, - optimizer_scalar_name, reduce_master_weights_status, rename_shard_file, + save_config, + save_prefix_past_key_value, select_model_weight_index, update_master_weight_status, ) -DEST_PLACE = paddle.CPUPlace() -if paddle.device.is_compiled_with_cuda(): - DEST_PLACE = paddle.CUDAPinnedPlace() - class UnifiedCheckpointHandler: def __init__(self, args): @@ -293,7 +298,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir): # Under non distributed environment. if paddle.distributed.get_world_size() <= 1: - self.save_single_card_checkpoint(model_to_save, output_dir) + save_single_card_checkpoint(model_to_save, output_dir) return skip_save_model_weight = False @@ -373,16 +378,14 @@ def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str) None """ if paddle.distributed.get_world_size() <= 1: - load_single_card_checkpoint(self.args, model, resume_from_checkpoint) + load_single_card_checkpoint(model, resume_from_checkpoint) return local_resume = check_unified_checkpoint(self.args, model, resume_from_checkpoint, safe_serialization=True) if not local_resume: logger.info("Begin to dynamically load unified checkpoint!") - load_unified_checkpoint_dynamically( - self.args, model, optimizer, resume_from_checkpoint, safe_serialization=True - ) + load_unified_checkpoint_dynamically(self.args, model, resume_from_checkpoint, safe_serialization=True) return if self.args.dataset_rank == 0 or self.args.use_expert_parallel: @@ -505,7 +508,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir): """ if paddle.distributed.get_world_size() <= 1: - self.save_single_card_optimizer(model, optimizer, output_dir) + save_single_card_optimizer(model, optimizer, output_dir) return if ( @@ -513,7 +516,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir): and ShardingOption.SHARD_OP in self.args.sharding and "split_param" in self.args.sharding_parallel_config ): - optim_state_dict, master_weights = self.gather_split_param_for_optimizer(optimizer) + optim_state_dict, master_weights = gather_splited_param_for_optimizer(optimizer) else: optim_state_dict = nested_copy(optimizer.state_dict()) master_weights = None @@ -587,7 +590,7 @@ def load_unified_optimizer(self, model, optimizer, resume_from_checkpoint): """ if paddle.distributed.get_world_size() <= 1: - optim_state_dict = load_single_card_optimizer(self.args, model, optimizer, resume_from_checkpoint) + optim_state_dict = load_single_card_optimizer(model, optimizer, resume_from_checkpoint) return optim_state_dict has_merge_optimizer_safetensors = distributed_isfile( @@ -622,190 +625,6 @@ def load_unified_optimizer(self, model, optimizer, resume_from_checkpoint): return returned_optim_state_dict return None - def save_single_card_checkpoint(self, model_to_save, output_dir): - """Save checkpoint for non-distributed environment.""" - - state_dict = get_expected_state_dict(model_to_save) - if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): - weight_filename = "peft_model-00001-of-00001.safetensors" - index_filename = SAFE_PEFT_WEIGHTS_INDEX_NAME - else: - weight_filename = "model-00001-of-00001.safetensors" - index_filename = SAFE_WEIGHTS_INDEX_NAME - # get index json - index_weight_file = {} - total_size = 0 - for key, weight in state_dict.items(): - index_weight_file[key] = weight_filename - total_size += weight.numel().item() * dtype_byte_size(weight.dtype) - sharded_index_json = {} - sharded_index_json["metadata"] = {"total_size": total_size} - sharded_index_json["weight_map"] = index_weight_file - if isinstance(model_to_save, LoRAModel): - sharded_index_json["type"] = "lora" - elif isinstance(model_to_save, PrefixModelForCausalLM): - sharded_index_json["type"] = "ptuning" - - os.makedirs(output_dir, exist_ok=True) - path = os.path.join(output_dir, index_filename) - with open(path, "w") as f: - json.dump(sharded_index_json, f, indent=4) - - # save checkpoint - self._file_save_async_or_sync( - state_dict, path=os.path.join(output_dir, weight_filename), is_sync=True, state_dict_type="model_weight" - ) - - if isinstance(model_to_save, PrefixModelForCausalLM): - save_prefix_past_key_value(model_to_save, output_dir) - model_to_save.prefix_config.save_pretrained(output_dir) - if isinstance(model_to_save, LoRAModel): - model_to_save.lora_config.save_pretrained(output_dir) - - config_to_save = save_config(model_to_save) - config_to_save.architectures = [model_to_save.__class__.__name__] - config_to_save.save_pretrained(output_dir) - - def save_single_card_optimizer(self, model, optimizer, output_dir): - """ "Save optimizer for non-distributed environment.""" - # Split into optimizer params and master weights. - optim_state_dict = nested_copy(optimizer.state_dict()) - master_weights = None - if "master_weights" in optim_state_dict.keys(): - master_weights = optim_state_dict.pop("master_weights") - if "LR_Scheduler" in optim_state_dict.keys(): - optim_state_dict.pop("LR_Scheduler") - - static2struct_name_mappings = {} - state_dict = get_expected_state_dict(model) - fp32_weight = {} - for k, v in state_dict.items(): - static2struct_name_mappings[v.name] = k - if master_weights is not None and v.dtype == core.VarDesc.VarType.FP32: - fp32_weight[k] = v - - # rename optimizer param - for key in list(optim_state_dict.keys()): - static_name, type_name = generate_base_static_name(key) - new_name = static2struct_name_mappings[static_name] + "/" + type_name - optim_state_dict[new_name] = optim_state_dict.pop(key) - if master_weights is not None: - for key in list(master_weights.keys()): - master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) - master_weights.update(fp32_weight) - - # save index json - index_optimizer_file, index_master_weight_file = {}, {} - total_optim_size, total_master_weight_size = 0, 0 - for key, weight in optim_state_dict.items(): - index_optimizer_file[key] = "optimizer-00001-of-00001.safetensors" - total_optim_size += weight.numel().item() * dtype_byte_size(weight.dtype) - if master_weights is not None: - for key, weight in master_weights.items(): - index_master_weight_file[key] = "master_weights-00001-of-00001.safetensors" - total_master_weight_size += weight.numel().item() * dtype_byte_size(weight.dtype) - path = os.path.join(output_dir, SAFE_OPTIMIZER_INDEX_NAME) - master_path = os.path.join(output_dir, SAFE_MASTER_WEIGHTS_INDEX_NAME) - with open(path, "w") as f: - has_master_weights = master_weights is not None - json.dump( - { - "metadata": {"total_size": total_optim_size}, - "weight_map": index_optimizer_file, - "master_weights": has_master_weights, - }, - f, - indent=4, - ) - if master_weights is not None: - with open(master_path, "w") as f: - json.dump( - {"metadata": {"total_size": total_master_weight_size}, "weight_map": index_master_weight_file}, - f, - indent=4, - ) - - # save optimizer state dict - self._file_save_async_or_sync( - optim_state_dict, - path=os.path.join(output_dir, "optimizer-00001-of-00001.safetensors"), - is_sync=True, - state_dict_type="optimizer_weight", - ) - if master_weights is not None: - self._file_save_async_or_sync( - master_weights, - path=os.path.join(output_dir, "master_weights-00001-of-00001.safetensors"), - is_sync=True, - state_dict_type="master_weight", - ) - - def gather_split_param_for_optimizer(self, optimizer): - hcg = fleet.get_hybrid_communicate_group() - sharding_group = hcg.get_sharding_parallel_group() - global_rank = dist.get_rank() - param_slice_info = {} - param_shape_info = {} - for buffer in optimizer._inner_opt._comm_buffer_list: - for key in buffer._sharding_param_grad_view.keys(): - param_slice_info[key] = ( - buffer._sharding_param_grad_view[key]._param_begin, - buffer._sharding_param_grad_view[key]._param_end, - ) - param_shape_info[key] = ( - buffer._sharding_param_grad_view[key]._param.shape, - buffer._sharding_param_grad_view[key]._param.numel().item(), - buffer._sharding_param_grad_view[key]._index, - buffer._sharding_param_grad_view[key]._padded_size, - ) - param_slice_info["global_rank"] = global_rank - param_slice_info_list = [] - dist.all_gather_object(param_slice_info_list, param_slice_info, group=sharding_group) - - optim_state_dict = nested_copy(optimizer.state_dict()) - master_weights = None - if "master_weights" in optim_state_dict.keys(): - master_weights = optim_state_dict.pop("master_weights") - if "LR_Scheduler" in optim_state_dict.keys(): - optim_state_dict.pop("LR_Scheduler") - - # deal with optimizer param - partial_tensor_list = [] - for key in list(optim_state_dict.keys()): - static_name, _ = generate_base_static_name(key) - if static_name in param_slice_info.keys(): - if optim_state_dict[key].numel().item() == 1: # for example: beta1, beta2 - continue - begin, end = param_slice_info[static_name] - shape, numel, _, _ = param_shape_info[static_name] - if end - begin == numel: # full tensor - optim_state_dict[key] = optim_state_dict[key].reshape(shape) - elif end <= begin: # empty tensor - continue - else: # partial tensor, end > begin but end - begin < numel - partial_tensor_list.append(static_name) - - send_table = {} - recv_table = {} - for key in partial_tensor_list: - sharding_ranklist = [] - for slice_info in param_slice_info_list: - begin, end = slice_info[key] - if end > begin: - sharding_ranklist.append((slice_info["global_rank"], begin, end)) - recv_table[key] = sharding_ranklist[0][0] # which sharding_rank to recv the splited tensor - send_table[key] = [(rank, begin, end) for rank, begin, end in sharding_ranklist] - - distributed_send_recv_splited_param( - optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False - ) - if master_weights is not None: - distributed_send_recv_splited_param( - master_weights, partial_tensor_list, param_shape_info, send_table, recv_table, True - ) - - return optim_state_dict, master_weights - def unlink_shared_memory(self): if not ("async_save" in self.args.unified_checkpoint_config): return @@ -936,18 +755,6 @@ def _remove_unused_keys( raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") -def save_config(model_to_save): - dtype = get_parameter_dtype(model_to_save) - model_to_save.config.dtype = str(dtype).split(".")[1] - config_to_save = copy.deepcopy(model_to_save.config) - - if config_to_save.tensor_parallel_degree > 1: - # do we need to change? - config_to_save.tensor_parallel_degree = 1 - - return config_to_save - - def unified_checkpoint_into_shards( args, model_to_save, @@ -1019,150 +826,6 @@ def unified_checkpoint_into_shards( return state_dict, shard_file, sharded_index -def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint): - returned_optim_state_dict = nested_copy(optimizer.state_dict()) - - index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME - - resolved_archive_file, sharded_metadata = get_optimizer_shard_files( - optimizer_path=resume_from_checkpoint, - index_filename=os.path.join(resume_from_checkpoint, index_filename), - ) - has_master_weights = True if sharded_metadata["master_weights"] else False - - typename_set = set() - for key in sharded_metadata["weight_map"].keys(): - _, typename = key.split("/") - typename_set.add(typename) - - model_state_dict = get_expected_state_dict(model) - model_keys = list(model_state_dict.keys()) - static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} # get optimizer param mappings - struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} - - expected_keys = [] - param_slice_info = {} - param_shape_info = {} - for buffer in optimizer._inner_opt._comm_buffer_list: - for key in buffer._sharding_param_grad_view.keys(): - begin = buffer._sharding_param_grad_view[key]._param_begin - end = buffer._sharding_param_grad_view[key]._param_end - if end > begin: - expected_keys.append(key) - shape = buffer._sharding_param_grad_view[key]._param.shape - numel = buffer._sharding_param_grad_view[key]._param.numel().item() - index = buffer._sharding_param_grad_view[key]._index - padded_size = buffer._sharding_param_grad_view[key]._padded_size - param_slice_info[key] = (begin, end) - param_shape_info[key] = (shape, numel, index, padded_size) - - expected_keys = set([static2struct_name_mappings.get(name, None) for name in expected_keys]) - expected_keys_optim = [] - for key in expected_keys: - for typename in typename_set: - expected_keys_optim.append(f"{key}/{typename}") - expected_keys_optim = set(expected_keys_optim) - - if len(resolved_archive_file) > 1: - resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards") - - if has_master_weights: - returned_optim_state_dict["master_weights"] = {} - resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files( - optimizer_path=resume_from_checkpoint, - index_filename=os.path.join(resume_from_checkpoint, index_filename_master_weights), - ) - if len(resolved_archive_file_mw) > 1: - resolved_archive_file_mw = tqdm(resolved_archive_file_mw, desc="Loading master weights shards") - - def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False): - returned_state_dict = {} - - if model.config.tensor_parallel_degree > 1: - if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): - tp_actions = model._get_tensor_parallel_convert_actions(model_keys, is_split=True, ignore_error=True) - else: - tp_actions = model.get_tensor_parallel_convert_actions(model.config, model_keys, ignore_error=True) - if not is_master_weights: - tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys) - - for shard_file in resolved_archive_file: - if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]): - continue - - if model.config.tensor_parallel_degree > 1: - state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected") - else: - state_dict = load_state_dict(shard_file, None, expected_keys, device="expected") - - returned_state_dict.update(state_dict) - del state_dict - gc.collect() - - return returned_state_dict - - # get tp params - state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys_optim) - if has_master_weights: - state_dict_master_weight = load_resolved_archive_file( - resolved_archive_file_mw, - sharded_metadata_mw, - expected_keys, - is_master_weights=True, - ) - - # need to split param for different sharding rank, maybe need to deal with oom issue. - for key in list(state_dict_optim.keys()): - key_name = key.split("/") - static_name = struct2static_name_mappings.get(key_name[0], None) - - if state_dict_optim[key].numel().item() > 1: - begin, end = param_slice_info[static_name] - shape, numel, index, padded_size = param_shape_info[static_name] - state_dict_optim[key] = state_dict_optim[key].reshape([-1]) - state_dict_optim[key] = state_dict_optim[key][begin - index : end - index] - - padding_start = max(begin, index + numel) - padding_end = min(end, index + padded_size) - if padding_start < padding_end: - state_dict_optim[key] = paddle.concat( - ( - state_dict_optim[key], - paddle.zeros([padding_end - padding_start], dtype=state_dict_optim[key].dtype), - ) - ) - - if has_master_weights: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) - else: - key_name = "_".join([static_name, key_name[1]]) - returned_optim_state_dict[key_name] = state_dict_optim.pop(key) - returned_optim_state_dict[key_name].name = key_name - - if has_master_weights: - for key in list(state_dict_master_weight.keys()): - static_name = struct2static_name_mappings.get(key, None) - if state_dict_master_weight[key].numel().item() > 1: - begin, end = param_slice_info[static_name] - shape, numel, index, padded_size = param_shape_info[static_name] - state_dict_master_weight[key] = state_dict_master_weight[key].reshape([-1]) - state_dict_master_weight[key] = state_dict_master_weight[key][begin - index : end - index] - - padding_start = max(begin, index + numel) - padding_end = min(end, index + padded_size) - if padding_start < padding_end: - state_dict_master_weight[key] = paddle.concat( - ( - state_dict_master_weight[key], - paddle.zeros([padding_end - padding_start], dtype=state_dict_master_weight[key].dtype), - ) - ) - returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key) - returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) - - return returned_optim_state_dict - - def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): # Special process with split param. if ( @@ -1170,7 +833,7 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin and ShardingOption.SHARD_OP in args.sharding and "split_param" in args.sharding_parallel_config ): - returned_optim_state_dict = load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint) + returned_optim_state_dict = load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint) return returned_optim_state_dict # init and get optimizer LR_Scheduler @@ -1624,586 +1287,3 @@ def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, if has_master_weights: local_resume_rw = check_dynamic_load(args, index_mw["weight_map"], existed_files_mw, is_master_weights=True) return local_resume & local_resume_rw - - -def save_prefix_past_key_value(model_to_save, save_directory): - past_key_value = model_to_save.prefix_encoder(model_to_save.prefix_tokens.unsqueeze(0).expand([1, -1])) - past_key_value = past_key_value.reshape( - [ - model_to_save.prefix_config.num_prefix_tokens, - 2, - model_to_save.prefix_config.num_hidden_layers, - model_to_save.num_heads, - model_to_save.head_dim, - ] - ) - past_key_value = paddle.transpose(past_key_value, perm=[2, 1, 3, 0, 4]).cpu().numpy() - model_to_save.prefix_config.save_pretrained(save_directory) - np.save(os.path.join(save_directory, PAST_KEY_VALUES_FILE_NAME), past_key_value) - - -def create_dispatch_table(args, model, file_keyname_mappings, file_machine_mappings, resume_from_checkpoint): - """Create dispatch table for dynamically loading state dict. - - Args: - args - """ - - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - tp_rank = tp_group.rank - - # Create tensor receive table, contains {"key0": [global_rank, tp_rank], "key1": [global_rank, tp_rank]} - dispatch_list = [] - recv_table = {} - if args.dataset_rank == 0: - state_dict = get_expected_state_dict(model) - for (k, v) in state_dict.items(): - if hasattr(v, "is_distributed") and v.is_distributed: - recv_table[k] = [(dist.get_rank(), tp_rank)] - else: - recv_table[k] = [(dist.get_rank(), -1)] - - # Gather receive table in global group. - dist.all_gather_object(dispatch_list, recv_table) - recv_table = {} - for dl in dispatch_list: - for key, value in dl.items(): - if key not in recv_table: - recv_table[key] = value - else: - recv_table[key] += value - - # Create send table, to decide which worker to send the key. Contains {"key0:" global_rank, "key1": global_rank, ...} - send_table = create_send_table(file_keyname_mappings, file_machine_mappings) - - return send_table, recv_table - - -def create_optimizer_dispatch_table( - args, - model, - optimizer, - file_keyname_mappings, - file_machine_mappings, - resume_from_checkpoint, - struct2static_name_mappings, - is_master_weights=False, - typename_set=None, -): - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - sharding_group = hcg.get_sharding_parallel_group() - sharding_rank = sharding_group.rank - if sharding_group.nranks > 1: - param2rank = optimizer._param2rank - tp_rank = tp_group.rank - - # Create receive table, contains {"param_key0": [global_rank, tp_rank], "param_key1": [global_rank, tp_rank]} - dispatch_list = [] - recv_table = {} - if args.data_parallel_rank == 0: - state_dict = get_expected_state_dict(model) - for (k, v) in state_dict.items(): - if sharding_group.nranks > 1: - static_name = struct2static_name_mappings[k] - param_rank = param2rank.get(static_name, None) - if param_rank != sharding_rank: - continue - if is_master_weights: - if hasattr(v, "is_distributed") and v.is_distributed: - recv_table[k] = [(dist.get_rank(), tp_rank)] - else: - recv_table[k] = [(dist.get_rank(), -1)] - else: - for typename in typename_set: - type_key = k + "/" + typename - if typename in optimizer_non_scaler_name: - if hasattr(v, "is_distributed") and v.is_distributed: - recv_table[type_key] = [(dist.get_rank(), tp_rank)] - else: - recv_table[type_key] = [(dist.get_rank(), -1)] - else: - recv_table[type_key] = [(dist.get_rank(), -1)] - - dist.all_gather_object(dispatch_list, recv_table) - recv_table = {} - for dl in dispatch_list: - for k, v in dl.items(): - if k not in recv_table: - recv_table[k] = v - else: - recv_table[k] += v - - # Create send table, to decide which worker to send the key. Contains {"param_key0:" 0, "param_key1": 1, ...} - send_table = create_send_table(file_keyname_mappings, file_machine_mappings) - return send_table, recv_table - - -def load_unified_checkpoint_dynamically(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): - index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=False) - index_filename = os.path.join(resume_from_checkpoint, index_filename) - - with open(index_filename, "r") as f: - index = json.loads(f.read()) - - # `file_keyname_mappings` indicates which keys each file contains. For example, {"model-00001-of-00002.safetensors": ["llama.embed_tokens.weight", "llama.layers.0.self_attn.q_proj.weight", ...]} - # `file_machine_mappings` indicates the machine where the files appear. For example, {"model-00001-of-00002.safetensors": [machine_0, machine_1], "model-00002-of-00002.safetensors": [machine_0]} - file_keyname_mappings, file_machine_mappings = get_file_mappings(index, resume_from_checkpoint) - - logger.debug("Creating dispatch table for unified checkpoint load ...") - # Get send_table and recv_table. The send table indicates which workers are responsible for sending tensors, and the recv table indicates which workers should receive the tensors. - send_table, recv_table = create_dispatch_table( - args, model, file_keyname_mappings, file_machine_mappings, resume_from_checkpoint - ) - - # Get all the keys that are splited by tensor parallelism. - all_tp_keys = set() - for k, v in recv_table.items(): - if v[0][1] != -1: - all_tp_keys.add(k) - - config_revise = copy.deepcopy(model.config) - config_revise.tensor_parallel_rank = None - if len(all_tp_keys) == 0: - tp_actions = {} - else: - # Get corresponding tensor parallel actions. - if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): - tp_actions = model._get_tensor_parallel_convert_actions( - set(all_tp_keys), is_split=True, ignore_error=True, config=config_revise - ) - else: - tp_actions = model.get_tensor_parallel_convert_actions(config_revise, all_tp_keys, ignore_error=True) - - logger.debug("Distributed send recv for state dict load ...") - # Distribute the checkpoint tensor dynamically, using the `send_table` and `recv_table` we create before. - state_dict = distributed_send_recv( - config_revise, - get_expected_state_dict(model), - tp_actions, - send_table, - recv_table, - resume_from_checkpoint, - file_keyname_mappings, - file_machine_mappings, - ) - dist.barrier() - logger.debug("Setting state dict into model ...") - error_msgs = _load_state_dict_into_model(model, state_dict, "") - if len(error_msgs) > 0: - error_msg = "\n\t".join(error_msgs) - raise RuntimeError(f"Error(s) in loading dynamic state_dict for {model.__class__.__name__}:\n\t{error_msg}") - - -def load_unified_optimizer_dynamically(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): - optim_state_dict = nested_copy(optimizer.state_dict()) - if "master_weights" in optim_state_dict.keys(): - optim_state_dict.pop("master_weights") - - if safe_serialization: - index_filename, index_filename_mw = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME - else: - index_filename, index_filename_mw = PADDLE_OPTIMIZER_INDEX_NAME, PADDLE_MASTER_WEIGHTS_INDEX_NAME - - with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f: - index = json.loads(f.read()) - - # `file_keyname_mappings` indicates which keys each file contains. For example, {"optimizer-00001-of-00002.safetensors": ["llama.embed_tokens.weight/moment1_0", "llama.layers.1.mlp.gate_proj.weight/moment1_0", ...]} - # `file_machine_mappings` indicates the machine where the files appear. For example, {"optimizer-00001-of-00002.safetensors": [machine_0, machine_1], "optimizer-00002-of-00002.safetensors": [machine_0]} - file_keyname_mappings, file_machine_mappings = get_file_mappings(index, resume_from_checkpoint) - - has_master_weights = index["master_weights"] - # update has_master_weights and index_filename_master_weights - # 1. if the master weights exists, only has_master_weights is set True and load master weights when needed - # 2. if master weights does not exist, convert model weights to master weights when needed - has_master_weights, index_filename_mw = update_master_weight_status( - args, optimizer, has_master_weights, safe_serialization - ) - - if has_master_weights: - with open(os.path.join(resume_from_checkpoint, index_filename_mw), "r") as f: - index_mw = json.loads(f.read()) - file_keyname_mappings_mw, file_machine_mappings_mw = get_file_mappings(index_mw, resume_from_checkpoint) - - # Get optimizer param type name, like moment1_0, moment2_0, beta1_pow_acc_0. - typename_set = set() - for key in index["weight_map"].keys(): - _, typename = key.split("/") - typename_set.add(typename) - - model_state_dict = get_expected_state_dict(model) - struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} - static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} - # Get send_table and recv_table. The send table indicates which workers are responsible for sending tensors, and the recv table indicates which workers should receive the tensors. - send_table, recv_table = create_optimizer_dispatch_table( - args, - model, - optimizer, - file_keyname_mappings, - file_machine_mappings, - resume_from_checkpoint, - struct2static_name_mappings, - is_master_weights=False, - typename_set=typename_set, - ) - if has_master_weights: - send_table_mw, recv_table_mw = create_optimizer_dispatch_table( - args, - model, - optimizer, - file_keyname_mappings_mw, - file_machine_mappings_mw, - resume_from_checkpoint, - struct2static_name_mappings, - is_master_weights=True, - ) - - # Initialize optimizer state dict. - hcg = fleet.get_hybrid_communicate_group() - sharding_group = hcg.get_sharding_parallel_group() - if sharding_group.nranks > 1: - param2rank = optimizer._param2rank - optim_state_dict_mw = {} - - def check_optimizer_param(parameter): - if sharding_group.nranks > 1: - param_rank = param2rank.get(parameter.name, None) - if param_rank != sharding_group.rank: - return False - if parameter.stop_gradient: - return False - return True - - optimizer_keys_with_shape = [] - if isinstance(optimizer._parameter_list[0], dict): - for param_group in optimizer._parameter_list: - # If parameter groups are set, there must be `params` key. This is guaranteed by the optimizer's initialization code. - for parameter in param_group["params"]: - if check_optimizer_param(parameter): - optimizer_keys_with_shape.append((parameter.name, parameter.shape)) - else: - for parameter in optimizer._parameter_list: - if check_optimizer_param(parameter): - optimizer_keys_with_shape.append((parameter.name, parameter.shape)) - - # see how to change - for static_name, shape in optimizer_keys_with_shape: - k = static2struct_name_mappings[static_name] - for typename in typename_set: - new_k = k + "/" + typename - if typename in optimizer_scalar_name: - optim_state_dict[new_k] = paddle.empty([1], dtype="float32") - else: - optim_state_dict[new_k] = paddle.empty(shape, dtype="float32") - if has_master_weights: - optim_state_dict_mw[k] = paddle.empty(shape, dtype="float32") - - # Get all the keys that are splited by tensor parallelism. - all_tp_keys = set() - for k, v in recv_table.items(): - structure_name, typename = k.split("/") - if typename in optimizer_non_scaler_name: - if v[0][1] != -1: - all_tp_keys.add(structure_name) - - # Get corresponding tensor parallel actions. - config_revise = copy.deepcopy(model.config) - config_revise.tensor_parallel_rank = None - if len(all_tp_keys) == 0: - tp_actions = {} - else: - if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): - tp_actions = model._get_tensor_parallel_convert_actions( - set(all_tp_keys), is_split=True, ignore_error=True, config=config_revise - ) - else: - tp_actions = model.get_tensor_parallel_convert_actions(config_revise, all_tp_keys, ignore_error=True) - optimizer_keys = list(index["weight_map"].keys()) - optimizer_tp_actions = mapping_optimizer_tp_actions(tp_actions, optimizer_keys) - if has_master_weights: - optimizer_tp_actions.update(tp_actions) - - # Distribute the optimizer checkpoint dynamically, using the `send_table` and `recv_table` we create before. - optim_state_dict = distributed_send_recv( - config_revise, - optim_state_dict, - optimizer_tp_actions, - send_table, - recv_table, - resume_from_checkpoint, - file_keyname_mappings, - file_machine_mappings, - ) - dist.barrier() - if has_master_weights: - optim_state_dict_mw = distributed_send_recv( - config_revise, - optim_state_dict_mw, - optimizer_tp_actions, - send_table_mw, - recv_table_mw, - resume_from_checkpoint, - file_keyname_mappings_mw, - file_machine_mappings_mw, - ) - dist.barrier() - - # Rename optimizer state dict. - for key in list(optim_state_dict.keys()): - if key == "LR_Scheduler": - continue - key_name = key.split("/") - static_name = struct2static_name_mappings[key_name[0]] - if has_master_weights: - if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) - else: - key_name = "_".join([static_name, key_name[1]]) - else: - key_name = "_".join([static_name, key_name[1]]) - optim_state_dict[key_name] = optim_state_dict.pop(key) - optim_state_dict[key_name].name = key_name - - if has_master_weights: - optim_state_dict["master_weights"] = {} - for key in list(optim_state_dict_mw.keys()): - static_name = struct2static_name_mappings[key] - optim_state_dict["master_weights"][static_name] = optim_state_dict_mw.pop(key) - optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) - - if args.data_parallel_rank == 0: - return optim_state_dict - return None - - -def load_single_card_checkpoint(args, model, resume_from_checkpoint: str): - if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): - index_filename = SAFE_PEFT_WEIGHTS_INDEX_NAME - else: - index_filename = SAFE_WEIGHTS_INDEX_NAME - resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( - pretrained_model_name_or_path=resume_from_checkpoint, - index_filename=os.path.join(resume_from_checkpoint, index_filename), - ) - - loaded_keys = sharded_metadata["all_checkpoint_keys"] - model_state_dict = get_expected_state_dict(model) - expected_keys = set(list(model_state_dict.keys())) - missing_keys = expected_keys - set(loaded_keys) - - if len(missing_keys) > 0: - raise ValueError(f"Missing keys: {missing_keys}") - - state_dict = load_state_dict(resolved_archive_file[0], None, expected_keys) - error_msgs = _load_state_dict_into_model(model, state_dict, "") - del state_dict - gc.collect() - - if error_msgs: - raise RuntimeError(f"Error(s) in loading state dict for {model.__class__.__name__}:\n\t{error_msgs}") - - -def load_single_card_optimizer(args, model, optimizer, resume_from_checkpoint: str): - returned_optim_state_dict = nested_copy(optimizer.state_dict()) - - resolved_archive_file, sharded_metadata = get_optimizer_shard_files( - optimizer_path=resume_from_checkpoint, - index_filename=os.path.join(resume_from_checkpoint, SAFE_OPTIMIZER_INDEX_NAME), - ) - has_master_weights = True if sharded_metadata["master_weights"] else False - - model_state_dict = get_expected_state_dict(model) - struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} - expected_keys = sharded_metadata["all_optimizer_keys"] - - if has_master_weights: - returned_optim_state_dict["master_weights"] = {} - resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files( - optimizer_path=resume_from_checkpoint, - index_filename=os.path.join(resume_from_checkpoint, SAFE_MASTER_WEIGHTS_INDEX_NAME), - ) - expected_keys_mw = sharded_metadata_mw["all_optimizer_keys"] - - state_dict_optim = load_state_dict(resolved_archive_file[0], None, expected_keys) - if has_master_weights: - state_dict_optim_mw = load_state_dict(resolved_archive_file_mw[0], None, expected_keys_mw) - - for key in list(state_dict_optim.keys()): - key_name = key.split("/") - static_name = struct2static_name_mappings[key_name[0]] - if has_master_weights: - if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) - else: - key_name = "_".join([static_name, key_name[1]]) - returned_optim_state_dict[key_name] = state_dict_optim.pop(key) - returned_optim_state_dict[key_name].name = key_name - if has_master_weights: - for key in list(state_dict_optim_mw.keys()): - static_name = struct2static_name_mappings[key] - returned_optim_state_dict["master_weights"][static_name] = state_dict_optim_mw.pop(key) - returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) - - returned_optim_state_dict = nested_copy_place( - returned_optim_state_dict, - place=paddle.framework._current_expected_place(), - blocking=True, - ) - return returned_optim_state_dict - - -def get_file_mappings(index, resume_from_checkpoint): - file_keyname_mappings = {} - for k, v in index["weight_map"].items(): - if v not in file_keyname_mappings: - file_keyname_mappings[v] = [] - file_keyname_mappings[v].append(k) - for k in file_keyname_mappings.keys(): - file_keyname_mappings[k] = sorted(file_keyname_mappings[k]) - - local_device_count = int(os.getenv("PADDLE_LOCAL_SIZE")) - local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) - global_rank = dist.get_rank() - file_machine_mappings = {} - for filename in file_keyname_mappings.keys(): - if local_rank == 0 and os.path.exists(os.path.join(resume_from_checkpoint, filename)): - file_machine_mappings[filename] = [global_rank // local_device_count] - file_machine_list = [] - dist.all_gather_object(file_machine_list, file_machine_mappings) - file_machine_mappings = {} - for mappings in file_machine_list: - for k, v in mappings.items(): - if k not in file_machine_mappings: - file_machine_mappings[k] = v - else: - file_machine_mappings[k] += v - return file_keyname_mappings, file_machine_mappings - - -def create_send_table(file_keyname_mappings, file_machine_mappings): - send_table = {} - global_rank = dist.get_rank() - local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) - local_device_count = int(os.getenv("PADDLE_LOCAL_SIZE")) - for filename, keys in file_keyname_mappings.items(): - machine = file_machine_mappings[filename][0] - is_src = (global_rank // local_device_count) == machine - for i, key in enumerate(keys): - if is_src and local_rank == i % local_device_count: - send_table[key] = global_rank - dispatch_list = [] - dist.all_gather_object(dispatch_list, send_table) - send_table = {} - for dl in dispatch_list: - send_table.update(dl) - return send_table - - -def distributed_send_recv( - config, - state_dict, - tp_actions, - send_table, - recv_table, - resume_from_checkpoint, - file_keyname_mappings, - file_machine_mappings, -): - - local_device_count = int(os.getenv("PADDLE_LOCAL_SIZE")) - global_rank = dist.get_rank() - for filename in file_keyname_mappings.keys(): - machine = file_machine_mappings[filename][0] - is_src = global_rank // local_device_count == machine - if is_src: - f = safe_open(os.path.join(resume_from_checkpoint, filename), framework="np") - - for key in file_keyname_mappings[filename]: - recv_info = recv_table[key] - recv_ranklist = [a for (a, b) in recv_info] - if is_src and global_rank == send_table[key]: - py_safe_slice_ = f.get_slice(key) - # send - if key in tp_actions: - weight = tp_actions[key](py_safe_slice_) - # copy weight to GPU - for j in range(len(weight)): - with device_guard(): - weight[j] = paddle.Tensor(weight[j], zero_copy=True) - weight[j] = weight[j]._copy_to(paddle.framework._current_expected_place(), False) - - for recv_rank, split_index in recv_info: - if recv_rank == global_rank: - state_dict[key] = weight[split_index] - else: - dist.stream.send(weight[split_index], dst=recv_rank) - else: - # no need to tp split - weight = py_safe_slice_[:] - with device_guard(): - weight = paddle.Tensor(weight, zero_copy=True) - weight = weight._copy_to(paddle.framework._current_expected_place(), False) - for recv_rank, _ in recv_info: - if recv_rank == global_rank: - state_dict[key] = weight - else: - dist.stream.send(weight, dst=recv_rank) - - if global_rank != send_table[key] and global_rank in recv_ranklist: - dist.stream.recv(state_dict[key], src=send_table[key]) - - if is_src: - f.__exit__(None, None, None) - - return state_dict - - -def distributed_send_recv_splited_param( - state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False -): - global_rank = dist.get_rank() - for key in list(state_dict.keys()): - if state_dict[key].numel().item() == 1: # for example: beta1, beta2 - continue - - static_name = key if is_master_weights else generate_base_static_name(key)[0] - shape, numel, index, padded_size = param_shape_info[static_name] - if static_name not in partial_tensor_list: - state_dict[key] = state_dict[key].reshape(shape) - continue - - recv_rank = recv_table[static_name] - send_info = send_table[static_name] - - base_padding_start = index + numel - base_padding_end = index + padded_size - - if global_rank == recv_rank: - tmp_tensor_list = [] - for send_rank, begin, end in send_info: - padding_start = max(begin, base_padding_start) - padding_end = min(end, base_padding_end) - - if send_rank == recv_rank: - tensor = ( - state_dict[key] if padding_start >= padding_end else state_dict[key][: padding_start - begin] - ) - tmp_tensor_list.append(tensor) - else: - length = end - begin if padding_start >= padding_end else padding_start - begin - tmp_tensor = paddle.empty(shape=[length], dtype=state_dict[key].dtype) - dist.stream.recv(tmp_tensor, src=send_rank) - tmp_tensor_list.append(tmp_tensor) - state_dict[key] = paddle.concat(tmp_tensor_list, axis=0).reshape(shape) - else: - for send_rank, begin, end in send_info: - padding_start = max(begin, base_padding_start) - padding_end = min(end, base_padding_end) - if global_rank == send_rank: - tensor = ( - state_dict[key] if padding_start >= padding_end else state_dict[key][: padding_start - begin] - ) - dist.stream.send(tensor, dst=recv_rank) - state_dict.pop(key) - return state_dict diff --git a/paddlenlp/trainer/plugins/unified_checkpoint_dynamic.py b/paddlenlp/trainer/plugins/unified_checkpoint_dynamic.py new file mode 100644 index 000000000000..bd5a5873b359 --- /dev/null +++ b/paddlenlp/trainer/plugins/unified_checkpoint_dynamic.py @@ -0,0 +1,493 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unified Checkpoint Dynamic Loading Functions""" + +import copy +import json +import os +import sys + +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet + +try: + from paddle.base import core +except: + core = None + +from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM +from paddlenlp.transformers.model_utils import _load_state_dict_into_model +from paddlenlp.transformers.utils import device_guard, is_safetensors_available +from paddlenlp.utils.env import ( + PADDLE_MASTER_WEIGHTS_INDEX_NAME, + PADDLE_OPTIMIZER_INDEX_NAME, + SAFE_MASTER_WEIGHTS_INDEX_NAME, + SAFE_OPTIMIZER_INDEX_NAME, +) +from paddlenlp.utils.log import logger +from paddlenlp.utils.nested import nested_copy + +if is_safetensors_available(): + if sys.platform.startswith("win"): + from safetensors import safe_open + else: + from paddlenlp.utils.safetensors import fast_safe_open as safe_open + +from .unified_checkpoint_utils import ( + FP32_MASTER, + get_expected_state_dict, + mapping_optimizer_tp_actions, + optimizer_non_scaler_name, + optimizer_scalar_name, + select_model_weight_index, + update_master_weight_status, +) + + +def create_send_table(file_keyname_mappings, file_machine_mappings): + send_table = {} + global_rank = dist.get_rank() + local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) + local_device_count = int(os.getenv("PADDLE_LOCAL_SIZE")) + for filename, keys in file_keyname_mappings.items(): + machine = file_machine_mappings[filename][0] + is_src = (global_rank // local_device_count) == machine + for i, key in enumerate(keys): + if is_src and local_rank == i % local_device_count: + send_table[key] = global_rank + dispatch_list = [] + dist.all_gather_object(dispatch_list, send_table) + send_table = {} + for dl in dispatch_list: + send_table.update(dl) + return send_table + + +def create_dispatch_table(args, model, file_keyname_mappings, file_machine_mappings): + """Create dispatch table for dynamically loading state dict. + + Args: + args + """ + + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + tp_rank = tp_group.rank + + # Create tensor receive table, contains {"key0": [global_rank, tp_rank], "key1": [global_rank, tp_rank]} + dispatch_list = [] + recv_table = {} + if args.dataset_rank == 0: + state_dict = get_expected_state_dict(model) + for (k, v) in state_dict.items(): + if hasattr(v, "is_distributed") and v.is_distributed: + recv_table[k] = [(dist.get_rank(), tp_rank)] + else: + recv_table[k] = [(dist.get_rank(), -1)] + + # Gather receive table in global group. + dist.all_gather_object(dispatch_list, recv_table) + recv_table = {} + for dl in dispatch_list: + for key, value in dl.items(): + if key not in recv_table: + recv_table[key] = value + else: + recv_table[key] += value + + # Create send table, to decide which worker to send the key. Contains {"key0:" global_rank, "key1": global_rank, ...} + send_table = create_send_table(file_keyname_mappings, file_machine_mappings) + + return send_table, recv_table + + +def create_optimizer_dispatch_table( + args, + model, + optimizer, + file_keyname_mappings, + file_machine_mappings, + struct2static_name_mappings, + is_master_weights=False, + typename_set=None, +): + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + sharding_group = hcg.get_sharding_parallel_group() + sharding_rank = sharding_group.rank + if sharding_group.nranks > 1: + param2rank = optimizer._param2rank + tp_rank = tp_group.rank + + # Create receive table, contains {"param_key0": [global_rank, tp_rank], "param_key1": [global_rank, tp_rank]} + dispatch_list = [] + recv_table = {} + if args.data_parallel_rank == 0: + state_dict = get_expected_state_dict(model) + for (k, v) in state_dict.items(): + if sharding_group.nranks > 1: + static_name = struct2static_name_mappings[k] + param_rank = param2rank.get(static_name, None) + if param_rank != sharding_rank: + continue + if is_master_weights: + if hasattr(v, "is_distributed") and v.is_distributed: + recv_table[k] = [(dist.get_rank(), tp_rank)] + else: + recv_table[k] = [(dist.get_rank(), -1)] + else: + for typename in typename_set: + type_key = k + "/" + typename + if typename in optimizer_non_scaler_name: + if hasattr(v, "is_distributed") and v.is_distributed: + recv_table[type_key] = [(dist.get_rank(), tp_rank)] + else: + recv_table[type_key] = [(dist.get_rank(), -1)] + else: + recv_table[type_key] = [(dist.get_rank(), -1)] + + dist.all_gather_object(dispatch_list, recv_table) + recv_table = {} + for dl in dispatch_list: + for k, v in dl.items(): + if k not in recv_table: + recv_table[k] = v + else: + recv_table[k] += v + + # Create send table, to decide which worker to send the key. Contains {"param_key0:" 0, "param_key1": 1, ...} + send_table = create_send_table(file_keyname_mappings, file_machine_mappings) + return send_table, recv_table + + +def get_file_mappings(index, resume_from_checkpoint): + file_keyname_mappings = {} + for k, v in index["weight_map"].items(): + if v not in file_keyname_mappings: + file_keyname_mappings[v] = [] + file_keyname_mappings[v].append(k) + for k in file_keyname_mappings.keys(): + file_keyname_mappings[k] = sorted(file_keyname_mappings[k]) + + local_device_count = int(os.getenv("PADDLE_LOCAL_SIZE")) + local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) + global_rank = dist.get_rank() + file_machine_mappings = {} + for filename in file_keyname_mappings.keys(): + if local_rank == 0 and os.path.exists(os.path.join(resume_from_checkpoint, filename)): + file_machine_mappings[filename] = [global_rank // local_device_count] + file_machine_list = [] + dist.all_gather_object(file_machine_list, file_machine_mappings) + file_machine_mappings = {} + for mappings in file_machine_list: + for k, v in mappings.items(): + if k not in file_machine_mappings: + file_machine_mappings[k] = v + else: + file_machine_mappings[k] += v + return file_keyname_mappings, file_machine_mappings + + +def distributed_send_recv( + state_dict, + tp_actions, + send_table, + recv_table, + resume_from_checkpoint, + file_keyname_mappings, + file_machine_mappings, +): + + local_device_count = int(os.getenv("PADDLE_LOCAL_SIZE")) + global_rank = dist.get_rank() + for filename in file_keyname_mappings.keys(): + machine = file_machine_mappings[filename][0] + is_src = global_rank // local_device_count == machine + if is_src: + f = safe_open(os.path.join(resume_from_checkpoint, filename), framework="np") + + for key in file_keyname_mappings[filename]: + recv_info = recv_table[key] + recv_ranklist = [a for (a, _) in recv_info] + if is_src and global_rank == send_table[key]: + py_safe_slice_ = f.get_slice(key) + # send + if key in tp_actions: + weight = tp_actions[key](py_safe_slice_) + # copy weight to GPU + for j in range(len(weight)): + with device_guard(): + weight[j] = paddle.Tensor(weight[j], zero_copy=True) + weight[j] = weight[j]._copy_to(paddle.framework._current_expected_place(), False) + + for recv_rank, split_index in recv_info: + if recv_rank == global_rank: + state_dict[key] = weight[split_index] + else: + dist.stream.send(weight[split_index], dst=recv_rank) + else: + # no need to tp split + weight = py_safe_slice_[:] + with device_guard(): + weight = paddle.Tensor(weight, zero_copy=True) + weight = weight._copy_to(paddle.framework._current_expected_place(), False) + for recv_rank, _ in recv_info: + if recv_rank == global_rank: + state_dict[key] = weight + else: + dist.stream.send(weight, dst=recv_rank) + + if global_rank != send_table[key] and global_rank in recv_ranklist: + dist.stream.recv(state_dict[key], src=send_table[key]) + + if is_src: + f.__exit__(None, None, None) + + return state_dict + + +def load_unified_checkpoint_dynamically(args, model, resume_from_checkpoint, safe_serialization=False): + index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=False) + index_filename = os.path.join(resume_from_checkpoint, index_filename) + + with open(index_filename, "r") as f: + index = json.loads(f.read()) + + # `file_keyname_mappings` indicates which keys each file contains. For example, {"model-00001-of-00002.safetensors": ["llama.embed_tokens.weight", "llama.layers.0.self_attn.q_proj.weight", ...]} + # `file_machine_mappings` indicates the machine where the files appear. For example, {"model-00001-of-00002.safetensors": [machine_0, machine_1], "model-00002-of-00002.safetensors": [machine_0]} + file_keyname_mappings, file_machine_mappings = get_file_mappings(index, resume_from_checkpoint) + + logger.debug("Creating dispatch table for unified checkpoint load ...") + # Get send_table and recv_table. The send table indicates which workers are responsible for sending tensors, and the recv table indicates which workers should receive the tensors. + send_table, recv_table = create_dispatch_table( + args, + model, + file_keyname_mappings, + file_machine_mappings, + ) + + # Get all the keys that are splited by tensor parallelism. + all_tp_keys = set() + for k, v in recv_table.items(): + if v[0][1] != -1: + all_tp_keys.add(k) + + config_revise = copy.deepcopy(model.config) + config_revise.tensor_parallel_rank = None + if len(all_tp_keys) == 0: + tp_actions = {} + else: + # Get corresponding tensor parallel actions. + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + tp_actions = model._get_tensor_parallel_convert_actions( + set(all_tp_keys), is_split=True, ignore_error=True, config=config_revise + ) + else: + tp_actions = model.get_tensor_parallel_convert_actions(config_revise, all_tp_keys, ignore_error=True) + + logger.debug("Distributed send recv for state dict load ...") + # Distribute the checkpoint tensor dynamically, using the `send_table` and `recv_table` we create before. + state_dict = distributed_send_recv( + get_expected_state_dict(model), + tp_actions, + send_table, + recv_table, + resume_from_checkpoint, + file_keyname_mappings, + file_machine_mappings, + ) + dist.barrier() + logger.debug("Setting state dict into model ...") + error_msgs = _load_state_dict_into_model(model, state_dict, "") + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + raise RuntimeError(f"Error(s) in loading dynamic state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + +def load_unified_optimizer_dynamically(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): + optim_state_dict = nested_copy(optimizer.state_dict()) + if "master_weights" in optim_state_dict.keys(): + optim_state_dict.pop("master_weights") + + if safe_serialization: + index_filename, index_filename_mw = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME + else: + index_filename, index_filename_mw = PADDLE_OPTIMIZER_INDEX_NAME, PADDLE_MASTER_WEIGHTS_INDEX_NAME + + with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f: + index = json.loads(f.read()) + + # `file_keyname_mappings` indicates which keys each file contains. For example, {"optimizer-00001-of-00002.safetensors": ["llama.embed_tokens.weight/moment1_0", "llama.layers.1.mlp.gate_proj.weight/moment1_0", ...]} + # `file_machine_mappings` indicates the machine where the files appear. For example, {"optimizer-00001-of-00002.safetensors": [machine_0, machine_1], "optimizer-00002-of-00002.safetensors": [machine_0]} + file_keyname_mappings, file_machine_mappings = get_file_mappings(index, resume_from_checkpoint) + + has_master_weights = index["master_weights"] + # update has_master_weights and index_filename_master_weights + # 1. if the master weights exists, only has_master_weights is set True and load master weights when needed + # 2. if master weights does not exist, convert model weights to master weights when needed + has_master_weights, index_filename_mw = update_master_weight_status( + args, optimizer, has_master_weights, safe_serialization + ) + + if has_master_weights: + with open(os.path.join(resume_from_checkpoint, index_filename_mw), "r") as f: + index_mw = json.loads(f.read()) + file_keyname_mappings_mw, file_machine_mappings_mw = get_file_mappings(index_mw, resume_from_checkpoint) + + # Get optimizer param type name, like moment1_0, moment2_0, beta1_pow_acc_0. + typename_set = set() + for key in index["weight_map"].keys(): + _, typename = key.split("/") + typename_set.add(typename) + + model_state_dict = get_expected_state_dict(model) + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} + static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} + # Get send_table and recv_table. The send table indicates which workers are responsible for sending tensors, and the recv table indicates which workers should receive the tensors. + send_table, recv_table = create_optimizer_dispatch_table( + args, + model, + optimizer, + file_keyname_mappings, + file_machine_mappings, + struct2static_name_mappings, + is_master_weights=False, + typename_set=typename_set, + ) + if has_master_weights: + send_table_mw, recv_table_mw = create_optimizer_dispatch_table( + args, + model, + optimizer, + file_keyname_mappings_mw, + file_machine_mappings_mw, + struct2static_name_mappings, + is_master_weights=True, + ) + + # Initialize optimizer state dict. + hcg = fleet.get_hybrid_communicate_group() + sharding_group = hcg.get_sharding_parallel_group() + if sharding_group.nranks > 1: + param2rank = optimizer._param2rank + optim_state_dict_mw = {} + + def check_optimizer_param(parameter): + if sharding_group.nranks > 1: + param_rank = param2rank.get(parameter.name, None) + if param_rank != sharding_group.rank: + return False + if parameter.stop_gradient: + return False + return True + + optimizer_keys_with_shape = [] + if isinstance(optimizer._parameter_list[0], dict): + for param_group in optimizer._parameter_list: + # If parameter groups are set, there must be `params` key. This is guaranteed by the optimizer's initialization code. + for parameter in param_group["params"]: + if check_optimizer_param(parameter): + optimizer_keys_with_shape.append((parameter.name, parameter.shape)) + else: + for parameter in optimizer._parameter_list: + if check_optimizer_param(parameter): + optimizer_keys_with_shape.append((parameter.name, parameter.shape)) + + # see how to change + for static_name, shape in optimizer_keys_with_shape: + k = static2struct_name_mappings[static_name] + for typename in typename_set: + new_k = k + "/" + typename + if typename in optimizer_scalar_name: + optim_state_dict[new_k] = paddle.empty([1], dtype="float32") + else: + optim_state_dict[new_k] = paddle.empty(shape, dtype="float32") + if has_master_weights: + optim_state_dict_mw[k] = paddle.empty(shape, dtype="float32") + + # Get all the keys that are splited by tensor parallelism. + all_tp_keys = set() + for k, v in recv_table.items(): + structure_name, typename = k.split("/") + if typename in optimizer_non_scaler_name: + if v[0][1] != -1: + all_tp_keys.add(structure_name) + + # Get corresponding tensor parallel actions. + config_revise = copy.deepcopy(model.config) + config_revise.tensor_parallel_rank = None + if len(all_tp_keys) == 0: + tp_actions = {} + else: + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + tp_actions = model._get_tensor_parallel_convert_actions( + set(all_tp_keys), is_split=True, ignore_error=True, config=config_revise + ) + else: + tp_actions = model.get_tensor_parallel_convert_actions(config_revise, all_tp_keys, ignore_error=True) + optimizer_keys = list(index["weight_map"].keys()) + optimizer_tp_actions = mapping_optimizer_tp_actions(tp_actions, optimizer_keys) + if has_master_weights: + optimizer_tp_actions.update(tp_actions) + + # Distribute the optimizer checkpoint dynamically, using the `send_table` and `recv_table` we create before. + optim_state_dict = distributed_send_recv( + optim_state_dict, + optimizer_tp_actions, + send_table, + recv_table, + resume_from_checkpoint, + file_keyname_mappings, + file_machine_mappings, + ) + dist.barrier() + if has_master_weights: + optim_state_dict_mw = distributed_send_recv( + optim_state_dict_mw, + optimizer_tp_actions, + send_table_mw, + recv_table_mw, + resume_from_checkpoint, + file_keyname_mappings_mw, + file_machine_mappings_mw, + ) + dist.barrier() + + # Rename optimizer state dict. + for key in list(optim_state_dict.keys()): + if key == "LR_Scheduler": + continue + key_name = key.split("/") + static_name = struct2static_name_mappings[key_name[0]] + if has_master_weights: + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) + optim_state_dict[key_name] = optim_state_dict.pop(key) + optim_state_dict[key_name].name = key_name + + if has_master_weights: + optim_state_dict["master_weights"] = {} + for key in list(optim_state_dict_mw.keys()): + static_name = struct2static_name_mappings[key] + optim_state_dict["master_weights"][static_name] = optim_state_dict_mw.pop(key) + optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) + + if args.data_parallel_rank == 0: + return optim_state_dict + return None diff --git a/paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py b/paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py new file mode 100644 index 000000000000..ac41fa287aef --- /dev/null +++ b/paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py @@ -0,0 +1,298 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Support Sharding Stage1 V2(split param) for Unified Checkpoint""" + +import gc +import os + +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet +from tqdm.auto import tqdm + +from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM +from paddlenlp.transformers.model_utils import load_state_dict +from paddlenlp.utils.env import ( + SAFE_MASTER_WEIGHTS_INDEX_NAME, + SAFE_OPTIMIZER_INDEX_NAME, +) +from paddlenlp.utils.nested import nested_copy + +from .unified_checkpoint_utils import ( + FP32_MASTER, + generate_base_static_name, + get_expected_state_dict, + get_optimizer_shard_files, + mapping_optimizer_tp_actions, +) + + +def distributed_send_recv_splited_param( + state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False +): + global_rank = dist.get_rank() + for key in list(state_dict.keys()): + if state_dict[key].numel().item() == 1: # for example: beta1, beta2 + continue + + static_name = key if is_master_weights else generate_base_static_name(key)[0] + shape, numel, index, padded_size = param_shape_info[static_name] + if static_name not in partial_tensor_list: + state_dict[key] = state_dict[key].reshape(shape) + continue + + recv_rank = recv_table[static_name] + send_info = send_table[static_name] + + base_padding_start = index + numel + base_padding_end = index + padded_size + + if global_rank == recv_rank: + tmp_tensor_list = [] + for send_rank, begin, end in send_info: + padding_start = max(begin, base_padding_start) + padding_end = min(end, base_padding_end) + + if send_rank == recv_rank: + tensor = ( + state_dict[key] if padding_start >= padding_end else state_dict[key][: padding_start - begin] + ) + tmp_tensor_list.append(tensor) + else: + length = end - begin if padding_start >= padding_end else padding_start - begin + tmp_tensor = paddle.empty(shape=[length], dtype=state_dict[key].dtype) + dist.stream.recv(tmp_tensor, src=send_rank) + tmp_tensor_list.append(tmp_tensor) + state_dict[key] = paddle.concat(tmp_tensor_list, axis=0).reshape(shape) + else: + for send_rank, begin, end in send_info: + padding_start = max(begin, base_padding_start) + padding_end = min(end, base_padding_end) + if global_rank == send_rank: + tensor = ( + state_dict[key] if padding_start >= padding_end else state_dict[key][: padding_start - begin] + ) + dist.stream.send(tensor, dst=recv_rank) + state_dict.pop(key) + return state_dict + + +def gather_splited_param_for_optimizer(optimizer): + hcg = fleet.get_hybrid_communicate_group() + sharding_group = hcg.get_sharding_parallel_group() + global_rank = dist.get_rank() + param_slice_info = {} + param_shape_info = {} + for buffer in optimizer._inner_opt._comm_buffer_list: + for key in buffer._sharding_param_grad_view.keys(): + param_slice_info[key] = ( + buffer._sharding_param_grad_view[key]._param_begin, + buffer._sharding_param_grad_view[key]._param_end, + ) + param_shape_info[key] = ( + buffer._sharding_param_grad_view[key]._param.shape, + buffer._sharding_param_grad_view[key]._param.numel().item(), + buffer._sharding_param_grad_view[key]._index, + buffer._sharding_param_grad_view[key]._padded_size, + ) + param_slice_info["global_rank"] = global_rank + param_slice_info_list = [] + dist.all_gather_object(param_slice_info_list, param_slice_info, group=sharding_group) + + optim_state_dict = nested_copy(optimizer.state_dict()) + master_weights = None + if "master_weights" in optim_state_dict.keys(): + master_weights = optim_state_dict.pop("master_weights") + if "LR_Scheduler" in optim_state_dict.keys(): + optim_state_dict.pop("LR_Scheduler") + + # deal with optimizer param + partial_tensor_list = [] + for key in list(optim_state_dict.keys()): + static_name, _ = generate_base_static_name(key) + if static_name in param_slice_info.keys(): + if optim_state_dict[key].numel().item() == 1: # for example: beta1, beta2 + continue + begin, end = param_slice_info[static_name] + shape, numel, _, _ = param_shape_info[static_name] + if end - begin == numel: # full tensor + optim_state_dict[key] = optim_state_dict[key].reshape(shape) + elif end <= begin: # empty tensor + continue + else: # partial tensor, end > begin but end - begin < numel + partial_tensor_list.append(static_name) + + send_table = {} + recv_table = {} + for key in partial_tensor_list: + sharding_ranklist = [] + for slice_info in param_slice_info_list: + begin, end = slice_info[key] + if end > begin: + sharding_ranklist.append((slice_info["global_rank"], begin, end)) + recv_table[key] = sharding_ranklist[0][0] # which sharding_rank to recv the splited tensor + send_table[key] = [(rank, begin, end) for rank, begin, end in sharding_ranklist] + + distributed_send_recv_splited_param( + optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False + ) + if master_weights is not None: + distributed_send_recv_splited_param( + master_weights, partial_tensor_list, param_shape_info, send_table, recv_table, True + ) + return optim_state_dict, master_weights + + +def load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint): + returned_optim_state_dict = nested_copy(optimizer.state_dict()) + + index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME + + resolved_archive_file, sharded_metadata = get_optimizer_shard_files( + optimizer_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, index_filename), + ) + has_master_weights = True if sharded_metadata["master_weights"] else False + + typename_set = set() + for key in sharded_metadata["weight_map"].keys(): + _, typename = key.split("/") + typename_set.add(typename) + + model_state_dict = get_expected_state_dict(model) + model_keys = list(model_state_dict.keys()) + static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} # get optimizer param mappings + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} + + expected_keys = [] + param_slice_info = {} + param_shape_info = {} + for buffer in optimizer._inner_opt._comm_buffer_list: + for key in buffer._sharding_param_grad_view.keys(): + begin = buffer._sharding_param_grad_view[key]._param_begin + end = buffer._sharding_param_grad_view[key]._param_end + if end > begin: + expected_keys.append(key) + shape = buffer._sharding_param_grad_view[key]._param.shape + numel = buffer._sharding_param_grad_view[key]._param.numel().item() + index = buffer._sharding_param_grad_view[key]._index + padded_size = buffer._sharding_param_grad_view[key]._padded_size + param_slice_info[key] = (begin, end) + param_shape_info[key] = (shape, numel, index, padded_size) + + expected_keys = set([static2struct_name_mappings.get(name, None) for name in expected_keys]) + expected_keys_optim = [] + for key in expected_keys: + for typename in typename_set: + expected_keys_optim.append(f"{key}/{typename}") + expected_keys_optim = set(expected_keys_optim) + + if len(resolved_archive_file) > 1: + resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards") + + if has_master_weights: + returned_optim_state_dict["master_weights"] = {} + resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files( + optimizer_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, index_filename_master_weights), + ) + if len(resolved_archive_file_mw) > 1: + resolved_archive_file_mw = tqdm(resolved_archive_file_mw, desc="Loading master weights shards") + + def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False): + returned_state_dict = {} + + if model.config.tensor_parallel_degree > 1: + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + tp_actions = model._get_tensor_parallel_convert_actions(model_keys, is_split=True, ignore_error=True) + else: + tp_actions = model.get_tensor_parallel_convert_actions(model.config, model_keys, ignore_error=True) + if not is_master_weights: + tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys) + + for shard_file in resolved_archive_file: + if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]): + continue + + if model.config.tensor_parallel_degree > 1: + state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected") + else: + state_dict = load_state_dict(shard_file, None, expected_keys, device="expected") + + returned_state_dict.update(state_dict) + del state_dict + gc.collect() + + return returned_state_dict + + # get tp params + state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys_optim) + if has_master_weights: + state_dict_master_weight = load_resolved_archive_file( + resolved_archive_file_mw, + sharded_metadata_mw, + expected_keys, + is_master_weights=True, + ) + + # need to split param for different sharding rank, maybe need to deal with oom issue. + for key in list(state_dict_optim.keys()): + key_name = key.split("/") + static_name = struct2static_name_mappings.get(key_name[0], None) + + if state_dict_optim[key].numel().item() > 1: + begin, end = param_slice_info[static_name] + shape, numel, index, padded_size = param_shape_info[static_name] + state_dict_optim[key] = state_dict_optim[key].reshape([-1]) + state_dict_optim[key] = state_dict_optim[key][begin - index : end - index] + + padding_start = max(begin, index + numel) + padding_end = min(end, index + padded_size) + if padding_start < padding_end: + state_dict_optim[key] = paddle.concat( + ( + state_dict_optim[key], + paddle.zeros([padding_end - padding_start], dtype=state_dict_optim[key].dtype), + ) + ) + + if has_master_weights: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) + returned_optim_state_dict[key_name] = state_dict_optim.pop(key) + returned_optim_state_dict[key_name].name = key_name + + if has_master_weights: + for key in list(state_dict_master_weight.keys()): + static_name = struct2static_name_mappings.get(key, None) + if state_dict_master_weight[key].numel().item() > 1: + begin, end = param_slice_info[static_name] + shape, numel, index, padded_size = param_shape_info[static_name] + state_dict_master_weight[key] = state_dict_master_weight[key].reshape([-1]) + state_dict_master_weight[key] = state_dict_master_weight[key][begin - index : end - index] + + padding_start = max(begin, index + numel) + padding_end = min(end, index + padded_size) + if padding_start < padding_end: + state_dict_master_weight[key] = paddle.concat( + ( + state_dict_master_weight[key], + paddle.zeros([padding_end - padding_start], dtype=state_dict_master_weight[key].dtype), + ) + ) + returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key) + returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) + + return returned_optim_state_dict diff --git a/paddlenlp/trainer/plugins/unified_checkpoint_single_card.py b/paddlenlp/trainer/plugins/unified_checkpoint_single_card.py new file mode 100644 index 000000000000..8931e33617e5 --- /dev/null +++ b/paddlenlp/trainer/plugins/unified_checkpoint_single_card.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Save and load single card checkpoint for Unified Checkpoint""" + +import gc +import json +import os + +import paddle + +try: + from paddle.base import core +except: + core = None + +from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM +from paddlenlp.transformers.model_utils import ( + _load_state_dict_into_model, + load_state_dict, +) +from paddlenlp.transformers.utils import ( + dtype_byte_size, + get_checkpoint_shard_files, + is_safetensors_available, +) +from paddlenlp.utils.env import ( + SAFE_MASTER_WEIGHTS_INDEX_NAME, + SAFE_OPTIMIZER_INDEX_NAME, + SAFE_PEFT_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_INDEX_NAME, +) +from paddlenlp.utils.log import logger +from paddlenlp.utils.nested import nested_copy + +if is_safetensors_available(): + from safetensors.numpy import save_file as safe_save_file + +from .unified_checkpoint_utils import ( + FP32_MASTER, + generate_base_static_name, + get_expected_state_dict, + get_optimizer_shard_files, + save_config, + save_prefix_past_key_value, +) + + +def save_file_sync(state_dict, path): + for k in list(state_dict.keys()): + if isinstance(state_dict[k], paddle.Tensor): + state_dict[k] = state_dict.pop(k).cpu().numpy() + safe_save_file(state_dict, path, metadata={"format": "np"}) + + +def save_single_card_checkpoint(model_to_save, output_dir): + """Save checkpoint for non-distributed environment.""" + + state_dict = get_expected_state_dict(model_to_save) + if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): + weight_filename = "peft_model-00001-of-00001.safetensors" + index_filename = SAFE_PEFT_WEIGHTS_INDEX_NAME + else: + weight_filename = "model-00001-of-00001.safetensors" + index_filename = SAFE_WEIGHTS_INDEX_NAME + # get index json + index_weight_file = {} + total_size = 0 + for key, weight in state_dict.items(): + index_weight_file[key] = weight_filename + total_size += weight.numel().item() * dtype_byte_size(weight.dtype) + sharded_index_json = {} + sharded_index_json["metadata"] = {"total_size": total_size} + sharded_index_json["weight_map"] = index_weight_file + if isinstance(model_to_save, LoRAModel): + sharded_index_json["type"] = "lora" + elif isinstance(model_to_save, PrefixModelForCausalLM): + sharded_index_json["type"] = "ptuning" + + os.makedirs(output_dir, exist_ok=True) + path = os.path.join(output_dir, index_filename) + with open(path, "w") as f: + json.dump(sharded_index_json, f, indent=4) + + # save checkpoint, do no support asynchronous save for single card currently. + logger.warning("Asynchronous saving is not supported for single card environment currently.") + save_file_sync(state_dict, path=os.path.join(output_dir, weight_filename)) + + if isinstance(model_to_save, PrefixModelForCausalLM): + save_prefix_past_key_value(model_to_save, output_dir) + model_to_save.prefix_config.save_pretrained(output_dir) + if isinstance(model_to_save, LoRAModel): + model_to_save.lora_config.save_pretrained(output_dir) + + config_to_save = save_config(model_to_save) + config_to_save.architectures = [model_to_save.__class__.__name__] + config_to_save.save_pretrained(output_dir) + + +def save_single_card_optimizer(model, optimizer, output_dir): + """ "Save optimizer for non-distributed environment.""" + # Split into optimizer params and master weights. + optim_state_dict = nested_copy(optimizer.state_dict()) + master_weights = None + if "master_weights" in optim_state_dict.keys(): + master_weights = optim_state_dict.pop("master_weights") + if "LR_Scheduler" in optim_state_dict.keys(): + optim_state_dict.pop("LR_Scheduler") + + static2struct_name_mappings = {} + state_dict = get_expected_state_dict(model) + fp32_weight = {} + for k, v in state_dict.items(): + static2struct_name_mappings[v.name] = k + if master_weights is not None and v.dtype == core.VarDesc.VarType.FP32: + fp32_weight[k] = v + + # rename optimizer param + for key in list(optim_state_dict.keys()): + static_name, type_name = generate_base_static_name(key) + new_name = static2struct_name_mappings[static_name] + "/" + type_name + optim_state_dict[new_name] = optim_state_dict.pop(key) + if master_weights is not None: + for key in list(master_weights.keys()): + master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) + master_weights.update(fp32_weight) + + # save index json + index_optimizer_file, index_master_weight_file = {}, {} + total_optim_size, total_master_weight_size = 0, 0 + for key, weight in optim_state_dict.items(): + index_optimizer_file[key] = "optimizer-00001-of-00001.safetensors" + total_optim_size += weight.numel().item() * dtype_byte_size(weight.dtype) + if master_weights is not None: + for key, weight in master_weights.items(): + index_master_weight_file[key] = "master_weights-00001-of-00001.safetensors" + total_master_weight_size += weight.numel().item() * dtype_byte_size(weight.dtype) + path = os.path.join(output_dir, SAFE_OPTIMIZER_INDEX_NAME) + master_path = os.path.join(output_dir, SAFE_MASTER_WEIGHTS_INDEX_NAME) + with open(path, "w") as f: + has_master_weights = master_weights is not None + json.dump( + { + "metadata": {"total_size": total_optim_size}, + "weight_map": index_optimizer_file, + "master_weights": has_master_weights, + }, + f, + indent=4, + ) + if master_weights is not None: + with open(master_path, "w") as f: + json.dump( + {"metadata": {"total_size": total_master_weight_size}, "weight_map": index_master_weight_file}, + f, + indent=4, + ) + + # save optimizer state dict + save_file_sync(optim_state_dict, path=os.path.join(output_dir, "optimizer-00001-of-00001.safetensors")) + if master_weights is not None: + save_file_sync(master_weights, path=os.path.join(output_dir, "master_weights-00001-of-00001.safetensors")) + + +def load_single_card_checkpoint(model, resume_from_checkpoint: str): + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + index_filename = SAFE_PEFT_WEIGHTS_INDEX_NAME + else: + index_filename = SAFE_WEIGHTS_INDEX_NAME + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, index_filename), + ) + + loaded_keys = sharded_metadata["all_checkpoint_keys"] + model_state_dict = get_expected_state_dict(model) + expected_keys = set(list(model_state_dict.keys())) + missing_keys = expected_keys - set(loaded_keys) + + if len(missing_keys) > 0: + raise ValueError(f"Missing keys: {missing_keys}") + + state_dict = load_state_dict(resolved_archive_file[0], None, expected_keys) + error_msgs = _load_state_dict_into_model(model, state_dict, "") + del state_dict + gc.collect() + + if error_msgs: + raise RuntimeError(f"Error(s) in loading state dict for {model.__class__.__name__}:\n\t{error_msgs}") + + +def load_single_card_optimizer(model, optimizer, resume_from_checkpoint: str): + returned_optim_state_dict = nested_copy(optimizer.state_dict()) + + resolved_archive_file, sharded_metadata = get_optimizer_shard_files( + optimizer_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, SAFE_OPTIMIZER_INDEX_NAME), + ) + has_master_weights = True if sharded_metadata["master_weights"] else False + + model_state_dict = get_expected_state_dict(model) + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} + expected_keys = sharded_metadata["all_optimizer_keys"] + + if has_master_weights: + returned_optim_state_dict["master_weights"] = {} + resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files( + optimizer_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, SAFE_MASTER_WEIGHTS_INDEX_NAME), + ) + expected_keys_mw = sharded_metadata_mw["all_optimizer_keys"] + + state_dict_optim = load_state_dict(resolved_archive_file[0], None, expected_keys) + if has_master_weights: + state_dict_optim_mw = load_state_dict(resolved_archive_file_mw[0], None, expected_keys_mw) + + for key in list(state_dict_optim.keys()): + key_name = key.split("/") + static_name = struct2static_name_mappings[key_name[0]] + if has_master_weights: + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) + returned_optim_state_dict[key_name] = state_dict_optim.pop(key) + returned_optim_state_dict[key_name].name = key_name + if has_master_weights: + for key in list(state_dict_optim_mw.keys()): + static_name = struct2static_name_mappings[key] + returned_optim_state_dict["master_weights"][static_name] = state_dict_optim_mw.pop(key) + returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) + return returned_optim_state_dict diff --git a/paddlenlp/trainer/plugins/unified_checkpoint_utils.py b/paddlenlp/trainer/plugins/unified_checkpoint_utils.py index b36277a5872d..795ebdbdbdc6 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint_utils.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint_utils.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Unified Checkpoint Utility Functions""" +import copy import os import numpy as np @@ -19,16 +21,22 @@ import paddle.distributed as dist from paddle.distributed import fleet +try: + from paddle.base import core +except: + core = None + from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM from paddlenlp.trainer.trainer_utils import ExplicitEnum from paddlenlp.trainer.utils.helper import distributed_isfile -from paddlenlp.transformers.model_utils import PretrainedModel +from paddlenlp.transformers.model_utils import PretrainedModel, get_parameter_dtype from paddlenlp.transformers.utils import dtype_byte_size from paddlenlp.utils.distributed import distributed_allgather, distributed_gather from paddlenlp.utils.env import ( PADDLE_MASTER_WEIGHTS_INDEX_NAME, PADDLE_PEFT_WEIGHTS_INDEX_NAME, PADDLE_WEIGHTS_INDEX_NAME, + PAST_KEY_VALUES_FILE_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME, @@ -69,9 +77,6 @@ class UnifiedCheckpointOption(ExplicitEnum): IGNORE_MERGE_OPTIMIZER = "ignore_merge_optimizer" -"""master weights related functions""" - - def unwrap_optimizer(optimizer): while hasattr(optimizer, "_inner_opt") or hasattr(optimizer, "_optim"): if hasattr(optimizer, "_inner_opt"): @@ -206,7 +211,7 @@ def get_expected_state_dict(model_to_save): return state_dict -def get_expected_keys(sharded_metadata, model, optimizer): +def get_expected_keys(args, sharded_metadata, model, optimizer, is_master_weights=False): hcg = fleet.get_hybrid_communicate_group() sharding_group = hcg.get_sharding_parallel_group() sharding_rank = sharding_group.rank @@ -214,11 +219,23 @@ def get_expected_keys(sharded_metadata, model, optimizer): if in_sharding_parallel_model: params2rank = optimizer._param2rank + model_state_dict = get_expected_state_dict(model) struct2static_name_mappings = {k: v.name for k, v in get_expected_state_dict(model).items()} expected_keys = [] for key in list(sharded_metadata["all_optimizer_keys"]): key_name = key.split("/")[0] + if ( + is_master_weights + and key_name in model_state_dict + and model_state_dict[key_name].dtype == core.VarDesc.VarType.FP32 + ): + continue + + if args.use_expert_parallel and args.data_parallel_rank > 0: + if key_name in model_state_dict and not getattr(model_state_dict[key_name], "no_sync", False): + continue + static_name = struct2static_name_mappings.get(key_name, None) if in_sharding_parallel_model: @@ -281,10 +298,13 @@ def generate_base_static_name(vname): vname = vname.split("_" + FP32_MASTER + "_") return vname[0], vname[1] else: - vname = vname.split(".") - a = vname[0] + "." + vname[1][:3] - b = vname[1][4:] - return a, b + # Directly deal with type names, for example: moe_gate_1_moment1_0. + type_names = optimizer_scalar_name + optimizer_non_scaler_name + for name in type_names: + if name in vname: + a = vname.split(name)[0][:-1] + b = name + return a, b def merge_large_tensor_parallel(tensor, tp_group, tp_action, dst_rank, is_dst): @@ -321,7 +341,9 @@ def merge_large_tensor_parallel(tensor, tp_group, tp_action, dst_rank, is_dst): def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() + dp_group = hcg.get_data_parallel_group() tp_rank = tp_group.rank + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 # filter actions for pipeline mode if hcg.get_pipe_parallel_group().nranks > 1: @@ -339,6 +361,9 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): continue key = filter_keys[i] tensor = state_dict[key] + # When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0. + if dp_rank > 0 and not getattr(tensor, "no_sync", False): + continue if key in tp_actions: # Get tensor size tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks @@ -362,16 +387,24 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): if len(tp_actions) > 0: for x in tp_actions.keys(): - logger.warning(f"key <{x}> need to merge tensor parallel but we can't find in model state.") + logger.debug(f"key <{x}> need to merge tensor parallel but we can't find in model state.") return state_dict_to_save -def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys): +def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys, model_state_dict=None): # Core function for UC hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() + dp_group = hcg.get_data_parallel_group() tp_rank = tp_group.rank + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 + + no_sync_kname = [] + if model_state_dict is not None: + for k, v in model_state_dict.items(): + if getattr(v, "no_sync", False): + no_sync_kname.append(k) state_dict_to_save = {} max_key_len = max([len(_) for _ in all_filter_keys]) @@ -383,6 +416,9 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys) # get base model key model_key = filter_keys[i].split("/")[0] tensor = state_dict[filter_keys[i]] + # When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0. + if dp_rank > 0 and model_key not in no_sync_kname: + continue if model_key in tp_actions: # for example: beta1, beta2 if tensor.numel().item() == 1: @@ -425,7 +461,7 @@ def filter_params(model_to_save, state_dict, is_optimizer=False): if tp_size <= 1: return [list(state_dict.keys())] - filter_tensor_list = [[] for i in range(tp_size)] + filter_tensor_list = [[] for _ in range(tp_size)] if tp_rank == 0: tensor_bytes_dict = {} @@ -475,26 +511,29 @@ def filter_params(model_to_save, state_dict, is_optimizer=False): def get_sharded_file_name(args, file_name, is_optimizer=False): if not is_optimizer: + sd_degree = args.sharding_parallel_degree if args.sharding_parallel_degree > 1 else 1 + size = sd_degree if args.use_expert_parallel else args.dataset_world_size shard_file = file_name.replace( ".pdparams", - f"-{args.logical_process_index + 1:05d}-of-{args.world_size//args.dataset_world_size:05d}.pdparams", + f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdparams", ) shard_file = shard_file.replace( ".safetensors", - f"-{args.logical_process_index + 1:05d}-of-{args.world_size//args.dataset_world_size:05d}.safetensors", + f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.safetensors", ) else: hcg = fleet.get_hybrid_communicate_group() dp_group = hcg.get_data_parallel_group() + size = dp_group.nranks if not args.use_expert_parallel else 1 shard_file = file_name.replace( - ".pdparams", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//dp_group.nranks:05d}.pdparams" + ".pdparams", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdparams" ) shard_file = shard_file.replace( ".safetensors", - f"-{args.logical_process_index + 1:05d}-of-{args.world_size//dp_group.nranks:05d}.safetensors", + f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.safetensors", ) shard_file = shard_file.replace( - ".pdopt", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//dp_group.nranks:05d}.pdopt" + ".pdopt", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdopt" ) return shard_file @@ -520,7 +559,7 @@ def get_sharded_index( return None -def gather_sharded_object(index_file, total_size, is_optimizer=False): +def gather_sharded_object(index_file, total_size, is_optimizer=False, use_expert_parallel=False): index_file_list, total_size_list = [], [] @@ -554,6 +593,17 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False): if len(index_file_list) == 0 and len(total_size_list) == 0: index_file_list = [index_file] total_size_list = [total_size] + + if use_expert_parallel: + data_group = hcg.get_data_parallel_group() + if data_group.nranks > 1: + data_index_file_list = [] + data_total_size_list = [] + dist.all_gather_object(data_index_file_list, index_file_list, data_group) + dist.all_gather_object(data_total_size_list, total_size_list, data_group) + index_file_list = flatten_list(data_index_file_list) + total_size_list = flatten_list(data_total_size_list) + if is_optimizer: sharding_group = hcg.get_sharding_parallel_group() if sharding_group.nranks > 1: @@ -570,11 +620,14 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False): def rename_shard_file(args, shard_file, file_name): """rename shard file when using expert_parallel.""" assert args.use_expert_parallel, "only expert_parallel need to use this function" + shard_file_list = [] + hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() pp_group = hcg.get_pipe_parallel_group() data_group = hcg.get_data_parallel_group() + if tp_group.nranks > 1: dist.all_gather_object(shard_file_list, shard_file, tp_group) if pp_group.nranks > 1: @@ -589,6 +642,7 @@ def rename_shard_file(args, shard_file, file_name): data_shard_file_list, shard_file_list if len(shard_file_list) > 0 else shard_file, data_group ) shard_file_list = flatten_list(data_shard_file_list) + new_index = shard_file_list.index(shard_file) sd_degree = args.sharding_parallel_degree if args.sharding_parallel_degree > 1 else 1 shard_file = file_name.replace( @@ -600,3 +654,31 @@ def rename_shard_file(args, shard_file, file_name): f"-{new_index + 1:05d}-of-{args.world_size//sd_degree:05d}.safetensors", ) return shard_file + + +def save_config(model_to_save): + dtype = get_parameter_dtype(model_to_save) + model_to_save.config.dtype = str(dtype).split(".")[1] + config_to_save = copy.deepcopy(model_to_save.config) + + if config_to_save.tensor_parallel_degree > 1: + # do we need to change? + config_to_save.tensor_parallel_degree = 1 + + return config_to_save + + +def save_prefix_past_key_value(model_to_save, save_directory): + past_key_value = model_to_save.prefix_encoder(model_to_save.prefix_tokens.unsqueeze(0).expand([1, -1])) + past_key_value = past_key_value.reshape( + [ + model_to_save.prefix_config.num_prefix_tokens, + 2, + model_to_save.prefix_config.num_hidden_layers, + model_to_save.num_heads, + model_to_save.head_dim, + ] + ) + past_key_value = paddle.transpose(past_key_value, perm=[2, 1, 3, 0, 4]).cpu().numpy() + model_to_save.prefix_config.save_pretrained(save_directory) + np.save(os.path.join(save_directory, PAST_KEY_VALUES_FILE_NAME), past_key_value) From cbbc074ab09da0bb9cc1de5b81e6c32eba572d4b Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Thu, 24 Oct 2024 16:50:59 +0800 Subject: [PATCH 09/16] update split_param loading --- .../plugins/unified_checkpoint_sharding_v2.py | 27 ++++++++++--------- paddlenlp/trainer/training_args.py | 6 +++++ 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py b/paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py index ac41fa287aef..f8eddb8691f1 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py @@ -224,12 +224,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected for shard_file in resolved_archive_file: if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]): continue - if model.config.tensor_parallel_degree > 1: - state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected") + state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="cpu") else: - state_dict = load_state_dict(shard_file, None, expected_keys, device="expected") - + state_dict = load_state_dict(shard_file, None, expected_keys, device="cpu") returned_state_dict.update(state_dict) del state_dict gc.collect() @@ -238,13 +236,6 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected # get tp params state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys_optim) - if has_master_weights: - state_dict_master_weight = load_resolved_archive_file( - resolved_archive_file_mw, - sharded_metadata_mw, - expected_keys, - is_master_weights=True, - ) # need to split param for different sharding rank, maybe need to deal with oom issue. for key in list(state_dict_optim.keys()): @@ -266,15 +257,24 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected paddle.zeros([padding_end - padding_start], dtype=state_dict_optim[key].dtype), ) ) - if has_master_weights: key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) + + state_dict_optim[key] = state_dict_optim[key]._copy_to(paddle.framework._current_expected_place(), False) + returned_optim_state_dict[key_name] = state_dict_optim.pop(key) returned_optim_state_dict[key_name].name = key_name if has_master_weights: + state_dict_master_weight = load_resolved_archive_file( + resolved_archive_file_mw, + sharded_metadata_mw, + expected_keys, + is_master_weights=True, + ) + for key in list(state_dict_master_weight.keys()): static_name = struct2static_name_mappings.get(key, None) if state_dict_master_weight[key].numel().item() > 1: @@ -292,6 +292,9 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected paddle.zeros([padding_end - padding_start], dtype=state_dict_master_weight[key].dtype), ) ) + state_dict_master_weight[key] = state_dict_master_weight[key]._copy_to( + paddle.framework._current_expected_place(), False + ) returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key) returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 6aed76f17cea..b8d94df2e0a7 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1402,6 +1402,12 @@ def is_segment_parallel_supported(): f"but got logging_steps={self.logging_steps}." ) + if "split_param" in sharding_parallel_config: + assert self.sharding == [ShardingOption.SHARD_OP], "Only sharding stage1 support split_param." + assert ( + self.amp_master_grad + ), "If `split_param` in sharding_parallel_config, `amp_master_grad` must be True." + fleet.init(is_collective=True, strategy=strategy) logger.info(strategy) From 7678fad7b780620173093d61c2bfed7056fdd97f Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 25 Oct 2024 14:33:20 +0800 Subject: [PATCH 10/16] mkdir unified_checkpoint directory --- paddlenlp/trainer/trainer.py | 5 +- .../trainer/unified_checkpoint/__init__.py | 15 + .../check_unified_checkpoint.py | 247 ++++++++ .../shared_memory_utils.py | 0 .../unified_checkpoint.py | 584 ++---------------- .../unified_checkpoint_dynamic.py | 0 .../unified_checkpoint_locally_load.py | 268 ++++++++ .../unified_checkpoint_sharding_v2.py | 11 +- .../unified_checkpoint_single_card.py | 17 +- .../unified_checkpoint_utils.py | 102 ++- 10 files changed, 686 insertions(+), 563 deletions(-) create mode 100644 paddlenlp/trainer/unified_checkpoint/__init__.py create mode 100644 paddlenlp/trainer/unified_checkpoint/check_unified_checkpoint.py rename paddlenlp/trainer/{plugins => unified_checkpoint}/shared_memory_utils.py (100%) rename paddlenlp/trainer/{plugins => unified_checkpoint}/unified_checkpoint.py (62%) rename paddlenlp/trainer/{plugins => unified_checkpoint}/unified_checkpoint_dynamic.py (100%) create mode 100644 paddlenlp/trainer/unified_checkpoint/unified_checkpoint_locally_load.py rename paddlenlp/trainer/{plugins => unified_checkpoint}/unified_checkpoint_sharding_v2.py (97%) rename paddlenlp/trainer/{plugins => unified_checkpoint}/unified_checkpoint_single_card.py (93%) rename paddlenlp/trainer/{plugins => unified_checkpoint}/unified_checkpoint_utils.py (90%) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index ac3c63b01047..7a5990098d0f 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -113,7 +113,6 @@ from .argparser import strtobool from .integrations import get_reporting_integration_callbacks from .plugins.timer import RuntimeTimer, get_timers, set_timers -from .plugins.unified_checkpoint import UnifiedCheckpointHandler from .trainer_callback import ( CallbackHandler, DefaultFlowCallback, @@ -144,6 +143,7 @@ speed_metrics, ) from .training_args import TrainingArguments +from .unified_checkpoint import UnifiedCheckpointHandler from .utils import reshard as reshard_util from .utils.async_save import AsyncSaver from .utils.helper import ( # nested_truncate, @@ -598,7 +598,6 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): if use_unified_checkpoint: self.unified_checkpoint_handler.load_unified_checkpoint( self.model, - self.optimizer, resume_from_checkpoint, ) logger.info(f"Loading model from {resume_from_checkpoint} using unified checkpoint.") @@ -1241,7 +1240,6 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): if self.args.unified_checkpoint: self.unified_checkpoint_handler.load_unified_checkpoint( self.model, - self.optimizer, self.state.best_model_checkpoint, ) if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1: @@ -1289,7 +1287,6 @@ def _load_best_model_from_peft_checkpoint(self): if self.args.unified_checkpoint: self.unified_checkpoint_handler.load_unified_checkpoint( self.model, - self.optimizer, self.state.best_model_checkpoint, ) if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1: diff --git a/paddlenlp/trainer/unified_checkpoint/__init__.py b/paddlenlp/trainer/unified_checkpoint/__init__.py new file mode 100644 index 000000000000..20a336cb3d8f --- /dev/null +++ b/paddlenlp/trainer/unified_checkpoint/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .unified_checkpoint import UnifiedCheckpointHandler diff --git a/paddlenlp/trainer/unified_checkpoint/check_unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/check_unified_checkpoint.py new file mode 100644 index 000000000000..76e8df9ce8d5 --- /dev/null +++ b/paddlenlp/trainer/unified_checkpoint/check_unified_checkpoint.py @@ -0,0 +1,247 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unfied checkpoint check functions.""" + +import json +import os + +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet + +from paddlenlp.trainer.utils.helper import distributed_file, distributed_isfile +from paddlenlp.utils.env import ( + PADDLE_MASTER_WEIGHTS_INDEX_NAME, + PADDLE_OPTIMIZER_INDEX_NAME, + SAFE_MASTER_WEIGHTS_INDEX_NAME, + SAFE_OPTIMIZER_INDEX_NAME, +) +from paddlenlp.utils.log import logger +from paddlenlp.utils.nested import flatten_list + +try: + from paddle.base import core +except: + core = None + +from .unified_checkpoint_utils import ( + get_expected_state_dict, + is_sharding_split_param_mode, + select_model_weight_index, + update_master_weight_status, +) + + +def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serialization=False): + index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=False) + index_filename = os.path.join(resume_from_checkpoint, index_filename) + # Find index json file and distribute this file in global group. + if distributed_isfile(index_filename): + distributed_file(index_filename) + else: + raise Exception( + f"Sorry, we can not find {index_filename}. This file should be appear at least on one machine." + ) + + with open(index_filename, "r") as f: + index = json.loads(f.read()) + all_weight_filenames = sorted(set(index["weight_map"].values())) + + # Get existed weight file list on current machine. + existed_filelist = [] + existed_files = [] + for filename in os.listdir(resume_from_checkpoint): + if filename in all_weight_filenames: + existed_files.append(filename) + + # Gather all the existed files in global group. + dist.all_gather_object(existed_filelist, existed_files) + flatten_existed_filelist = flatten_list(existed_filelist) + diff_filelist = list(set(all_weight_filenames).difference(set(flatten_existed_filelist))) + if len(diff_filelist) != 0: + raise Exception(f"Sorry, the weight file list on the machines is not complete!, missing {diff_filelist}") + + # To decide whether to load the checkpoint locally, or need to dynamically send tensors across machines. + local_resume = True + if args.dataset_rank == 0 or args.use_expert_parallel: + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + pp_group = hcg.get_pipe_parallel_group() + dp_group = hcg.get_data_parallel_group() + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 + + need_files = set() + state_dict = get_expected_state_dict(model) + for key in state_dict.keys(): + filename = index["weight_map"][key] + # When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0. + if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): + continue + need_files.add(filename) + diff_filelist = list(need_files.difference(set(existed_files))) + num_diff = paddle.to_tensor([len(diff_filelist)]) + if tp_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=tp_group) + if pp_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=pp_group) + if args.use_expert_parallel and dp_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=dp_group) + if num_diff.item() == 0: + local_resume = True + else: + local_resume = False + local_resume = paddle.to_tensor([local_resume]) + dist.all_reduce(local_resume, op=dist.ReduceOp.PROD) + local_resume = local_resume.item() + return local_resume + + +def check_unified_optimizer(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): + if not safe_serialization: + index_filename, index_filename_master_weights = PADDLE_OPTIMIZER_INDEX_NAME, PADDLE_MASTER_WEIGHTS_INDEX_NAME + else: + index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME + index_filename = os.path.join(resume_from_checkpoint, index_filename) + index_filename_master_weights = os.path.join(resume_from_checkpoint, index_filename_master_weights) + + # Find index json file and distribute the file in global group. + if distributed_isfile(index_filename): + distributed_file(index_filename) + else: + raise Exception( + f"Sorry, we can not find {index_filename}. This file should be appear at least on one machine." + ) + + with open(index_filename, "r") as f: + index = json.loads(f.read()) + all_optimizer_filenames = sorted(set(index["weight_map"].values())) + + has_master_weights = index["master_weights"] + # update has_master_weights and index_filename_master_weights + # 1. if the master weight exists, only has_master_weights is set True and loaded when needed + # 2. if master weight does not exist, convert model weight to master weight when needed + has_master_weights, index_filename_master_weights = update_master_weight_status( + args, optimizer, has_master_weights, safe_serialization + ) + if has_master_weights: + index_filename_master_weights = os.path.join(resume_from_checkpoint, index_filename_master_weights) + if distributed_isfile(index_filename_master_weights): + distributed_file(index_filename_master_weights) + else: + raise Exception( + f"Sorry, we can not find {index_filename_master_weights}. This file should be appear at least on one machine." + ) + with open(index_filename_master_weights, "r") as f: + index_mw = json.loads(f.read()) + all_mw_filenames = sorted(set(index_mw["weight_map"].values())) + + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + pp_group = hcg.get_pipe_parallel_group() + dp_group = hcg.get_data_parallel_group() + sharding_group = hcg.get_sharding_parallel_group() + sharding_rank = sharding_group.rank + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 + struct2static_name_mappings = {k: v.name for k, v in model.state_dict().items()} + + if is_sharding_split_param_mode(args): + # We do not check optimizer files completion for split_param, since it is very complicated. Directly support local resume. + logger.warning("We only support local resume for split_param mode, do not support dynamically loading.") + return True + + if sharding_group.nranks > 1: + param2rank = optimizer._param2rank + + def check_complete(all_filenames): + # Check whether the checkpoint files on machines are complete. If not complete, raise Exception. + existed_filelist = [] + existed_files = [] + for filename in os.listdir(resume_from_checkpoint): + if filename in all_filenames: + existed_files.append(filename) + + dist.all_gather_object(existed_filelist, existed_files) + flatten_existed_filelist = flatten_list(existed_filelist) + diff_filelist = list(set(all_filenames).difference(set(flatten_existed_filelist))) + if len(diff_filelist) != 0: + raise Exception( + f"Sorry, the optimizer file list on `data_parallel_rank==0` machines is not complete!, missing {diff_filelist}" + ) + return existed_files + + def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, typename_set=None): + # To decide whether to load the checkpoint locally, or need to dynamically distribute the checkpoint. + local_resume = True + if args.data_parallel_rank == 0 or args.use_expert_parallel: + need_files = set() + state_dict = get_expected_state_dict(model) + + for key in state_dict.keys(): + if sharding_group.nranks > 1: + static_name = struct2static_name_mappings.get(key, None) + param_rank = param2rank.get(static_name, None) + if param_rank != sharding_rank: + continue + + # When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0. + if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): + continue + + if is_master_weights and state_dict[key].dtype == core.VarDesc.VarType.FP32: + continue + + if not is_master_weights: + for type_name in typename_set: + type_key = key + "/" + type_name + filename = weight_map[type_key] + need_files.add(filename) + else: + filename = weight_map[key] + need_files.add(filename) + + diff_filelist = list(need_files.difference(set(existed_files))) + num_diff = paddle.to_tensor([len(diff_filelist)]) + if tp_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=tp_group) + if pp_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=pp_group) + if sharding_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=sharding_group) + if args.use_expert_parallel and dp_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=dp_group) + + if num_diff.item() == 0: + local_resume = True + else: + local_resume = False + local_resume = paddle.to_tensor([local_resume]) + dist.all_reduce(local_resume, op=dist.ReduceOp.PROD) + return local_resume.item() + + # check whether the optimizer checkpoint files are complete. + existed_files = check_complete(all_optimizer_filenames) + if has_master_weights: + existed_files_mw = check_complete(all_mw_filenames) + # get optimizer's param type name, like moment1_0. + typename_set = set() + for key in index["weight_map"].keys(): + _, typename = key.split("/") + typename_set.add(typename) + local_resume = check_dynamic_load( + args, index["weight_map"], existed_files, is_master_weights=False, typename_set=typename_set + ) + local_resume_rw = True + if has_master_weights: + local_resume_rw = check_dynamic_load(args, index_mw["weight_map"], existed_files_mw, is_master_weights=True) + return local_resume & local_resume_rw diff --git a/paddlenlp/trainer/plugins/shared_memory_utils.py b/paddlenlp/trainer/unified_checkpoint/shared_memory_utils.py similarity index 100% rename from paddlenlp/trainer/plugins/shared_memory_utils.py rename to paddlenlp/trainer/unified_checkpoint/shared_memory_utils.py diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py similarity index 62% rename from paddlenlp/trainer/plugins/unified_checkpoint.py rename to paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py index a724fb0968a2..463d99462a74 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -13,7 +13,6 @@ # limitations under the License. import copy -import gc import json import multiprocessing import os @@ -24,7 +23,6 @@ import paddle import paddle.distributed as dist from paddle.distributed import fleet -from tqdm.auto import tqdm try: from paddle.base import core @@ -33,27 +31,20 @@ from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM from paddlenlp.trainer.argparser import strtobool -from paddlenlp.trainer.trainer_utils import ShardingOption -from paddlenlp.trainer.utils.helper import distributed_file, distributed_isfile +from paddlenlp.trainer.utils.helper import distributed_isfile from paddlenlp.transformers.model_utils import ( PretrainedModel, _add_variant, - _load_state_dict_into_model, - faster_set_state_dict, - load_state_dict, unwrap_model, ) from paddlenlp.transformers.utils import ( device_guard, dtype_byte_size, - get_checkpoint_shard_files, is_safetensors_available, ) from paddlenlp.utils.env import ( LORA_WEIGHTS_NAME, - PADDLE_MASTER_WEIGHTS_INDEX_NAME, PADDLE_MASTER_WEIGHTS_NAME, - PADDLE_OPTIMIZER_INDEX_NAME, PADDLE_OPTIMIZER_NAME, PADDLE_WEIGHTS_NAME, PREFIX_WEIGHTS_NAME, @@ -67,7 +58,7 @@ SAFE_WEIGHTS_NAME, ) from paddlenlp.utils.log import logger -from paddlenlp.utils.nested import flatten_list, nested_copy +from paddlenlp.utils.nested import nested_copy if is_safetensors_available(): from safetensors.numpy import save_file as safe_save_file @@ -77,6 +68,7 @@ else: from paddlenlp.utils.safetensors import fast_load_file as load_file +from .check_unified_checkpoint import check_unified_checkpoint, check_unified_optimizer from .shared_memory_utils import ( _read_state_dict_from_shm, _traverse_copy_to_shm, @@ -86,10 +78,11 @@ load_unified_checkpoint_dynamically, load_unified_optimizer_dynamically, ) -from .unified_checkpoint_sharding_v2 import ( - gather_splited_param_for_optimizer, - load_unified_optimizer_split_param, +from .unified_checkpoint_locally_load import ( + load_unified_checkpoint_locally, + load_unified_optimizer_locally, ) +from .unified_checkpoint_sharding_v2 import gather_splited_param_for_optimizer from .unified_checkpoint_single_card import ( load_single_card_checkpoint, load_single_card_optimizer, @@ -102,30 +95,25 @@ filter_params, gather_sharded_object, generate_base_static_name, - get_expected_keys, get_expected_state_dict, - get_optimizer_shard_files, get_sharded_file_name, get_sharded_index, is_need_master_weight, - mapping_optimizer_tp_actions, + is_sharding_split_param_mode, merge_tensor_parallel_for_optimizer, merge_tensor_parallel_with_shard, reduce_master_weights_status, rename_shard_file, - save_config, - save_prefix_past_key_value, - select_model_weight_index, - update_master_weight_status, + save_model_config, ) -class UnifiedCheckpointHandler: +class AsyncCheckpointHander: def __init__(self, args): + # Mainly for asynchronous saving. self.args = args self.global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 - # Mainly for asynchronous saving. self._shm_model_weight = None self._shm_master_weight = None self._shm_optimizer_weight = None @@ -294,6 +282,51 @@ def _reset_and_update(self, shared_array, new_value): encoded_value = new_value.encode("utf-8") shared_array[: len(encoded_value)] = encoded_value + def unlink_shared_memory(self): + if not ("async_save" in self.args.unified_checkpoint_config): + return + + if self._shared_save_model_flag is not None: + while self._shared_save_model_flag[0] > 0: # async process is saving + if not self._process_model_weight.is_alive(): + raise RuntimeError("The process that saves model_weight has been killed unexpectedly.") + time.sleep(0.5) + self._shared_save_model_flag[0] = -1 + if self._shared_save_master_weight_flag is not None: + while self._shared_save_master_weight_flag[0] > 0: + if not self._process_master_weight.is_alive(): + raise RuntimeError("The process that saves master_weight has been killed unexpectedly.") + time.sleep(0.5) + self._shared_save_master_weight_flag[0] = -1 + if self._shared_save_optimizer_flag is not None: + while self._shared_save_optimizer_flag[0] > 0: + if not self._process_optimizer_weight.is_alive(): + raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.") + time.sleep(0.5) + self._shared_save_optimizer_flag[0] = -1 + + if self._shm_model_weight is not None: + self._shm_model_weight.close() + self._shm_model_weight.unlink() + self._shm_model_weight = None + if self._shm_master_weight is not None: + self._shm_master_weight.close() + self._shm_master_weight.unlink() + self._shm_master_weight = None + if self._shm_optimizer_weight is not None: + self._shm_optimizer_weight.close() + self._shm_optimizer_weight.unlink() + self._shm_optimizer_weight = None + + if paddle.distributed.get_world_size() > 1: + dist.barrier() + + +class UnifiedCheckpointHandler: + def __init__(self, args): + self.args = args + self.async_handler = AsyncCheckpointHander(args) + def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None): """save unified checkpoint @@ -342,7 +375,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None) is_sync_save = True if "async_save" in self.args.unified_checkpoint_config: is_sync_save = False - self._file_save_async_or_sync( + self.async_handler._file_save_async_or_sync( state_dict, path=os.path.join(save_directory, shard_file), signal_path=signal_dir, @@ -361,25 +394,8 @@ def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None) json.dump(sharded_index, f, indent=4) if self.args.should_save: - # Save prefix model past_key_values - if isinstance(model_to_save, PrefixModelForCausalLM): - save_prefix_past_key_value(model_to_save, save_directory) - model_to_save.prefix_config.save_pretrained(save_directory) - if isinstance(model_to_save, LoRAModel): - model_to_save.lora_config.save_pretrained(save_directory) - - # save the config - config_to_save = save_config(model_to_save) - # Attach architecture to the config - if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): - config_to_save.architectures = [model_to_save.model.__class__.__name__] - else: - config_to_save.architectures = [model_to_save.__class__.__name__] - if self.args.should_save: - config_to_save.save_pretrained(save_directory) - # save generation config - if model_to_save.can_generate(): - model_to_save.generation_config.save_pretrained(save_directory) + save_model_config(model_to_save, save_directory) + paddle.device.cuda.empty_cache() if strtobool(os.getenv("FLAG_LLM_PDC", "False")) and self.args.should_save: @@ -391,7 +407,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None) } paddle.save(save_info, os.path.join(save_directory, ".saving_info")) - def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str): + def load_unified_checkpoint(self, model, resume_from_checkpoint: str): """Load potential model checkpoint Args: @@ -464,14 +480,14 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp is_sync_save = True if "async_save" in self.args.unified_checkpoint_config: is_sync_save = False - self._file_save_async_or_sync( + self.async_handler._file_save_async_or_sync( optim_state_dict, path=os.path.join(output_dir, optimizer_name), signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="optimizer_weight", ) - self._file_save_async_or_sync( + self.async_handler._file_save_async_or_sync( master_weights, path=os.path.join(output_dir, master_weights_name), signal_path=signal_dir, @@ -539,11 +555,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir): save_single_card_optimizer(model, optimizer, output_dir) # no need to save signal return - if ( - self.args.sharding_parallel_degree > 1 - and ShardingOption.SHARD_OP in self.args.sharding - and "split_param" in self.args.sharding_parallel_config - ): + if is_sharding_split_param_mode(self.args): optim_state_dict, master_weights = gather_splited_param_for_optimizer(optimizer) else: optim_state_dict = nested_copy(optimizer.state_dict()) @@ -578,7 +590,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir): is_sync_save = True if "async_save" in self.args.unified_checkpoint_config: is_sync_save = False - self._file_save_async_or_sync( + self.async_handler._file_save_async_or_sync( optim_state_dict, path=os.path.join(save_directory, shard_optim_file), signal_path=signal_dir, @@ -586,7 +598,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir): state_dict_type="optimizer_weight", ) if master_weight_state_dict is not None: - self._file_save_async_or_sync( + self.async_handler._file_save_async_or_sync( master_weight_state_dict, path=os.path.join(save_directory, shard_master_weight_file), signal_path=signal_dir, @@ -658,140 +670,7 @@ def load_unified_optimizer(self, model, optimizer, resume_from_checkpoint): return None def unlink_shared_memory(self): - if not ("async_save" in self.args.unified_checkpoint_config): - return - - if self._shared_save_model_flag is not None: - while self._shared_save_model_flag[0] > 0: # async process is saving - if not self._process_model_weight.is_alive(): - raise RuntimeError("The process that saves model_weight has been killed unexpectedly.") - time.sleep(0.5) - self._shared_save_model_flag[0] = -1 - if self._shared_save_master_weight_flag is not None: - while self._shared_save_master_weight_flag[0] > 0: - if not self._process_master_weight.is_alive(): - raise RuntimeError("The process that saves master_weight has been killed unexpectedly.") - time.sleep(0.5) - self._shared_save_master_weight_flag[0] = -1 - if self._shared_save_optimizer_flag is not None: - while self._shared_save_optimizer_flag[0] > 0: - if not self._process_optimizer_weight.is_alive(): - raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.") - time.sleep(0.5) - self._shared_save_optimizer_flag[0] = -1 - - if self._shm_model_weight is not None: - self._shm_model_weight.close() - self._shm_model_weight.unlink() - self._shm_model_weight = None - if self._shm_master_weight is not None: - self._shm_master_weight.close() - self._shm_master_weight.unlink() - self._shm_master_weight = None - if self._shm_optimizer_weight is not None: - self._shm_optimizer_weight.close() - self._shm_optimizer_weight.unlink() - self._shm_optimizer_weight = None - - if paddle.distributed.get_world_size() > 1: - dist.barrier() - - -def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False): - """ - Only dataset_rank == 0 or using expert parallel can enter this function. - """ - index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=True) - - resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( - pretrained_model_name_or_path=resume_from_checkpoint, - index_filename=os.path.join(resume_from_checkpoint, index_filename), - ) - loaded_keys = sharded_metadata["all_checkpoint_keys"] - - model_state_dict = get_expected_state_dict(model) - # If using expert parallel, when dp_rank > 0, need to modify the expected_keys here. - if not args.use_expert_parallel or (args.use_expert_parallel and args.data_parallel_rank == 0): - expected_keys = set(list(model_state_dict.keys())) - else: - expected_keys = set() - for key in model_state_dict.keys(): - if getattr(model_state_dict[key], "no_sync", False): - expected_keys.add(key) - missing_keys = expected_keys - set(loaded_keys) - - use_fast_set = True - if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): - use_fast_set = False - - if len(missing_keys) > 0: - raise ValueError(f"missing_keys: {missing_keys}") - - def _remove_unused_keys( - state_dict, - model_state_dict, - ): - unused_keys = set(state_dict.keys()) - set(model_state_dict.keys()) - for unused_key in unused_keys: - del state_dict[unused_key] - return unused_keys - - # This should always be a list but, just to be sure. - if not isinstance(resolved_archive_file, list): - resolved_archive_file = [resolved_archive_file] - - error_msgs = [] - - if len(resolved_archive_file) > 1: - resolved_archive_file = tqdm(resolved_archive_file, desc="Loading checkpoint shards") - - for shard_file in resolved_archive_file: - # TODO: check if no expected_keys in shard_file, then don't load it - if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]): - continue - - pre_tensor_parallel_split = False - if shard_file.endswith(".safetensors") and model.config.tensor_parallel_degree > 1: - pre_tensor_parallel_split = True - assert loaded_keys is not None, "loaded_keys is not None." - if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): - tp_actions = model._get_tensor_parallel_convert_actions( - set(loaded_keys), is_split=True, ignore_error=True - ) - else: - tp_actions = model.get_tensor_parallel_convert_actions(model.config, loaded_keys, ignore_error=True) - # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors - state_dict = load_state_dict( - shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys, device="expected" - ) - - if not pre_tensor_parallel_split: - # Since we load all keys but we only need one of pipeline stages - _ = _remove_unused_keys(state_dict, model_state_dict) - - if model.config.tensor_parallel_degree > 1 and not pre_tensor_parallel_split: - logger.info("Converting state_dict to Tensor Parallel Format") - # ignore error for multi shard, since only parts of data - state_dict = model.convert_tensor_parallel( - None, model.config, state_dict=state_dict, ignore_error=len(resolved_archive_file) > 1 - ) - - if use_fast_set: - error_msgs += faster_set_state_dict(model, state_dict, strict_dtype=False) - else: - error_msgs += _load_state_dict_into_model(model, state_dict, "") - - # force memory release - del state_dict - # gc.collect() - - if len(error_msgs) > 0: - error_msg = "\n\t".join(error_msgs) - if " but the expected shape is" in error_msg: - error_msg += ( - "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." - ) - raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + return self.async_handler.unlink_shared_memory() def unified_checkpoint_into_shards( @@ -865,129 +744,6 @@ def unified_checkpoint_into_shards( return state_dict, shard_file, sharded_index -def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): - # Special process with split param. - if ( - args.sharding_parallel_degree > 1 - and ShardingOption.SHARD_OP in args.sharding - and "split_param" in args.sharding_parallel_config - ): - returned_optim_state_dict = load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint) - return returned_optim_state_dict - - # init and get optimizer LR_Scheduler - returned_optim_state_dict = nested_copy(optimizer.state_dict()) - - if not safe_serialization: - index_filename, index_filename_master_weights = ( - PADDLE_OPTIMIZER_INDEX_NAME, - PADDLE_MASTER_WEIGHTS_INDEX_NAME, - ) - else: - index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME - - resolved_archive_file, sharded_metadata = get_optimizer_shard_files( - optimizer_path=resume_from_checkpoint, - index_filename=os.path.join(resume_from_checkpoint, index_filename), - ) - has_master_weights = True if sharded_metadata["master_weights"] else False - - model_state_dict = get_expected_state_dict(model) - model_keys = list(model_state_dict.keys()) - struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} # get optimizer param mappings - - expected_keys = get_expected_keys(args, sharded_metadata, model, optimizer) - - # This should always be a list but, just to be sure. - if not isinstance(resolved_archive_file, list): - resolved_archive_file = [resolved_archive_file] - - if len(resolved_archive_file) > 1: - resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards") - - # update has_master_weights and index_filename_master_weights - # 1. if the master weight exists, only has_master_weights is set True and loaded when needed - # 2. if master weight does not exist, convert model weight to master weight when needed - has_master_weights, index_filename_master_weights = update_master_weight_status( - args, optimizer, has_master_weights, safe_serialization - ) - - if has_master_weights: - returned_optim_state_dict["master_weights"] = {} - - resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files( - optimizer_path=resume_from_checkpoint, - index_filename=os.path.join(resume_from_checkpoint, index_filename_master_weights), - ) - - expected_keys_mw = get_expected_keys(args, sharded_metadata_mw, model, optimizer, is_master_weights=True) - if not isinstance(resolved_archive_file_mw, list): - resolved_archive_file_mw = [resolved_archive_file_mw] - if len(resolved_archive_file_mw) > 1: - resolved_archive_file_mw = tqdm(resolved_archive_file_mw, desc="Loading master weights shards") - - def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False): - returned_state_dict = {} - # load optimizer - for shard_file in resolved_archive_file: - # TODO: check if no expected_keys in shard_file, then don't load it - if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]): - continue - - if shard_file.endswith(".safetensors"): - # assert model_keys is not None, "model_keys is None." TODO: correct the assert - if model.config.tensor_parallel_degree > 1: - if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): - tp_actions = model._get_tensor_parallel_convert_actions( - model_keys, is_split=True, ignore_error=True - ) - else: - tp_actions = model.get_tensor_parallel_convert_actions( - model.config, model_keys, ignore_error=True - ) - if not is_master_weights: - tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys) - - # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors - state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected") - else: - # for pipeline model, we don't need to use tp_actions - state_dict = load_state_dict(shard_file, None, expected_keys, device="expected") - - returned_state_dict.update(state_dict) - # force memory release - del state_dict - gc.collect() - return returned_state_dict - - state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys) - if has_master_weights: - state_dict_master_weight = load_resolved_archive_file( - resolved_archive_file_mw, sharded_metadata_mw, expected_keys_mw, is_master_weights=True - ) - # rename optimizer param - for key in list(state_dict_optim.keys()): - key_name = key.split("/") - static_name = struct2static_name_mappings[key_name[0]] - if has_master_weights: - if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) - else: - key_name = "_".join([static_name, key_name[1]]) - else: - key_name = "_".join([static_name, key_name[1]]) - returned_optim_state_dict[key_name] = state_dict_optim.pop(key) - returned_optim_state_dict[key_name].name = key_name - - if has_master_weights: - for key in list(state_dict_master_weight.keys()): - static_name = struct2static_name_mappings[key] - returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key) - returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) - - return returned_optim_state_dict - - def unified_optimizer_into_shards( args, model, @@ -1118,211 +874,3 @@ def unified_optimizer_into_shards( (optim_state_dict, shard_optimizer_file, sharded_optim_index), (master_weights, shard_master_weight_file, sharded_master_weight_index), ] - - -def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serialization=False): - index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=False) - index_filename = os.path.join(resume_from_checkpoint, index_filename) - # Find index json file and distribute this file in global group. - if distributed_isfile(index_filename): - distributed_file(index_filename) - else: - raise Exception( - f"Sorry, we can not find {index_filename}. This file should be appear at least on one machine." - ) - - with open(index_filename, "r") as f: - index = json.loads(f.read()) - all_weight_filenames = sorted(set(index["weight_map"].values())) - - # Get existed weight file list on current machine. - existed_filelist = [] - existed_files = [] - for filename in os.listdir(resume_from_checkpoint): - if filename in all_weight_filenames: - existed_files.append(filename) - - # Gather all the existed files in global group. - dist.all_gather_object(existed_filelist, existed_files) - flatten_existed_filelist = flatten_list(existed_filelist) - diff_filelist = list(set(all_weight_filenames).difference(set(flatten_existed_filelist))) - if len(diff_filelist) != 0: - raise Exception(f"Sorry, the weight file list on the machines is not complete!, missing {diff_filelist}") - - # To decide whether to load the checkpoint locally, or need to dynamically send tensors across machines. - local_resume = True - if args.dataset_rank == 0 or args.use_expert_parallel: - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - pp_group = hcg.get_pipe_parallel_group() - dp_group = hcg.get_data_parallel_group() - dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 - - need_files = set() - state_dict = get_expected_state_dict(model) - for key in state_dict.keys(): - filename = index["weight_map"][key] - # When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0. - if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): - continue - need_files.add(filename) - diff_filelist = list(need_files.difference(set(existed_files))) - num_diff = paddle.to_tensor([len(diff_filelist)]) - if tp_group.nranks > 1: - dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=tp_group) - if pp_group.nranks > 1: - dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=pp_group) - if args.use_expert_parallel and dp_group.nranks > 1: - dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=dp_group) - if num_diff.item() == 0: - local_resume = True - else: - local_resume = False - local_resume = paddle.to_tensor([local_resume]) - dist.all_reduce(local_resume, op=dist.ReduceOp.PROD) - local_resume = local_resume.item() - return local_resume - - -def check_unified_optimizer(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): - if not safe_serialization: - index_filename, index_filename_master_weights = PADDLE_OPTIMIZER_INDEX_NAME, PADDLE_MASTER_WEIGHTS_INDEX_NAME - else: - index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME - index_filename = os.path.join(resume_from_checkpoint, index_filename) - index_filename_master_weights = os.path.join(resume_from_checkpoint, index_filename_master_weights) - - # Find index json file and distribute the file in global group. - if distributed_isfile(index_filename): - distributed_file(index_filename) - else: - raise Exception( - f"Sorry, we can not find {index_filename}. This file should be appear at least on one machine." - ) - - with open(index_filename, "r") as f: - index = json.loads(f.read()) - all_optimizer_filenames = sorted(set(index["weight_map"].values())) - - has_master_weights = index["master_weights"] - # update has_master_weights and index_filename_master_weights - # 1. if the master weight exists, only has_master_weights is set True and loaded when needed - # 2. if master weight does not exist, convert model weight to master weight when needed - has_master_weights, index_filename_master_weights = update_master_weight_status( - args, optimizer, has_master_weights, safe_serialization - ) - if has_master_weights: - index_filename_master_weights = os.path.join(resume_from_checkpoint, index_filename_master_weights) - if distributed_isfile(index_filename_master_weights): - distributed_file(index_filename_master_weights) - else: - raise Exception( - f"Sorry, we can not find {index_filename_master_weights}. This file should be appear at least on one machine." - ) - with open(index_filename_master_weights, "r") as f: - index_mw = json.loads(f.read()) - all_mw_filenames = sorted(set(index_mw["weight_map"].values())) - - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - pp_group = hcg.get_pipe_parallel_group() - dp_group = hcg.get_data_parallel_group() - sharding_group = hcg.get_sharding_parallel_group() - sharding_rank = sharding_group.rank - dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 - struct2static_name_mappings = {k: v.name for k, v in model.state_dict().items()} - - if ( - args.sharding_parallel_degree > 1 - and ShardingOption.SHARD_OP in args.sharding - and "split_param" in args.sharding_parallel_config - ): - # We do not check optimizer files completion for split_param, since it is very complicated. Directly support local resume. - logger.warning("We only support local resume for split_param mode, do not support dynamically loading.") - return True - - if sharding_group.nranks > 1: - param2rank = optimizer._param2rank - - def check_complete(all_filenames): - # Check whether the checkpoint files on machines are complete. If not complete, raise Exception. - existed_filelist = [] - existed_files = [] - for filename in os.listdir(resume_from_checkpoint): - if filename in all_filenames: - existed_files.append(filename) - - dist.all_gather_object(existed_filelist, existed_files) - flatten_existed_filelist = flatten_list(existed_filelist) - diff_filelist = list(set(all_filenames).difference(set(flatten_existed_filelist))) - if len(diff_filelist) != 0: - raise Exception( - f"Sorry, the optimizer file list on `data_parallel_rank==0` machines is not complete!, missing {diff_filelist}" - ) - return existed_files - - def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, typename_set=None): - # To decide whether to load the checkpoint locally, or need to dynamically distribute the checkpoint. - local_resume = True - if args.data_parallel_rank == 0 or args.use_expert_parallel: - need_files = set() - state_dict = get_expected_state_dict(model) - - for key in state_dict.keys(): - if sharding_group.nranks > 1: - static_name = struct2static_name_mappings.get(key, None) - param_rank = param2rank.get(static_name, None) - if param_rank != sharding_rank: - continue - - # When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0. - if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): - continue - - if is_master_weights and state_dict[key].dtype == core.VarDesc.VarType.FP32: - continue - - if not is_master_weights: - for type_name in typename_set: - type_key = key + "/" + type_name - filename = weight_map[type_key] - need_files.add(filename) - else: - filename = weight_map[key] - need_files.add(filename) - - diff_filelist = list(need_files.difference(set(existed_files))) - num_diff = paddle.to_tensor([len(diff_filelist)]) - if tp_group.nranks > 1: - dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=tp_group) - if pp_group.nranks > 1: - dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=pp_group) - if sharding_group.nranks > 1: - dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=sharding_group) - if args.use_expert_parallel and dp_group.nranks > 1: - dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=dp_group) - - if num_diff.item() == 0: - local_resume = True - else: - local_resume = False - local_resume = paddle.to_tensor([local_resume]) - dist.all_reduce(local_resume, op=dist.ReduceOp.PROD) - return local_resume.item() - - # check whether the optimizer checkpoint files are complete. - existed_files = check_complete(all_optimizer_filenames) - if has_master_weights: - existed_files_mw = check_complete(all_mw_filenames) - # get optimizer's param type name, like moment1_0. - typename_set = set() - for key in index["weight_map"].keys(): - _, typename = key.split("/") - typename_set.add(typename) - local_resume = check_dynamic_load( - args, index["weight_map"], existed_files, is_master_weights=False, typename_set=typename_set - ) - local_resume_rw = True - if has_master_weights: - local_resume_rw = check_dynamic_load(args, index_mw["weight_map"], existed_files_mw, is_master_weights=True) - return local_resume & local_resume_rw diff --git a/paddlenlp/trainer/plugins/unified_checkpoint_dynamic.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_dynamic.py similarity index 100% rename from paddlenlp/trainer/plugins/unified_checkpoint_dynamic.py rename to paddlenlp/trainer/unified_checkpoint/unified_checkpoint_dynamic.py diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_locally_load.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_locally_load.py new file mode 100644 index 000000000000..794ec04f3832 --- /dev/null +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_locally_load.py @@ -0,0 +1,268 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unfied checkpoint locally loading functions.""" + +import gc +import os + +from tqdm.auto import tqdm + +try: + from paddle.base import core +except: + core = None + +from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM +from paddlenlp.transformers.model_utils import ( + _load_state_dict_into_model, + faster_set_state_dict, + load_state_dict, +) +from paddlenlp.transformers.utils import get_checkpoint_shard_files +from paddlenlp.utils.env import ( + PADDLE_MASTER_WEIGHTS_INDEX_NAME, + PADDLE_OPTIMIZER_INDEX_NAME, + SAFE_MASTER_WEIGHTS_INDEX_NAME, + SAFE_OPTIMIZER_INDEX_NAME, +) +from paddlenlp.utils.log import logger +from paddlenlp.utils.nested import nested_copy + +from .unified_checkpoint_sharding_v2 import load_unified_optimizer_split_param +from .unified_checkpoint_utils import ( + FP32_MASTER, + get_expected_keys, + get_expected_state_dict, + get_optimizer_shard_files, + is_sharding_split_param_mode, + mapping_optimizer_tp_actions, + select_model_weight_index, + update_master_weight_status, +) + + +def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False): + """ + Only dataset_rank == 0 or using expert parallel can enter this function. + """ + index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=True) + + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, index_filename), + ) + loaded_keys = sharded_metadata["all_checkpoint_keys"] + + model_state_dict = get_expected_state_dict(model) + # If using expert parallel, when dp_rank > 0, need to modify the expected_keys here. + if not args.use_expert_parallel or (args.use_expert_parallel and args.data_parallel_rank == 0): + expected_keys = set(list(model_state_dict.keys())) + else: + expected_keys = set() + for key in model_state_dict.keys(): + if getattr(model_state_dict[key], "no_sync", False): + expected_keys.add(key) + missing_keys = expected_keys - set(loaded_keys) + + use_fast_set = True + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + use_fast_set = False + + if len(missing_keys) > 0: + raise ValueError(f"missing_keys: {missing_keys}") + + def _remove_unused_keys( + state_dict, + model_state_dict, + ): + unused_keys = set(state_dict.keys()) - set(model_state_dict.keys()) + for unused_key in unused_keys: + del state_dict[unused_key] + return unused_keys + + # This should always be a list but, just to be sure. + if not isinstance(resolved_archive_file, list): + resolved_archive_file = [resolved_archive_file] + + error_msgs = [] + + if len(resolved_archive_file) > 1: + resolved_archive_file = tqdm(resolved_archive_file, desc="Loading checkpoint shards") + + for shard_file in resolved_archive_file: + # TODO: check if no expected_keys in shard_file, then don't load it + if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]): + continue + + pre_tensor_parallel_split = False + if shard_file.endswith(".safetensors") and model.config.tensor_parallel_degree > 1: + pre_tensor_parallel_split = True + assert loaded_keys is not None, "loaded_keys is not None." + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + tp_actions = model._get_tensor_parallel_convert_actions( + set(loaded_keys), is_split=True, ignore_error=True + ) + else: + tp_actions = model.get_tensor_parallel_convert_actions(model.config, loaded_keys, ignore_error=True) + # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors + state_dict = load_state_dict( + shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys, device="expected" + ) + + if not pre_tensor_parallel_split: + # Since we load all keys but we only need one of pipeline stages + _ = _remove_unused_keys(state_dict, model_state_dict) + + if model.config.tensor_parallel_degree > 1 and not pre_tensor_parallel_split: + logger.info("Converting state_dict to Tensor Parallel Format") + # ignore error for multi shard, since only parts of data + state_dict = model.convert_tensor_parallel( + None, model.config, state_dict=state_dict, ignore_error=len(resolved_archive_file) > 1 + ) + + if use_fast_set: + error_msgs += faster_set_state_dict(model, state_dict, strict_dtype=False) + else: + error_msgs += _load_state_dict_into_model(model, state_dict, "") + + # force memory release + del state_dict + # gc.collect() + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if " but the expected shape is" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + +def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): + # Special process with split param. + if is_sharding_split_param_mode(args): + returned_optim_state_dict = load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint) + return returned_optim_state_dict + + # init and get optimizer LR_Scheduler + returned_optim_state_dict = nested_copy(optimizer.state_dict()) + + if not safe_serialization: + index_filename, index_filename_master_weights = ( + PADDLE_OPTIMIZER_INDEX_NAME, + PADDLE_MASTER_WEIGHTS_INDEX_NAME, + ) + else: + index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME + + resolved_archive_file, sharded_metadata = get_optimizer_shard_files( + optimizer_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, index_filename), + ) + has_master_weights = True if sharded_metadata["master_weights"] else False + + model_state_dict = get_expected_state_dict(model) + model_keys = list(model_state_dict.keys()) + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} # get optimizer param mappings + + expected_keys = get_expected_keys(args, sharded_metadata, model, optimizer) + + # This should always be a list but, just to be sure. + if not isinstance(resolved_archive_file, list): + resolved_archive_file = [resolved_archive_file] + + if len(resolved_archive_file) > 1: + resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards") + + # update has_master_weights and index_filename_master_weights + # 1. if the master weight exists, only has_master_weights is set True and loaded when needed + # 2. if master weight does not exist, convert model weight to master weight when needed + has_master_weights, index_filename_master_weights = update_master_weight_status( + args, optimizer, has_master_weights, safe_serialization + ) + + if has_master_weights: + returned_optim_state_dict["master_weights"] = {} + + resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files( + optimizer_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, index_filename_master_weights), + ) + + expected_keys_mw = get_expected_keys(args, sharded_metadata_mw, model, optimizer, is_master_weights=True) + if not isinstance(resolved_archive_file_mw, list): + resolved_archive_file_mw = [resolved_archive_file_mw] + if len(resolved_archive_file_mw) > 1: + resolved_archive_file_mw = tqdm(resolved_archive_file_mw, desc="Loading master weights shards") + + def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False): + returned_state_dict = {} + # load optimizer + for shard_file in resolved_archive_file: + # TODO: check if no expected_keys in shard_file, then don't load it + if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]): + continue + + if shard_file.endswith(".safetensors"): + # assert model_keys is not None, "model_keys is None." TODO: correct the assert + if model.config.tensor_parallel_degree > 1: + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + tp_actions = model._get_tensor_parallel_convert_actions( + model_keys, is_split=True, ignore_error=True + ) + else: + tp_actions = model.get_tensor_parallel_convert_actions( + model.config, model_keys, ignore_error=True + ) + if not is_master_weights: + tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys) + + # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors + state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected") + else: + # for pipeline model, we don't need to use tp_actions + state_dict = load_state_dict(shard_file, None, expected_keys, device="expected") + + returned_state_dict.update(state_dict) + # force memory release + del state_dict + gc.collect() + return returned_state_dict + + state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys) + if has_master_weights: + state_dict_master_weight = load_resolved_archive_file( + resolved_archive_file_mw, sharded_metadata_mw, expected_keys_mw, is_master_weights=True + ) + # rename optimizer param + for key in list(state_dict_optim.keys()): + key_name = key.split("/") + static_name = struct2static_name_mappings[key_name[0]] + if has_master_weights: + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) + returned_optim_state_dict[key_name] = state_dict_optim.pop(key) + returned_optim_state_dict[key_name].name = key_name + + if has_master_weights: + for key in list(state_dict_master_weight.keys()): + static_name = struct2static_name_mappings[key] + returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key) + returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) + + return returned_optim_state_dict diff --git a/paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_sharding_v2.py similarity index 97% rename from paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py rename to paddlenlp/trainer/unified_checkpoint/unified_checkpoint_sharding_v2.py index f8eddb8691f1..1ee9728b0788 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint_sharding_v2.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_sharding_v2.py @@ -38,9 +38,10 @@ ) -def distributed_send_recv_splited_param( +def merge_splited_param( state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False ): + """Merge the splited param in sharding group.""" global_rank = dist.get_rank() for key in list(state_dict.keys()): if state_dict[key].numel().item() == 1: # for example: beta1, beta2 @@ -144,13 +145,9 @@ def gather_splited_param_for_optimizer(optimizer): recv_table[key] = sharding_ranklist[0][0] # which sharding_rank to recv the splited tensor send_table[key] = [(rank, begin, end) for rank, begin, end in sharding_ranklist] - distributed_send_recv_splited_param( - optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False - ) + merge_splited_param(optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False) if master_weights is not None: - distributed_send_recv_splited_param( - master_weights, partial_tensor_list, param_shape_info, send_table, recv_table, True - ) + merge_splited_param(master_weights, partial_tensor_list, param_shape_info, send_table, recv_table, True) return optim_state_dict, master_weights diff --git a/paddlenlp/trainer/plugins/unified_checkpoint_single_card.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_single_card.py similarity index 93% rename from paddlenlp/trainer/plugins/unified_checkpoint_single_card.py rename to paddlenlp/trainer/unified_checkpoint/unified_checkpoint_single_card.py index 9d5d9164a7af..baabc3ff1e6e 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint_single_card.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_single_card.py @@ -51,8 +51,7 @@ generate_base_static_name, get_expected_state_dict, get_optimizer_shard_files, - save_config, - save_prefix_past_key_value, + save_model_config, ) @@ -96,19 +95,7 @@ def save_single_card_checkpoint(model_to_save, output_dir): logger.warning("Asynchronous saving is not supported for single card environment currently.") save_file_sync(state_dict, path=os.path.join(output_dir, weight_filename)) - if isinstance(model_to_save, PrefixModelForCausalLM): - save_prefix_past_key_value(model_to_save, output_dir) - model_to_save.prefix_config.save_pretrained(output_dir) - if isinstance(model_to_save, LoRAModel): - model_to_save.lora_config.save_pretrained(output_dir) - - config_to_save = save_config(model_to_save) - config_to_save.architectures = [model_to_save.__class__.__name__] - config_to_save.save_pretrained(output_dir) - - # save generation config - if model_to_save.can_generate(): - model_to_save.generation_config.save_pretrained(output_dir) + save_model_config(model_to_save, output_dir) def save_single_card_optimizer(model, optimizer, output_dir): diff --git a/paddlenlp/trainer/plugins/unified_checkpoint_utils.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_utils.py similarity index 90% rename from paddlenlp/trainer/plugins/unified_checkpoint_utils.py rename to paddlenlp/trainer/unified_checkpoint/unified_checkpoint_utils.py index 795ebdbdbdc6..bad8dabbafa2 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint_utils.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_utils.py @@ -27,7 +27,7 @@ core = None from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM -from paddlenlp.trainer.trainer_utils import ExplicitEnum +from paddlenlp.trainer.trainer_utils import ExplicitEnum, ShardingOption from paddlenlp.trainer.utils.helper import distributed_isfile from paddlenlp.transformers.model_utils import PretrainedModel, get_parameter_dtype from paddlenlp.transformers.utils import dtype_byte_size @@ -129,6 +129,9 @@ def update_master_weight_status(args, optimizer, has_master_weight, safe_seriali def reduce_master_weights_status(has_master_weights=False): + """ + Get master_weight status througn tp, pp and sharding group. + """ data = paddle.to_tensor([has_master_weights], dtype="int32") hcg = fleet.get_hybrid_communicate_group() @@ -192,6 +195,9 @@ def mapping_optimizer_tp_actions(tp_actions, optimizer_loaded_keys): def get_expected_state_dict(model_to_save): + """ + Get trainable state_dict of model_to_save. + """ if isinstance(model_to_save, PretrainedModel): state_dict = model_to_save.state_dict() if ( @@ -293,7 +299,9 @@ def get_optimizer_shard_files(optimizer_path, index_filename): def generate_base_static_name(vname): - # return base static name and specific type name, like [embedding_0.w_0, moment1_0] + """ + Return base static name and specific type name, like [embedding_0.w_0, moment1_0] + """ if FP32_MASTER in vname: vname = vname.split("_" + FP32_MASTER + "_") return vname[0], vname[1] @@ -308,6 +316,9 @@ def generate_base_static_name(vname): def merge_large_tensor_parallel(tensor, tp_group, tp_action, dst_rank, is_dst): + """ + Move large tensor merge process to CPU, in order to avoid OOM. + """ num_rows = tensor.shape[0] num_splits = 4 parts = np.array_split(np.arange(num_rows), num_splits) @@ -339,6 +350,9 @@ def merge_large_tensor_parallel(tensor, tp_group, tp_action, dst_rank, is_dst): def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): + """ + Merge tensor parallel according to tp_actions, used for model weight. + """ hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() dp_group = hcg.get_data_parallel_group() @@ -393,7 +407,9 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys, model_state_dict=None): - # Core function for UC + """ + Merge tensor parallel according to tp_actions, used for master_weight and optimizer weight. + """ hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() dp_group = hcg.get_data_parallel_group() @@ -451,6 +467,10 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys, def filter_params(model_to_save, state_dict, is_optimizer=False): + """ + Group according to the size of the tensor, aiming to make the weight size + stored on each device as equal as possible. + """ hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() @@ -510,6 +530,9 @@ def filter_params(model_to_save, state_dict, is_optimizer=False): def get_sharded_file_name(args, file_name, is_optimizer=False): + """ + Get safetensors file name for saving. + """ if not is_optimizer: sd_degree = args.sharding_parallel_degree if args.sharding_parallel_degree > 1 else 1 size = sd_degree if args.use_expert_parallel else args.dataset_world_size @@ -542,7 +565,9 @@ def get_sharded_index( index_file_list, total_size_list, ): - # save index json file + """ + Save safetensors index json file, including metadata and weight_map. + """ local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) if local_rank == 0: sharded_index_json = {} @@ -560,7 +585,9 @@ def get_sharded_index( def gather_sharded_object(index_file, total_size, is_optimizer=False, use_expert_parallel=False): - + """ + All gather sharded files list across different groups. + """ index_file_list, total_size_list = [], [] hcg = fleet.get_hybrid_communicate_group() @@ -618,7 +645,9 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False, use_expert def rename_shard_file(args, shard_file, file_name): - """rename shard file when using expert_parallel.""" + """ + Rename shard file when using expert_parallel. + """ assert args.use_expert_parallel, "only expert_parallel need to use this function" shard_file_list = [] @@ -656,19 +685,10 @@ def rename_shard_file(args, shard_file, file_name): return shard_file -def save_config(model_to_save): - dtype = get_parameter_dtype(model_to_save) - model_to_save.config.dtype = str(dtype).split(".")[1] - config_to_save = copy.deepcopy(model_to_save.config) - - if config_to_save.tensor_parallel_degree > 1: - # do we need to change? - config_to_save.tensor_parallel_degree = 1 - - return config_to_save - - def save_prefix_past_key_value(model_to_save, save_directory): + """ + Used only for PrefixModelForCausalLM. + """ past_key_value = model_to_save.prefix_encoder(model_to_save.prefix_tokens.unsqueeze(0).expand([1, -1])) past_key_value = past_key_value.reshape( [ @@ -680,5 +700,49 @@ def save_prefix_past_key_value(model_to_save, save_directory): ] ) past_key_value = paddle.transpose(past_key_value, perm=[2, 1, 3, 0, 4]).cpu().numpy() - model_to_save.prefix_config.save_pretrained(save_directory) np.save(os.path.join(save_directory, PAST_KEY_VALUES_FILE_NAME), past_key_value) + + +def is_sharding_split_param_mode(args): + return ( + args.sharding_parallel_degree > 1 + and ShardingOption.SHARD_OP in args.sharding + and "split_param" in args.sharding_parallel_config + ) + + +def save_model_config(model_to_save, save_directory): + """ + Save model config. + """ + + def save_config(model_to_save): + dtype = get_parameter_dtype(model_to_save) + model_to_save.config.dtype = str(dtype).split(".")[1] + config_to_save = copy.deepcopy(model_to_save.config) + + if config_to_save.tensor_parallel_degree > 1: + # do we need to change? + config_to_save.tensor_parallel_degree = 1 + + return config_to_save + + # Save prefix model past_key_values + if isinstance(model_to_save, PrefixModelForCausalLM): + save_prefix_past_key_value(model_to_save, save_directory) + model_to_save.prefix_config.save_pretrained(save_directory) + if isinstance(model_to_save, LoRAModel): + model_to_save.lora_config.save_pretrained(save_directory) + + # save the config + config_to_save = save_config(model_to_save) + # Attach architecture to the config + if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): + config_to_save.architectures = [model_to_save.model.__class__.__name__] + else: + config_to_save.architectures = [model_to_save.__class__.__name__] + + config_to_save.save_pretrained(save_directory) + # save generation config + if model_to_save.can_generate(): + model_to_save.generation_config.save_pretrained(save_directory) From 238888d17e298adaff0c86f81f29ba6224444a51 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 25 Oct 2024 17:32:25 +0800 Subject: [PATCH 11/16] rename file --- .../{check_unified_checkpoint.py => check_uc.py} | 2 +- .../{unified_checkpoint_dynamic.py => uc_dynamic.py} | 4 ++-- ...checkpoint_locally_load.py => uc_locally_load.py} | 4 ++-- ...d_checkpoint_sharding_v2.py => uc_sharding_v2.py} | 2 +- ...d_checkpoint_single_card.py => uc_single_card.py} | 2 +- .../{unified_checkpoint_utils.py => uc_utils.py} | 0 .../trainer/unified_checkpoint/unified_checkpoint.py | 12 ++++++------ 7 files changed, 13 insertions(+), 13 deletions(-) rename paddlenlp/trainer/unified_checkpoint/{check_unified_checkpoint.py => check_uc.py} (99%) rename paddlenlp/trainer/unified_checkpoint/{unified_checkpoint_dynamic.py => uc_dynamic.py} (99%) rename paddlenlp/trainer/unified_checkpoint/{unified_checkpoint_locally_load.py => uc_locally_load.py} (99%) rename paddlenlp/trainer/unified_checkpoint/{unified_checkpoint_sharding_v2.py => uc_sharding_v2.py} (99%) rename paddlenlp/trainer/unified_checkpoint/{unified_checkpoint_single_card.py => uc_single_card.py} (99%) rename paddlenlp/trainer/unified_checkpoint/{unified_checkpoint_utils.py => uc_utils.py} (100%) diff --git a/paddlenlp/trainer/unified_checkpoint/check_unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/check_uc.py similarity index 99% rename from paddlenlp/trainer/unified_checkpoint/check_unified_checkpoint.py rename to paddlenlp/trainer/unified_checkpoint/check_uc.py index 76e8df9ce8d5..f03bc14e46e4 100644 --- a/paddlenlp/trainer/unified_checkpoint/check_unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/check_uc.py @@ -35,7 +35,7 @@ except: core = None -from .unified_checkpoint_utils import ( +from .uc_utils import ( get_expected_state_dict, is_sharding_split_param_mode, select_model_weight_index, diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_dynamic.py b/paddlenlp/trainer/unified_checkpoint/uc_dynamic.py similarity index 99% rename from paddlenlp/trainer/unified_checkpoint/unified_checkpoint_dynamic.py rename to paddlenlp/trainer/unified_checkpoint/uc_dynamic.py index bd5a5873b359..05090189cd47 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_dynamic.py +++ b/paddlenlp/trainer/unified_checkpoint/uc_dynamic.py @@ -45,7 +45,7 @@ else: from paddlenlp.utils.safetensors import fast_safe_open as safe_open -from .unified_checkpoint_utils import ( +from .uc_utils import ( FP32_MASTER, get_expected_state_dict, mapping_optimizer_tp_actions, @@ -258,7 +258,7 @@ def distributed_send_recv( return state_dict -def load_unified_checkpoint_dynamically(args, model, resume_from_checkpoint, safe_serialization=False): +def load_uc_dynamically(args, model, resume_from_checkpoint, safe_serialization=False): index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=False) index_filename = os.path.join(resume_from_checkpoint, index_filename) diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_locally_load.py b/paddlenlp/trainer/unified_checkpoint/uc_locally_load.py similarity index 99% rename from paddlenlp/trainer/unified_checkpoint/unified_checkpoint_locally_load.py rename to paddlenlp/trainer/unified_checkpoint/uc_locally_load.py index 794ec04f3832..abd4442dcb84 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_locally_load.py +++ b/paddlenlp/trainer/unified_checkpoint/uc_locally_load.py @@ -39,8 +39,8 @@ from paddlenlp.utils.log import logger from paddlenlp.utils.nested import nested_copy -from .unified_checkpoint_sharding_v2 import load_unified_optimizer_split_param -from .unified_checkpoint_utils import ( +from .uc_sharding_v2 import load_unified_optimizer_split_param +from .uc_utils import ( FP32_MASTER, get_expected_keys, get_expected_state_dict, diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_sharding_v2.py b/paddlenlp/trainer/unified_checkpoint/uc_sharding_v2.py similarity index 99% rename from paddlenlp/trainer/unified_checkpoint/unified_checkpoint_sharding_v2.py rename to paddlenlp/trainer/unified_checkpoint/uc_sharding_v2.py index 1ee9728b0788..ab3f3a7f27b0 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_sharding_v2.py +++ b/paddlenlp/trainer/unified_checkpoint/uc_sharding_v2.py @@ -29,7 +29,7 @@ ) from paddlenlp.utils.nested import nested_copy -from .unified_checkpoint_utils import ( +from .uc_utils import ( FP32_MASTER, generate_base_static_name, get_expected_state_dict, diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_single_card.py b/paddlenlp/trainer/unified_checkpoint/uc_single_card.py similarity index 99% rename from paddlenlp/trainer/unified_checkpoint/unified_checkpoint_single_card.py rename to paddlenlp/trainer/unified_checkpoint/uc_single_card.py index baabc3ff1e6e..84657c379419 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_single_card.py +++ b/paddlenlp/trainer/unified_checkpoint/uc_single_card.py @@ -46,7 +46,7 @@ if is_safetensors_available(): from safetensors.numpy import save_file as safe_save_file -from .unified_checkpoint_utils import ( +from .uc_utils import ( FP32_MASTER, generate_base_static_name, get_expected_state_dict, diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint_utils.py b/paddlenlp/trainer/unified_checkpoint/uc_utils.py similarity index 100% rename from paddlenlp/trainer/unified_checkpoint/unified_checkpoint_utils.py rename to paddlenlp/trainer/unified_checkpoint/uc_utils.py diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py index 463d99462a74..7a229070201a 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -68,28 +68,28 @@ else: from paddlenlp.utils.safetensors import fast_load_file as load_file -from .check_unified_checkpoint import check_unified_checkpoint, check_unified_optimizer +from .check_uc import check_unified_checkpoint, check_unified_optimizer from .shared_memory_utils import ( _read_state_dict_from_shm, _traverse_copy_to_shm, create_meta_dict, ) -from .unified_checkpoint_dynamic import ( +from .uc_dynamic import ( load_unified_checkpoint_dynamically, load_unified_optimizer_dynamically, ) -from .unified_checkpoint_locally_load import ( +from .uc_locally_load import ( load_unified_checkpoint_locally, load_unified_optimizer_locally, ) -from .unified_checkpoint_sharding_v2 import gather_splited_param_for_optimizer -from .unified_checkpoint_single_card import ( +from .uc_sharding_v2 import gather_splited_param_for_optimizer +from .uc_single_card import ( load_single_card_checkpoint, load_single_card_optimizer, save_single_card_checkpoint, save_single_card_optimizer, ) -from .unified_checkpoint_utils import ( +from .uc_utils import ( FP32_MASTER, UnifiedCheckpointOption, filter_params, From b219ba6fcf60cb491143df5bb94ad0119b28786b Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 25 Oct 2024 17:41:26 +0800 Subject: [PATCH 12/16] update async handler --- .../unified_checkpoint/async_uc_hander.py | 250 ++++++++++++++++++ .../trainer/unified_checkpoint/check_uc.py | 6 +- .../trainer/unified_checkpoint/uc_dynamic.py | 4 +- .../unified_checkpoint/uc_locally_load.py | 2 + .../unified_checkpoint/uc_sharding_v2.py | 2 + .../unified_checkpoint/uc_single_card.py | 7 + .../unified_checkpoint/unified_checkpoint.py | 226 +--------------- 7 files changed, 270 insertions(+), 227 deletions(-) create mode 100644 paddlenlp/trainer/unified_checkpoint/async_uc_hander.py diff --git a/paddlenlp/trainer/unified_checkpoint/async_uc_hander.py b/paddlenlp/trainer/unified_checkpoint/async_uc_hander.py new file mode 100644 index 000000000000..d57386e9591a --- /dev/null +++ b/paddlenlp/trainer/unified_checkpoint/async_uc_hander.py @@ -0,0 +1,250 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Asynchronous unified checkpoint handler.""" + +import multiprocessing +import os +import time +from multiprocessing import shared_memory + +import paddle +import paddle.distributed as dist + +from paddlenlp.transformers.utils import is_safetensors_available +from paddlenlp.utils.log import logger + +if is_safetensors_available(): + from safetensors.numpy import save_file as safe_save_file + +from .shared_memory_utils import ( + _read_state_dict_from_shm, + _traverse_copy_to_shm, + create_meta_dict, +) + +__all__ = ["AsyncCheckpointHander"] + + +class AsyncCheckpointHander: + def __init__(self, args): + # Mainly for asynchronous saving. + self.args = args + self.global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 + + self._shm_model_weight = None + self._shm_master_weight = None + self._shm_optimizer_weight = None + self._meta_dict_model = None + self._meta_dict_master_weight = None + self._meta_dict_optim = None + self._process_model_weight = None + self._process_master_weight = None + self._process_optimizer_weight = None + self._lock = None + self._shared_save_model_flag = None + self._shared_save_master_weight_flag = None + self._shared_save_optimizer_flag = None + + if "async_save" in self.args.unified_checkpoint_config: + self._lock = multiprocessing.Lock() + self._shared_save_model_path = multiprocessing.Array("c", 100000) + self._shared_save_model_signal_path = multiprocessing.Array("c", 100000) + self._shared_save_master_weight_path = multiprocessing.Array("c", 100000) + self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000) + self._shared_save_optimizer_path = multiprocessing.Array("c", 100000) + self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000) + self._shared_save_model_flag = multiprocessing.Array("i", 1) + self._shared_save_master_weight_flag = multiprocessing.Array("i", 1) + self._shared_save_optimizer_flag = multiprocessing.Array("i", 1) + + def _file_save_async_or_sync( + self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight" + ): + if is_sync: + for k in list(state_dict.keys()): + if isinstance(state_dict[k], paddle.Tensor): + state_dict[k] = state_dict.pop(k).cpu().numpy() + safe_save_file(state_dict, path, metadata={"format": "np"}) + else: + if state_dict_type == "model_weight": + if self._shm_model_weight is None: + self._meta_dict_model, buffer_size = create_meta_dict(state_dict) + self._shm_model_weight = shared_memory.SharedMemory(create=True, size=buffer_size) + shm_state_dict = self._shm_model_weight + meta_dict = self._meta_dict_model + shared_save_flag = self._shared_save_model_flag + shared_save_path = self._shared_save_model_path + shared_save_signal_path = self._shared_save_model_signal_path + if self._process_model_weight is None: + self._process_model_weight = multiprocessing.Process( + target=self._save_file_async_in_process, + args=( + meta_dict, + self._shm_model_weight.name, + self._shared_save_model_flag, + self._shared_save_model_path, + self._shared_save_model_signal_path, + self._lock, + state_dict_type, + self.global_rank, + ), + ) + self._process_model_weight.start() + process = self._process_model_weight + elif state_dict_type == "master_weight": + if self._shm_master_weight is None: + self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict) + self._shm_master_weight = shared_memory.SharedMemory(create=True, size=buffer_size) + shm_state_dict = self._shm_master_weight + meta_dict = self._meta_dict_master_weight + shared_save_flag = self._shared_save_master_weight_flag + shared_save_path = self._shared_save_master_weight_path + shared_save_signal_path = self._shared_save_master_weight_signal_path + if self._process_master_weight is None: + self._process_master_weight = multiprocessing.Process( + target=self._save_file_async_in_process, + args=( + meta_dict, + self._shm_master_weight.name, + self._shared_save_master_weight_flag, + self._shared_save_master_weight_path, + self._shared_save_master_weight_signal_path, + self._lock, + "model_weight" + if "skip_save_model_weight" in self.args.unified_checkpoint_config + else state_dict_type, + self.global_rank, + ), + ) + self._process_master_weight.start() + process = self._process_master_weight + elif state_dict_type == "optimizer_weight": + if self._shm_optimizer_weight is None: + self._meta_dict_optim, buffer_size = create_meta_dict(state_dict) + self._shm_optimizer_weight = shared_memory.SharedMemory(create=True, size=buffer_size) + shm_state_dict = self._shm_optimizer_weight + meta_dict = self._meta_dict_optim + shared_save_flag = self._shared_save_optimizer_flag + shared_save_path = self._shared_save_optimizer_path + shared_save_signal_path = self._shared_save_optimizer_signal_path + if self._process_optimizer_weight is None: + self._process_optimizer_weight = multiprocessing.Process( + target=self._save_file_async_in_process, + args=( + meta_dict, + self._shm_optimizer_weight.name, + self._shared_save_optimizer_flag, + self._shared_save_optimizer_path, + self._shared_save_optimizer_signal_path, + self._lock, + state_dict_type, + self.global_rank, + ), + ) + self._process_optimizer_weight.start() + process = self._process_optimizer_weight + + while True: # wait until no process is saving. + flag_value = shared_save_flag[0] + if flag_value == 0: + break + if not process.is_alive(): + raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.") + time.sleep(0.5) + logger.info(f"Wait for the previous save process to finish saving {state_dict_type}") + # only save model weight or save master weight, we enter this loop. + self._reset_and_update(shared_save_path, path) + self._reset_and_update(shared_save_signal_path, signal_path) + _traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf) + with self._lock: + shared_save_flag[0] = 1 + + def _save_file_async_in_process( + self, + meta_dict, + shm_name, + shared_save_flag, + shared_save_path, + shared_save_signal_path, + lock, + state_dict_type, + global_rank, + ): + shm = shared_memory.SharedMemory(name=shm_name) + while True: + flag_value = shared_save_flag[0] # if process uses `spawn`, cannot read this value. + if flag_value == -1: # stop process + break + if flag_value == 0: # nothing to save + continue + if flag_value == 1: # need to save + path = shared_save_path[:].decode("utf-8").rstrip("\x00") + signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00") + logger.info(f"Start to async save {path}") + state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array + safe_save_file(state_dict, path, {"format": "np"}) + del state_dict + saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}") + paddle.save(global_rank, saved_signal_path) + with lock: + shared_save_flag[0] = 0 + time.sleep(0.5) + shm.close() + + def _reset_and_update(self, shared_array, new_value): + # clear array + for i in range(len(shared_array)): + shared_array[i] = b"\0" + # update array + encoded_value = new_value.encode("utf-8") + shared_array[: len(encoded_value)] = encoded_value + + def unlink_shared_memory(self): + if not ("async_save" in self.args.unified_checkpoint_config): + return + + if self._shared_save_model_flag is not None: + while self._shared_save_model_flag[0] > 0: # async process is saving + if not self._process_model_weight.is_alive(): + raise RuntimeError("The process that saves model_weight has been killed unexpectedly.") + time.sleep(0.5) + self._shared_save_model_flag[0] = -1 + if self._shared_save_master_weight_flag is not None: + while self._shared_save_master_weight_flag[0] > 0: + if not self._process_master_weight.is_alive(): + raise RuntimeError("The process that saves master_weight has been killed unexpectedly.") + time.sleep(0.5) + self._shared_save_master_weight_flag[0] = -1 + if self._shared_save_optimizer_flag is not None: + while self._shared_save_optimizer_flag[0] > 0: + if not self._process_optimizer_weight.is_alive(): + raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.") + time.sleep(0.5) + self._shared_save_optimizer_flag[0] = -1 + + if self._shm_model_weight is not None: + self._shm_model_weight.close() + self._shm_model_weight.unlink() + self._shm_model_weight = None + if self._shm_master_weight is not None: + self._shm_master_weight.close() + self._shm_master_weight.unlink() + self._shm_master_weight = None + if self._shm_optimizer_weight is not None: + self._shm_optimizer_weight.close() + self._shm_optimizer_weight.unlink() + self._shm_optimizer_weight = None + + if paddle.distributed.get_world_size() > 1: + dist.barrier() diff --git a/paddlenlp/trainer/unified_checkpoint/check_uc.py b/paddlenlp/trainer/unified_checkpoint/check_uc.py index f03bc14e46e4..287abe01f9c0 100644 --- a/paddlenlp/trainer/unified_checkpoint/check_uc.py +++ b/paddlenlp/trainer/unified_checkpoint/check_uc.py @@ -42,6 +42,8 @@ update_master_weight_status, ) +__all__ = ["check_unified_checkpoint", "check_unified_optimizer"] + def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serialization=False): index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=False) @@ -102,7 +104,7 @@ def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serializa else: local_resume = False local_resume = paddle.to_tensor([local_resume]) - dist.all_reduce(local_resume, op=dist.ReduceOp.PROD) + dist.all_reduce(local_resume, op=dist.ReduceOp.MIN) local_resume = local_resume.item() return local_resume @@ -226,7 +228,7 @@ def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, else: local_resume = False local_resume = paddle.to_tensor([local_resume]) - dist.all_reduce(local_resume, op=dist.ReduceOp.PROD) + dist.all_reduce(local_resume, op=dist.ReduceOp.MIN) return local_resume.item() # check whether the optimizer checkpoint files are complete. diff --git a/paddlenlp/trainer/unified_checkpoint/uc_dynamic.py b/paddlenlp/trainer/unified_checkpoint/uc_dynamic.py index 05090189cd47..ee6cbb12dab6 100644 --- a/paddlenlp/trainer/unified_checkpoint/uc_dynamic.py +++ b/paddlenlp/trainer/unified_checkpoint/uc_dynamic.py @@ -55,6 +55,8 @@ update_master_weight_status, ) +__all__ = ["load_unified_checkpoint_dynamically", "load_unified_optimizer_dynamically"] + def create_send_table(file_keyname_mappings, file_machine_mappings): send_table = {} @@ -258,7 +260,7 @@ def distributed_send_recv( return state_dict -def load_uc_dynamically(args, model, resume_from_checkpoint, safe_serialization=False): +def load_unified_checkpoint_dynamically(args, model, resume_from_checkpoint, safe_serialization=False): index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=False) index_filename = os.path.join(resume_from_checkpoint, index_filename) diff --git a/paddlenlp/trainer/unified_checkpoint/uc_locally_load.py b/paddlenlp/trainer/unified_checkpoint/uc_locally_load.py index abd4442dcb84..93ac3b1ae735 100644 --- a/paddlenlp/trainer/unified_checkpoint/uc_locally_load.py +++ b/paddlenlp/trainer/unified_checkpoint/uc_locally_load.py @@ -51,6 +51,8 @@ update_master_weight_status, ) +__all__ = ["load_unified_checkpoint_locally", "load_unified_optimizer_locally"] + def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False): """ diff --git a/paddlenlp/trainer/unified_checkpoint/uc_sharding_v2.py b/paddlenlp/trainer/unified_checkpoint/uc_sharding_v2.py index ab3f3a7f27b0..1c1602b46133 100644 --- a/paddlenlp/trainer/unified_checkpoint/uc_sharding_v2.py +++ b/paddlenlp/trainer/unified_checkpoint/uc_sharding_v2.py @@ -37,6 +37,8 @@ mapping_optimizer_tp_actions, ) +__all__ = ["gather_splited_param_for_optimizer", "load_unified_optimizer_split_param"] + def merge_splited_param( state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False diff --git a/paddlenlp/trainer/unified_checkpoint/uc_single_card.py b/paddlenlp/trainer/unified_checkpoint/uc_single_card.py index 84657c379419..cd16a3866a22 100644 --- a/paddlenlp/trainer/unified_checkpoint/uc_single_card.py +++ b/paddlenlp/trainer/unified_checkpoint/uc_single_card.py @@ -54,6 +54,13 @@ save_model_config, ) +__all__ = [ + "load_single_card_checkpoint", + "load_single_card_optimizer", + "save_single_card_checkpoint", + "save_single_card_optimizer", +] + def save_file_sync(state_dict, path): for k in list(state_dict.keys()): diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py index 7a229070201a..72b17c33c44c 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -14,14 +14,10 @@ import copy import json -import multiprocessing import os import sys -import time -from multiprocessing import shared_memory import paddle -import paddle.distributed as dist from paddle.distributed import fleet try: @@ -61,19 +57,13 @@ from paddlenlp.utils.nested import nested_copy if is_safetensors_available(): - from safetensors.numpy import save_file as safe_save_file - if sys.platform.startswith("win"): from safetensors.numpy import load_file else: from paddlenlp.utils.safetensors import fast_load_file as load_file +from .async_uc_hander import AsyncCheckpointHander from .check_uc import check_unified_checkpoint, check_unified_optimizer -from .shared_memory_utils import ( - _read_state_dict_from_shm, - _traverse_copy_to_shm, - create_meta_dict, -) from .uc_dynamic import ( load_unified_checkpoint_dynamically, load_unified_optimizer_dynamically, @@ -107,219 +97,7 @@ save_model_config, ) - -class AsyncCheckpointHander: - def __init__(self, args): - # Mainly for asynchronous saving. - self.args = args - self.global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 - - self._shm_model_weight = None - self._shm_master_weight = None - self._shm_optimizer_weight = None - self._meta_dict_model = None - self._meta_dict_master_weight = None - self._meta_dict_optim = None - self._process_model_weight = None - self._process_master_weight = None - self._process_optimizer_weight = None - self._lock = None - self._shared_save_model_flag = None - self._shared_save_master_weight_flag = None - self._shared_save_optimizer_flag = None - - if "async_save" in self.args.unified_checkpoint_config: - self._lock = multiprocessing.Lock() - self._shared_save_model_path = multiprocessing.Array("c", 100000) - self._shared_save_model_signal_path = multiprocessing.Array("c", 100000) - self._shared_save_master_weight_path = multiprocessing.Array("c", 100000) - self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000) - self._shared_save_optimizer_path = multiprocessing.Array("c", 100000) - self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000) - self._shared_save_model_flag = multiprocessing.Array("i", 1) - self._shared_save_master_weight_flag = multiprocessing.Array("i", 1) - self._shared_save_optimizer_flag = multiprocessing.Array("i", 1) - - def _file_save_async_or_sync( - self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight" - ): - if is_sync: - for k in list(state_dict.keys()): - if isinstance(state_dict[k], paddle.Tensor): - state_dict[k] = state_dict.pop(k).cpu().numpy() - safe_save_file(state_dict, path, metadata={"format": "np"}) - else: - if state_dict_type == "model_weight": - if self._shm_model_weight is None: - self._meta_dict_model, buffer_size = create_meta_dict(state_dict) - self._shm_model_weight = shared_memory.SharedMemory(create=True, size=buffer_size) - shm_state_dict = self._shm_model_weight - meta_dict = self._meta_dict_model - shared_save_flag = self._shared_save_model_flag - shared_save_path = self._shared_save_model_path - shared_save_signal_path = self._shared_save_model_signal_path - if self._process_model_weight is None: - self._process_model_weight = multiprocessing.Process( - target=self._save_file_async_in_process, - args=( - meta_dict, - self._shm_model_weight.name, - self._shared_save_model_flag, - self._shared_save_model_path, - self._shared_save_model_signal_path, - self._lock, - state_dict_type, - self.global_rank, - ), - ) - self._process_model_weight.start() - process = self._process_model_weight - elif state_dict_type == "master_weight": - if self._shm_master_weight is None: - self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict) - self._shm_master_weight = shared_memory.SharedMemory(create=True, size=buffer_size) - shm_state_dict = self._shm_master_weight - meta_dict = self._meta_dict_master_weight - shared_save_flag = self._shared_save_master_weight_flag - shared_save_path = self._shared_save_master_weight_path - shared_save_signal_path = self._shared_save_master_weight_signal_path - if self._process_master_weight is None: - self._process_master_weight = multiprocessing.Process( - target=self._save_file_async_in_process, - args=( - meta_dict, - self._shm_master_weight.name, - self._shared_save_master_weight_flag, - self._shared_save_master_weight_path, - self._shared_save_master_weight_signal_path, - self._lock, - "model_weight" - if "skip_save_model_weight" in self.args.unified_checkpoint_config - else state_dict_type, - self.global_rank, - ), - ) - self._process_master_weight.start() - process = self._process_master_weight - elif state_dict_type == "optimizer_weight": - if self._shm_optimizer_weight is None: - self._meta_dict_optim, buffer_size = create_meta_dict(state_dict) - self._shm_optimizer_weight = shared_memory.SharedMemory(create=True, size=buffer_size) - shm_state_dict = self._shm_optimizer_weight - meta_dict = self._meta_dict_optim - shared_save_flag = self._shared_save_optimizer_flag - shared_save_path = self._shared_save_optimizer_path - shared_save_signal_path = self._shared_save_optimizer_signal_path - if self._process_optimizer_weight is None: - self._process_optimizer_weight = multiprocessing.Process( - target=self._save_file_async_in_process, - args=( - meta_dict, - self._shm_optimizer_weight.name, - self._shared_save_optimizer_flag, - self._shared_save_optimizer_path, - self._shared_save_optimizer_signal_path, - self._lock, - state_dict_type, - self.global_rank, - ), - ) - self._process_optimizer_weight.start() - process = self._process_optimizer_weight - - while True: # wait until no process is saving. - flag_value = shared_save_flag[0] - if flag_value == 0: - break - if not process.is_alive(): - raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.") - time.sleep(0.5) - logger.info(f"Wait for the previous save process to finish saving {state_dict_type}") - # only save model weight or save master weight, we enter this loop. - self._reset_and_update(shared_save_path, path) - self._reset_and_update(shared_save_signal_path, signal_path) - _traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf) - with self._lock: - shared_save_flag[0] = 1 - - def _save_file_async_in_process( - self, - meta_dict, - shm_name, - shared_save_flag, - shared_save_path, - shared_save_signal_path, - lock, - state_dict_type, - global_rank, - ): - shm = shared_memory.SharedMemory(name=shm_name) - while True: - flag_value = shared_save_flag[0] # if process uses `spawn`, cannot read this value. - if flag_value == -1: # stop process - break - if flag_value == 0: # nothing to save - continue - if flag_value == 1: # need to save - path = shared_save_path[:].decode("utf-8").rstrip("\x00") - signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00") - logger.info(f"Start to async save {path}") - state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array - safe_save_file(state_dict, path, {"format": "np"}) - del state_dict - saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}") - paddle.save(global_rank, saved_signal_path) - with lock: - shared_save_flag[0] = 0 - time.sleep(0.5) - shm.close() - - def _reset_and_update(self, shared_array, new_value): - # clear array - for i in range(len(shared_array)): - shared_array[i] = b"\0" - # update array - encoded_value = new_value.encode("utf-8") - shared_array[: len(encoded_value)] = encoded_value - - def unlink_shared_memory(self): - if not ("async_save" in self.args.unified_checkpoint_config): - return - - if self._shared_save_model_flag is not None: - while self._shared_save_model_flag[0] > 0: # async process is saving - if not self._process_model_weight.is_alive(): - raise RuntimeError("The process that saves model_weight has been killed unexpectedly.") - time.sleep(0.5) - self._shared_save_model_flag[0] = -1 - if self._shared_save_master_weight_flag is not None: - while self._shared_save_master_weight_flag[0] > 0: - if not self._process_master_weight.is_alive(): - raise RuntimeError("The process that saves master_weight has been killed unexpectedly.") - time.sleep(0.5) - self._shared_save_master_weight_flag[0] = -1 - if self._shared_save_optimizer_flag is not None: - while self._shared_save_optimizer_flag[0] > 0: - if not self._process_optimizer_weight.is_alive(): - raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.") - time.sleep(0.5) - self._shared_save_optimizer_flag[0] = -1 - - if self._shm_model_weight is not None: - self._shm_model_weight.close() - self._shm_model_weight.unlink() - self._shm_model_weight = None - if self._shm_master_weight is not None: - self._shm_master_weight.close() - self._shm_master_weight.unlink() - self._shm_master_weight = None - if self._shm_optimizer_weight is not None: - self._shm_optimizer_weight.close() - self._shm_optimizer_weight.unlink() - self._shm_optimizer_weight = None - - if paddle.distributed.get_world_size() > 1: - dist.barrier() +__all__ = ["UnifiedCheckpointHandler"] class UnifiedCheckpointHandler: From dbd13dfdbccf580e72e3a4ea91042fa8e16ef122 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 25 Oct 2024 22:13:18 +0800 Subject: [PATCH 13/16] update files --- .../{async_uc_hander.py => async_handler.py} | 4 ++-- .../{check_uc.py => check_completion.py} | 2 +- .../{uc_dynamic.py => load_dynamic.py} | 2 +- .../{uc_locally_load.py => load_local.py} | 4 ++-- ...ingle_card.py => load_save_single_card.py} | 2 +- ...ng_v2.py => sharding_split_param_utils.py} | 2 +- .../unified_checkpoint/unified_checkpoint.py | 19 ++++++++----------- .../{uc_utils.py => utils.py} | 0 tests/trainer/test_unified_checkpoint.py | 2 +- 9 files changed, 17 insertions(+), 20 deletions(-) rename paddlenlp/trainer/unified_checkpoint/{async_uc_hander.py => async_handler.py} (99%) rename paddlenlp/trainer/unified_checkpoint/{check_uc.py => check_completion.py} (99%) rename paddlenlp/trainer/unified_checkpoint/{uc_dynamic.py => load_dynamic.py} (99%) rename paddlenlp/trainer/unified_checkpoint/{uc_locally_load.py => load_local.py} (99%) rename paddlenlp/trainer/unified_checkpoint/{uc_single_card.py => load_save_single_card.py} (99%) rename paddlenlp/trainer/unified_checkpoint/{uc_sharding_v2.py => sharding_split_param_utils.py} (99%) rename paddlenlp/trainer/unified_checkpoint/{uc_utils.py => utils.py} (100%) diff --git a/paddlenlp/trainer/unified_checkpoint/async_uc_hander.py b/paddlenlp/trainer/unified_checkpoint/async_handler.py similarity index 99% rename from paddlenlp/trainer/unified_checkpoint/async_uc_hander.py rename to paddlenlp/trainer/unified_checkpoint/async_handler.py index d57386e9591a..4206821b50e5 100644 --- a/paddlenlp/trainer/unified_checkpoint/async_uc_hander.py +++ b/paddlenlp/trainer/unified_checkpoint/async_handler.py @@ -33,10 +33,10 @@ create_meta_dict, ) -__all__ = ["AsyncCheckpointHander"] +__all__ = ["AsyncCheckpointHandler"] -class AsyncCheckpointHander: +class AsyncCheckpointHandler: def __init__(self, args): # Mainly for asynchronous saving. self.args = args diff --git a/paddlenlp/trainer/unified_checkpoint/check_uc.py b/paddlenlp/trainer/unified_checkpoint/check_completion.py similarity index 99% rename from paddlenlp/trainer/unified_checkpoint/check_uc.py rename to paddlenlp/trainer/unified_checkpoint/check_completion.py index 287abe01f9c0..cf337c468463 100644 --- a/paddlenlp/trainer/unified_checkpoint/check_uc.py +++ b/paddlenlp/trainer/unified_checkpoint/check_completion.py @@ -35,7 +35,7 @@ except: core = None -from .uc_utils import ( +from .utils import ( get_expected_state_dict, is_sharding_split_param_mode, select_model_weight_index, diff --git a/paddlenlp/trainer/unified_checkpoint/uc_dynamic.py b/paddlenlp/trainer/unified_checkpoint/load_dynamic.py similarity index 99% rename from paddlenlp/trainer/unified_checkpoint/uc_dynamic.py rename to paddlenlp/trainer/unified_checkpoint/load_dynamic.py index ee6cbb12dab6..064ecacc7c3c 100644 --- a/paddlenlp/trainer/unified_checkpoint/uc_dynamic.py +++ b/paddlenlp/trainer/unified_checkpoint/load_dynamic.py @@ -45,7 +45,7 @@ else: from paddlenlp.utils.safetensors import fast_safe_open as safe_open -from .uc_utils import ( +from .utils import ( FP32_MASTER, get_expected_state_dict, mapping_optimizer_tp_actions, diff --git a/paddlenlp/trainer/unified_checkpoint/uc_locally_load.py b/paddlenlp/trainer/unified_checkpoint/load_local.py similarity index 99% rename from paddlenlp/trainer/unified_checkpoint/uc_locally_load.py rename to paddlenlp/trainer/unified_checkpoint/load_local.py index 93ac3b1ae735..552289d8f383 100644 --- a/paddlenlp/trainer/unified_checkpoint/uc_locally_load.py +++ b/paddlenlp/trainer/unified_checkpoint/load_local.py @@ -39,8 +39,8 @@ from paddlenlp.utils.log import logger from paddlenlp.utils.nested import nested_copy -from .uc_sharding_v2 import load_unified_optimizer_split_param -from .uc_utils import ( +from .sharding_split_param_utils import load_unified_optimizer_split_param +from .utils import ( FP32_MASTER, get_expected_keys, get_expected_state_dict, diff --git a/paddlenlp/trainer/unified_checkpoint/uc_single_card.py b/paddlenlp/trainer/unified_checkpoint/load_save_single_card.py similarity index 99% rename from paddlenlp/trainer/unified_checkpoint/uc_single_card.py rename to paddlenlp/trainer/unified_checkpoint/load_save_single_card.py index cd16a3866a22..c8d514dda55f 100644 --- a/paddlenlp/trainer/unified_checkpoint/uc_single_card.py +++ b/paddlenlp/trainer/unified_checkpoint/load_save_single_card.py @@ -46,7 +46,7 @@ if is_safetensors_available(): from safetensors.numpy import save_file as safe_save_file -from .uc_utils import ( +from .utils import ( FP32_MASTER, generate_base_static_name, get_expected_state_dict, diff --git a/paddlenlp/trainer/unified_checkpoint/uc_sharding_v2.py b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py similarity index 99% rename from paddlenlp/trainer/unified_checkpoint/uc_sharding_v2.py rename to paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py index 1c1602b46133..f337b1a8186b 100644 --- a/paddlenlp/trainer/unified_checkpoint/uc_sharding_v2.py +++ b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py @@ -29,7 +29,7 @@ ) from paddlenlp.utils.nested import nested_copy -from .uc_utils import ( +from .utils import ( FP32_MASTER, generate_base_static_name, get_expected_state_dict, diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py index 72b17c33c44c..d6c2db82f126 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -62,24 +62,21 @@ else: from paddlenlp.utils.safetensors import fast_load_file as load_file -from .async_uc_hander import AsyncCheckpointHander -from .check_uc import check_unified_checkpoint, check_unified_optimizer -from .uc_dynamic import ( +from .async_handler import AsyncCheckpointHandler +from .check_completion import check_unified_checkpoint, check_unified_optimizer +from .load_dynamic import ( load_unified_checkpoint_dynamically, load_unified_optimizer_dynamically, ) -from .uc_locally_load import ( - load_unified_checkpoint_locally, - load_unified_optimizer_locally, -) -from .uc_sharding_v2 import gather_splited_param_for_optimizer -from .uc_single_card import ( +from .load_local import load_unified_checkpoint_locally, load_unified_optimizer_locally +from .load_save_single_card import ( load_single_card_checkpoint, load_single_card_optimizer, save_single_card_checkpoint, save_single_card_optimizer, ) -from .uc_utils import ( +from .sharding_split_param_utils import gather_splited_param_for_optimizer +from .utils import ( FP32_MASTER, UnifiedCheckpointOption, filter_params, @@ -103,7 +100,7 @@ class UnifiedCheckpointHandler: def __init__(self, args): self.args = args - self.async_handler = AsyncCheckpointHander(args) + self.async_handler = AsyncCheckpointHandler(args) def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None): """save unified checkpoint diff --git a/paddlenlp/trainer/unified_checkpoint/uc_utils.py b/paddlenlp/trainer/unified_checkpoint/utils.py similarity index 100% rename from paddlenlp/trainer/unified_checkpoint/uc_utils.py rename to paddlenlp/trainer/unified_checkpoint/utils.py diff --git a/tests/trainer/test_unified_checkpoint.py b/tests/trainer/test_unified_checkpoint.py index 17fe0f14f9ea..8f5a1dfe7236 100644 --- a/tests/trainer/test_unified_checkpoint.py +++ b/tests/trainer/test_unified_checkpoint.py @@ -18,7 +18,7 @@ import numpy as np import pytest -from paddlenlp.trainer.plugins.unified_checkpoint import UnifiedCheckpointOption +from paddlenlp.trainer.unified_checkpoint.utils import UnifiedCheckpointOption from tests.parallel_launch import TestMultipleGpus from tests.testing_utils import ( require_paddle_at_least_2_gpu, From c758d9686abf589c564920653a996c670ebfa7f5 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Mon, 28 Oct 2024 14:51:32 +0800 Subject: [PATCH 14/16] update async_save_info.json file place --- paddlenlp/trainer/trainer.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index ba280a1832d5..f7b91720e1db 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2308,7 +2308,7 @@ def save_model( if output_dir is None: output_dir = self.args.output_dir - if PREFIX_CHECKPOINT_DIR in output_dir: + if PREFIX_CHECKPOINT_DIR in output_dir and self.is_in_train: signal_dir = os.path.join(self.args.output_signal_dir, os.path.split(output_dir)[-1]) else: signal_dir = self.args.output_signal_dir @@ -2606,7 +2606,7 @@ def _save( # signal_dir is used for asynchronous saving situations. signal_dir = self.args.output_signal_dir if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: - if PREFIX_CHECKPOINT_DIR in output_dir: + if PREFIX_CHECKPOINT_DIR in output_dir and self.is_in_train: signal_dir = os.path.join(signal_dir, os.path.split(output_dir)[-1]) os.makedirs(signal_dir, exist_ok=True) logger.info(f"Saving model checkpoint finish signal to {signal_dir}") @@ -2626,9 +2626,11 @@ def _save( "ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim, "skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config, } - if os.path.exists(os.path.join(signal_dir, "async_save_info.json")): # afs cannot overwrite - os.remove(os.path.join(signal_dir, "async_save_info.json")) - with open(os.path.join(signal_dir, "async_save_info.json"), "w") as f: + if os.path.exists( + os.path.join(self.args.output_signal_dir, "async_save_info.json") + ): # afs cannot overwrite + os.remove(os.path.join(self.args.output_signal_dir, "async_save_info.json")) + with open(os.path.join(self.args.output_signal_dir, "async_save_info.json"), "w") as f: json.dump(save_info, f) if self.args.should_save: From 2dd22ca42286728158d8754cfbf323e58cbcc7a0 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Mon, 28 Oct 2024 15:42:18 +0800 Subject: [PATCH 15/16] update load non-merge --- .../unified_checkpoint/unified_checkpoint.py | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py index d6c2db82f126..0b9ec0097dfb 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -31,13 +31,10 @@ from paddlenlp.transformers.model_utils import ( PretrainedModel, _add_variant, + load_state_dict, unwrap_model, ) -from paddlenlp.transformers.utils import ( - device_guard, - dtype_byte_size, - is_safetensors_available, -) +from paddlenlp.transformers.utils import dtype_byte_size, is_safetensors_available from paddlenlp.utils.env import ( LORA_WEIGHTS_NAME, PADDLE_MASTER_WEIGHTS_NAME, @@ -286,6 +283,10 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): if has_master_weights: master_weights = load_file(master_weights_path) + optimizer_state_dict = load_state_dict(optimizer_path, None, None, device="expected") + if has_master_weights: + master_weights = load_state_dict(master_weights_path, None, None, device="expected") + # rename and move to paddle.Tensor for key in list(optimizer_state_dict.keys()): key_name = key.split("/") @@ -297,20 +298,14 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): key_name = "_".join([static_name, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) - with device_guard(): - weight = paddle.Tensor(optimizer_state_dict.pop(key), zero_copy=True) - weight = weight._copy_to(paddle.framework._current_expected_place(), False) - returned_optim_state_dict[key_name] = weight + returned_optim_state_dict[key_name] = optimizer_state_dict.pop(key) returned_optim_state_dict[key_name].name = key_name if has_master_weights: returned_optim_state_dict["master_weights"] = {} for key in list(master_weights.keys()): static_name = struct2static_name_mappings[key] - with device_guard(): - weight = paddle.Tensor(master_weights.pop(key), zero_copy=True) - weight = weight._copy_to(paddle.framework._current_expected_place(), False) - returned_optim_state_dict["master_weights"][static_name] = weight + returned_optim_state_dict["master_weights"][static_name] = master_weights.pop(key) returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) return returned_optim_state_dict From fd5dea054836d177118f6ccae0e3349dfb1f6024 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Mon, 28 Oct 2024 16:25:18 +0800 Subject: [PATCH 16/16] fix --- paddlenlp/trainer/trainer.py | 4 ++-- .../unified_checkpoint/unified_checkpoint.py | 13 +------------ 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index f7b91720e1db..6bacd8172e3e 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2308,7 +2308,7 @@ def save_model( if output_dir is None: output_dir = self.args.output_dir - if PREFIX_CHECKPOINT_DIR in output_dir and self.is_in_train: + if PREFIX_CHECKPOINT_DIR in os.path.split(output_dir)[-1]: signal_dir = os.path.join(self.args.output_signal_dir, os.path.split(output_dir)[-1]) else: signal_dir = self.args.output_signal_dir @@ -2606,7 +2606,7 @@ def _save( # signal_dir is used for asynchronous saving situations. signal_dir = self.args.output_signal_dir if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: - if PREFIX_CHECKPOINT_DIR in output_dir and self.is_in_train: + if PREFIX_CHECKPOINT_DIR in os.path.split(output_dir)[-1]: signal_dir = os.path.join(signal_dir, os.path.split(output_dir)[-1]) os.makedirs(signal_dir, exist_ok=True) logger.info(f"Saving model checkpoint finish signal to {signal_dir}") diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py index 0b9ec0097dfb..5628874d5c30 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -15,7 +15,6 @@ import copy import json import os -import sys import paddle from paddle.distributed import fleet @@ -34,7 +33,7 @@ load_state_dict, unwrap_model, ) -from paddlenlp.transformers.utils import dtype_byte_size, is_safetensors_available +from paddlenlp.transformers.utils import dtype_byte_size from paddlenlp.utils.env import ( LORA_WEIGHTS_NAME, PADDLE_MASTER_WEIGHTS_NAME, @@ -53,12 +52,6 @@ from paddlenlp.utils.log import logger from paddlenlp.utils.nested import nested_copy -if is_safetensors_available(): - if sys.platform.startswith("win"): - from safetensors.numpy import load_file - else: - from paddlenlp.utils.safetensors import fast_load_file as load_file - from .async_handler import AsyncCheckpointHandler from .check_completion import check_unified_checkpoint, check_unified_optimizer from .load_dynamic import ( @@ -279,10 +272,6 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): model_state_dict = get_expected_state_dict(model) struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} # get optimizer param mappings - optimizer_state_dict = load_file(optimizer_path) - if has_master_weights: - master_weights = load_file(master_weights_path) - optimizer_state_dict = load_state_dict(optimizer_path, None, None, device="expected") if has_master_weights: master_weights = load_state_dict(master_weights_path, None, None, device="expected")