Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

fix auto ci return bug when run in v100 #9228

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 100 additions & 101 deletions scripts/distribute/ci_case_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1359,109 +1359,108 @@ function llama_align_dygraph_dy2st_pir_auto_grad_merge_bs2_fp32_DP1-MP1-PP1() {
function llama_align_dy2st_fthenb_and_vpp_auto_bs2_fp32_DP1-MP1-PP4() {
echo "=========== $FUNCNAME run begin ==========="
# Only A100 support this case.
if [ $IS_A100 -eq 0 ]; then
return
fi
export FLAGS_call_stack_level=3
export NVIDIA_TF32_OVERRIDE=0
export FLAGS_max_inplace_grad_add=3

task_name="llama_align_dy2st_fthenb_and_vpp_auto_bs2_fp32_DP1_MP1_PP4"
case_out_dir="output/$task_name"
case_log_dir="output/$task_name""_log"
loss1=0
loss2=0
use_pir=1
if [ $IS_A100 -ne 0 ]; then
export FLAGS_call_stack_level=3
export NVIDIA_TF32_OVERRIDE=0
export FLAGS_max_inplace_grad_add=3

max_step=10
to_static=1

for pp_mode in "1F1B" "VPP"; do
export FLAGS_enable_pir_api=${use_pir}
export FLAGS_enable_pir_in_executor=${use_pir}
rm -rf $case_out_dir
rm -rf $case_log_dir
rm -rf ${log_path}/$FUNCNAME
if [ "$pp_mode" == "FThenB" ]; then
vpp_degree=1
else
vpp_degree=2
fi
task_name="llama_align_dy2st_fthenb_and_vpp_auto_bs2_fp32_DP1_MP1_PP4"
case_out_dir="output/$task_name"
case_log_dir="output/$task_name""_log"
loss1=0
loss2=0
use_pir=1

max_step=10
to_static=1

for pp_mode in "1F1B" "VPP"; do
export FLAGS_enable_pir_api=${use_pir}
export FLAGS_enable_pir_in_executor=${use_pir}
rm -rf $case_out_dir
rm -rf $case_log_dir
rm -rf ${log_path}/$FUNCNAME
if [ "$pp_mode" == "FThenB" ]; then
vpp_degree=1
else
vpp_degree=2
fi

python -u -m paddle.distributed.launch \
--gpus "0,1,2,3" \
--log_dir $case_log_dir \
run_pretrain_auto.py \
--model_type "llama" \
--model_name_or_path "facebook/llama-7b" \
--tokenizer_name_or_path "facebook/llama-7b" \
--input_dir "./data" \
--output_dir $case_out_dir \
--split 949,50,1 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--warmup_steps 30 \
--max_grad_norm 0.0 \
--learning_rate 3e-05 \
--min_learning_rate 3e-06 \
--max_steps $max_step \
--logging_steps 1 \
--eval_steps 1000 \
--save_steps 50000 \
--continue_training 0 \
--do_train true \
--do_eval false \
--do_predict false \
--disable_tqdm true \
--skip_profile_timer true \
--save_total_limit 2 \
--device gpu \
--disable_tqdm true \
--dataloader_num_workers 1 \
--distributed_dataloader 0 \
--enable_auto_parallel 1 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 4 \
--per_device_eval_batch_size 2 \
--recompute false \
--recompute_use_reentrant true \
--recompute_granularity full \
--fp16 0 \
--fp16_opt_level "O2" \
--fuse_attention_ffn true \
--fuse_attention_qkv true \
--fuse_sequence_parallel_allreduce false \
--use_flash_attention 0 \
--use_fused_rope false \
--use_fused_rms_norm 0 \
--max_seq_length 2048 \
--hidden_size 1024 \
--sep_parallel_degree 1 \
--sequence_parallel false \
--pipeline_parallel_degree 4 \
--sharding_parallel_degree 1 \
--tensor_parallel_degree 1 \
--sharding "" \
--to_static ${to_static} \
--num_hidden_layers 8 \
--data_parallel_config "gradient_sync_after_accumulate" \
--pipeline_schedule_mode $pp_mode \
--virtual_pp_degree $vpp_degree \
>>${log_path}/$FUNCNAME 2>&1

loss=$(grep "global_step: 10," "$case_log_dir/workerlog.0" | grep -oP '(?<=loss: )\d+(\.\d+)?' | awk -F ',' '{print $1}')
if [ "$pp_mode" == "FThenB" ]; then
loss1=loss
else
loss2=loss
fi
echo "result: $pp_mode loss=$loss"
done
ips=-1
mem=-1
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss1} ${loss2} ${ips_base} ${ips} ${mem_base} ${mem}
python -u -m paddle.distributed.launch \
--gpus "0,1,2,3" \
--log_dir $case_log_dir \
run_pretrain_auto.py \
--model_type "llama" \
--model_name_or_path "facebook/llama-7b" \
--tokenizer_name_or_path "facebook/llama-7b" \
--input_dir "./data" \
--output_dir $case_out_dir \
--split 949,50,1 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--warmup_steps 30 \
--max_grad_norm 0.0 \
--learning_rate 3e-05 \
--min_learning_rate 3e-06 \
--max_steps $max_step \
--logging_steps 1 \
--eval_steps 1000 \
--save_steps 50000 \
--continue_training 0 \
--do_train true \
--do_eval false \
--do_predict false \
--disable_tqdm true \
--skip_profile_timer true \
--save_total_limit 2 \
--device gpu \
--disable_tqdm true \
--dataloader_num_workers 1 \
--distributed_dataloader 0 \
--enable_auto_parallel 1 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 4 \
--per_device_eval_batch_size 2 \
--recompute false \
--recompute_use_reentrant true \
--recompute_granularity full \
--fp16 0 \
--fp16_opt_level "O2" \
--fuse_attention_ffn true \
--fuse_attention_qkv true \
--fuse_sequence_parallel_allreduce false \
--use_flash_attention 0 \
--use_fused_rope false \
--use_fused_rms_norm 0 \
--max_seq_length 2048 \
--hidden_size 1024 \
--sep_parallel_degree 1 \
--sequence_parallel false \
--pipeline_parallel_degree 4 \
--sharding_parallel_degree 1 \
--tensor_parallel_degree 1 \
--sharding "" \
--to_static ${to_static} \
--num_hidden_layers 8 \
--data_parallel_config "gradient_sync_after_accumulate" \
--pipeline_schedule_mode $pp_mode \
--virtual_pp_degree $vpp_degree \
>>${log_path}/$FUNCNAME 2>&1

loss=$(grep "global_step: 10," "$case_log_dir/workerlog.0" | grep -oP '(?<=loss: )\d+(\.\d+)?' | awk -F ',' '{print $1}')
if [ "$pp_mode" == "FThenB" ]; then
loss1=loss
else
loss2=loss
fi
echo "result: $pp_mode loss=$loss"
done
ips=-1
mem=-1
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss1} ${loss2} ${ips_base} ${ips} ${mem_base} ${mem}
fi
echo "=========== $FUNCNAME run end ==========="
}

Expand Down
Loading