Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Oct 17, 2024
1 parent 5244e61 commit 5b20bd3
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2369,12 +2369,21 @@ def _save_checkpoint(self, model, metrics=None):
output_dir = os.path.join(run_dir, checkpoint_folder)
signal_dir = os.path.join(run_signal_dir, checkpoint_folder)

if isinstance(self.model, LoRAModel) and (self.model.quantized or self.args.pipeline_parallel_degree > 1):
self.save_model(output_dir, False, signal_dir)
elif isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM):
self.save_model(output_dir, True, signal_dir)
signature = inspect.signature(self.save_model)
if "signal_dir" not in signature.parameters:
if isinstance(self.model, LoRAModel) and (self.model.quantized or self.args.pipeline_parallel_degree > 1):
self.save_model(output_dir)
elif isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM):
self.save_model(output_dir, True)

Check warning on line 2377 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2374-L2377

Added lines #L2374 - L2377 were not covered by tests
else:
self.save_model(output_dir)

Check warning on line 2379 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2379

Added line #L2379 was not covered by tests
else:
self.save_model(output_dir, False, signal_dir)
if isinstance(self.model, LoRAModel) and (self.model.quantized or self.args.pipeline_parallel_degree > 1):
self.save_model(output_dir, False, signal_dir)

Check warning on line 2382 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2382

Added line #L2382 was not covered by tests
elif isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM):
self.save_model(output_dir, True, signal_dir)

Check warning on line 2384 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2384

Added line #L2384 was not covered by tests
else:
self.save_model(output_dir, False, signal_dir)

# only save model state dict, ignore optimizer and scheduler
if not self.args.ignore_save_lr_and_optim:
Expand Down

0 comments on commit 5b20bd3

Please # to comment.