Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Auto Parallel] Add zero h1 pipeline scheduling for paddle #62865

Merged
merged 37 commits into from
Apr 18, 2024

Conversation

AndSonder
Copy link
Contributor

@AndSonder AndSonder commented Mar 19, 2024

PR Category

Auto Parallel

PR Types

Others

Description

为 Paddle 支持 Zero-H1 并行调度

Llama2 4卡实际调度结果如下:

d378681e1c2a9cdb1d47f971c9836ea7

在 PaddleNLP Llama2 模型上进行测试结果如下(pp4, batch 1, hidden_layer=4):

精度

精度可以对齐,有时候小数点后3位以后会有误查(符合论文的描述)

Llama2 下 10000 步 Loss 对比:

  • ZBH1: 2.6
  • 1F1B: 2.6

以下为前10000步,loss 曲线图

image

速度测试

测试机器: 4卡 3090

调度方案 interval_runtime interval_samples_per_second interval_steps_per_second
1F1B 3.17 5.1 0.3
ZBH1 2.75 5.8 0.4

显存占用

调度方案 卡号 max_memory_allocated max_memory_reserved
1F1B 0 12605.69 MB 13405.76 MB
1F1B 1 8809.68 MB 9611.76 MB
1F1B 2 7013.66 MB 7785.76 MB
1F1B 3 7806.72 MB 8561.76 MB
ZBH1 0 12921.69 MB (↑ 316 ) 13831.76 MB (↑ 426 )
ZBH1 1 9639.7 MB (↑ 830 ) 10463.76 MB (↑ 852 )
ZBH1 2 8357.72 MB (↑ 1344 ) 9149.76 MB (↑ 1364 )
ZBH1 3 10597.38 MB (↑ 1790 ) 11219.76 MB (↑ 1658 )
  • 1F1B 总 max_memory_allocated: 36035.75 MB
  • ZBH1 总 max_memory_allocated: 41516.49 MB
  • 1F1B 总 max_memory_reserved: 35064.04 MB
  • ZBH1 总 max_memory_reserved: 44650.04 MB

测试脚本如下:

set -x
unset CUDA_VISIBLE_DEVICES

task_name="llama_auto_static_dp2sharding2mp2pp2_vpp2"
# rm -rf output/$task_name/  # ckpt is saved in 'output/''
rm -rf "output/$task_name""_log"

# export PARALLEL_CROSS_ENTROPY=true
export FLAGS_call_stack_level=4
export PYTHONPATH=../../../:$PYTHONPATH
export GLOG_v=0

python -u -m paddle.distributed.launch \
    --gpus "0,1,2,3" \
    --log_dir "output/$task_name""_log" \
    run_pretrain_auto_static.py \
    --model_type "llama" \
    --model_name_or_path "facebook/llama-7b" \
    --tokenizer_name_or_path "facebook/llama-7b" \
    --input_dir "../data" \
    --output_dir "output/$task_name" \
    --split 949,50,1 \
    --max_seq_length 2048 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --use_flash_attention 0 \
    --use_fused_rms_norm 0 \
    --fp16 0 \
    --fp16_opt_level "O2"  \
    --scale_loss 1024 \
    --pipeline_parallel_degree  4 \
    --tensor_parallel_degree 1 \
    --pipeline_schedule_mode "ZBH1" \
    --learning_rate 0.0001 \
    --min_learning_rate 0.00001 \
    --max_steps 20 \
    --save_steps 5000 \
    --weight_decay 0.01 \
    --warmup_ratio 0.01 \
    --max_grad_norm 1.0 \
    --logging_steps 1 \
    --dataloader_num_workers 1 \
    --eval_steps 1000 \
    --report_to "visualdl" \
    --disable_tqdm true \
    --continue_training 0 \
    --recompute 0 \
    --recompute_granularity full \
    --do_train \
    --do_eval \
    --device "gpu" \
    --data_impl "mmap" \
    --enable_auto_parallel 1 \
    --sharding_parallel_degree 1 \
    --sharding "stage1" \

相关 Issue:

Copy link

paddle-bot bot commented Mar 19, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Mar 19, 2024
@AndSonder AndSonder marked this pull request as ready for review March 21, 2024 13:23
@AndSonder AndSonder changed the title [AutoParallel] Add zero h1 for paddle [AutoParallel] Add zero h1 pipeline scheduling for paddle Mar 26, 2024
@AndSonder AndSonder changed the title [AutoParallel] Add zero h1 pipeline scheduling for paddle [Auto Parallel] Add zero h1 pipeline scheduling for paddle Apr 2, 2024
Copy link

paddle-ci-bot bot commented Apr 9, 2024

Sorry to inform you that 4911d03's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

name = name.split("@")[0]
if not block._find_var_recursive(name):
return "backward_b"
var = block._find_var_recursive(name)
Copy link
Contributor

@heavyrain-lzy heavyrain-lzy Apr 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deal the operators without output such as send_v2, c_sync_calc_stream

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, 麻烦研发老师有空的时候再测试一下是否还会 hang 住 ~

def _partial_programs(self, program):
dist_context = self.get_attr("dist_context")
self._split_matmul_grad_ops_to_matmul(program, dist_context)
types, sub_program_list = _program_for_zero_bubble(program)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考1F1BFTheB_partial_programs,这里增加enable_send_recv_overlap参数设置,例如1F1B_partial_programs

def _partial_programs(self, program):
        # NOTE: The flag "enable_send_recv_overlap" may increase the reserved memory of GPUs.
        enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap")
        types = [FORWARD, BACKWARD, OPT]
        sub_program_list = _program_for_fthenb_and_1f1b(
            program, enable_send_recv_overlap
        )
        return types, sub_program_list

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考1F1BFTheB_partial_programs,这里增加enable_send_recv_overlap参数设置,例如1F1B_partial_programs

def _partial_programs(self, program):
        # NOTE: The flag "enable_send_recv_overlap" may increase the reserved memory of GPUs.
        enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap")
        types = [FORWARD, BACKWARD, OPT]
        sub_program_list = _program_for_fthenb_and_1f1b(
            program, enable_send_recv_overlap
        )
        return types, sub_program_list

这个要不单独加一个pr 适配一下吧,之前 vpp 的这个开关也是后续适配的 ~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考1F1BFTheB_partial_programs,这里增加enable_send_recv_overlap参数设置,例如1F1B_partial_programs

def _partial_programs(self, program):
        # NOTE: The flag "enable_send_recv_overlap" may increase the reserved memory of GPUs.
        enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap")
        types = [FORWARD, BACKWARD, OPT]
        sub_program_list = _program_for_fthenb_and_1f1b(
            program, enable_send_recv_overlap
        )
        return types, sub_program_list

这个要不单独加一个pr 适配一下吧,之前 vpp 的这个开关也是后续适配的 ~

VPP应该是一开始忘记加了,所以后续单独加上。这里可以一并加上。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考1F1BFTheB_partial_programs,这里增加enable_send_recv_overlap参数设置,例如1F1B_partial_programs

def _partial_programs(self, program):
        # NOTE: The flag "enable_send_recv_overlap" may increase the reserved memory of GPUs.
        enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap")
        types = [FORWARD, BACKWARD, OPT]
        sub_program_list = _program_for_fthenb_and_1f1b(
            program, enable_send_recv_overlap
        )
        return types, sub_program_list

这个要不单独加一个pr 适配一下吧,之前 vpp 的这个开关也是后续适配的 ~

VPP应该是一开始忘记加了,所以后续单独加上。这里可以一并加上。

好的 ~

Copy link
Contributor

@heavyrain-lzy heavyrain-lzy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@heavyrain-lzy heavyrain-lzy merged commit adf8689 into PaddlePaddle:develop Apr 18, 2024
29 checks passed
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants