Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Oct 17, 2024
1 parent 5244e61 commit 6ebe5b6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
1 change: 0 additions & 1 deletion paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,6 @@ def _save(
output_dir: Optional[str] = None,
state_dict=None,
merge_tensor_parallel=False,
signal_dir: Optional[str] = None,
):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
Expand Down
11 changes: 6 additions & 5 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2292,7 +2292,6 @@ def save_model(
self,
output_dir: Optional[str] = None,
merge_tensor_parallel: Optional[bool] = False,
signal_dir: Optional[str] = None,
):
"""
Will save the model, so you can reload it using `from_pretrained()`.
Expand All @@ -2303,7 +2302,9 @@ def save_model(
if output_dir is None:
output_dir = self.args.output_dir

if signal_dir is None:
if PREFIX_CHECKPOINT_DIR in output_dir:
signal_dir = os.path.join(self.args.output_signal_dir, os.path.split(output_dir)[-1])
else:
signal_dir = self.args.output_signal_dir

if ShardingOption.FULL_SHARD in self.args.sharding:
Expand Down Expand Up @@ -2370,11 +2371,11 @@ def _save_checkpoint(self, model, metrics=None):
signal_dir = os.path.join(run_signal_dir, checkpoint_folder)

if isinstance(self.model, LoRAModel) and (self.model.quantized or self.args.pipeline_parallel_degree > 1):
self.save_model(output_dir, False, signal_dir)
self.save_model(output_dir)
elif isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM):
self.save_model(output_dir, True, signal_dir)
self.save_model(output_dir, True)
else:
self.save_model(output_dir, False, signal_dir)
self.save_model(output_dir)

# only save model state dict, ignore optimizer and scheduler
if not self.args.ignore_save_lr_and_optim:
Expand Down

0 comments on commit 6ebe5b6

Please # to comment.