Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

support pp accuracy calculation #9379

Merged
merged 11 commits into from
Nov 29, 2024
Merged

Conversation

wtmlon
Copy link
Collaborator

@wtmlon wtmlon commented Nov 6, 2024

PR types

PR changes

Description

@wtmlon wtmlon requested a review from DesmonDay November 6, 2024 11:23
Copy link

paddle-bot bot commented Nov 6, 2024

Thanks for your contribution!

@CLAassistant
Copy link

CLAassistant commented Nov 6, 2024

CLA assistant check
All committers have signed the CLA.

if pp_group.nranks > 1:
logit_shape = [[]]
if "pp_logits" in infohub:
logits = paddle.concat(infohub["pp_logits"], axis=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为啥是concat了?不是很理解

Copy link

codecov bot commented Nov 12, 2024

Codecov Report

Attention: Patch coverage is 0% with 2 lines in your changes missing coverage. Please review.

Project coverage is 53.10%. Comparing base (4b02477) to head (c0645e7).
Report is 13 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/trainer/trainer.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #9379      +/-   ##
===========================================
+ Coverage    52.93%   53.10%   +0.17%     
===========================================
  Files          688      694       +6     
  Lines       109379   110966    +1587     
===========================================
+ Hits         57899    58930    +1031     
- Misses       51480    52036     +556     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

# evaluation dont support drop last,
# so set the `accumulate_steps` to actually
# eval batch size.
model_config_backup = model.accumulate_steps
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个命名是不是不太规范? 很明显这个又不是一个model config

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logits = None
if "pp_logits" in infohub:
logits = paddle.concat(infohub["pp_logits"], axis=0)
logits = logits._copy_to(paddle.framework._current_expected_place(), False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里拷贝的原因是pp_logits是放在cpu memory 或者 cuda pin memory?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,如果这里不放在cpu或者 pin memory 会在 concat 的时候造成增加两倍 logits 大小的峰值显存,导致 OOM

@ZHUI ZHUI self-requested a review November 13, 2024 07:20
@@ -3312,6 +3347,8 @@ def prediction_step(
if self.args.pipeline_parallel_degree > 1:
# hack for pipeline mode
inputs = self._prepare_inputs(inputs)
if self.args.metric_for_best_model == "accuracy":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个建议不要放在trainer,放在SFTTrainer更加合理

# evaluation dont support drop last,
# so set the `accumulate_steps` to actually
# eval batch size.
model_config_backup = model.accumulate_steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:
input_ids = inputs

model.accumulate_steps = input_ids.shape[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

要不就要把model.micro_batch_size直接设为1

@@ -81,6 +81,7 @@
"fp16_opt_level": "O2",
"max_grad_norm": 1.0,
"dataloader_num_workers": 0,
"metric_for_best_model": "accuracy",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续也在开源模型适配

lugimzzz
lugimzzz previously approved these changes Nov 27, 2024
Copy link
Contributor

@lugimzzz lugimzzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@lugimzzz lugimzzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wawltor wawltor merged commit 741785a into PaddlePaddle:develop Nov 29, 2024
9 of 12 checks passed
wtmlon added a commit to wtmlon/PaddleNLP that referenced this pull request Nov 29, 2024
* support pp accuracy calculation

* add pp accuracy ci

* add comment

* update

* mv logits accumulation to cpu

* refactor code

* code refactor

* remove ci, not support yet

* update
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants