Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Oct 17, 2024
1 parent 6ebe5b6 commit 2eafad3
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2311,7 +2311,7 @@ def save_model(
self.model_wrapped.get_all_parameters(convert2cpu=True)

if self.args.should_save_model_state:
self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel, signal_dir=signal_dir)
self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)
else:
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
os.makedirs(signal_dir, exist_ok=True)
Expand Down Expand Up @@ -2592,15 +2592,16 @@ def _save(
output_dir: Optional[str] = None,
state_dict=None,
merge_tensor_parallel=False,
signal_dir: Optional[str] = None,
):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")

# signal_dir is used for asynchronous saving situations.
signal_dir = self.args.output_signal_dir
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
signal_dir = signal_dir if signal_dir is not None else self.args.output_signal_dir
if PREFIX_CHECKPOINT_DIR in output_dir:
signal_dir = os.path.join(signal_dir, os.path.split(output_dir)[-1])

Check warning on line 2604 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2603-L2604

Added lines #L2603 - L2604 were not covered by tests
os.makedirs(signal_dir, exist_ok=True)
logger.info(f"Saving model checkpoint finish signal to {signal_dir}")

Expand Down

0 comments on commit 2eafad3

Please # to comment.