From 2eafad30631b30b74ab58e733c4e066202e5ae99 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Thu, 17 Oct 2024 21:32:35 +0800 Subject: [PATCH] bug fix --- paddlenlp/trainer/trainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index aa69b9414255..c7cfa72462b4 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2311,7 +2311,7 @@ def save_model( self.model_wrapped.get_all_parameters(convert2cpu=True) if self.args.should_save_model_state: - self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel, signal_dir=signal_dir) + self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel) else: if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: os.makedirs(signal_dir, exist_ok=True) @@ -2592,15 +2592,16 @@ def _save( output_dir: Optional[str] = None, state_dict=None, merge_tensor_parallel=False, - signal_dir: Optional[str] = None, ): output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving model checkpoint to {output_dir}") # signal_dir is used for asynchronous saving situations. + signal_dir = self.args.output_signal_dir if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: - signal_dir = signal_dir if signal_dir is not None else self.args.output_signal_dir + if PREFIX_CHECKPOINT_DIR in output_dir: + signal_dir = os.path.join(signal_dir, os.path.split(output_dir)[-1]) os.makedirs(signal_dir, exist_ok=True) logger.info(f"Saving model checkpoint finish signal to {signal_dir}")