Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

support pp accuracy calculation #9379

Merged
merged 11 commits into from
Nov 29, 2024
8 changes: 7 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3190,7 +3190,13 @@

# Metrics!
if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
# all_labels maybe is a tuple when prediction_steps output label_mask
if isinstance(all_labels, (list, tuple)):

Check warning on line 3194 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L3194

Added line #L3194 was not covered by tests
# compute_metrics in train.py
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels[0]))

Check warning on line 3196 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L3196

Added line #L3196 was not covered by tests
else:
# compute_metrics in modeling.py
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))

Check warning on line 3199 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L3199

Added line #L3199 was not covered by tests
else:
metrics = {}

Expand Down