forked from deepspeedai/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Curriculum learning support (deepspeedai#17)
* initial commit * script fix
- Loading branch information
1 parent
0d88755
commit db97cd2
Showing
10 changed files
with
341 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
This is a short tutorial of how to use/tune the curriculum learning (CL) integration. Currently it is only integrated for GPT pre-training. For technical details please refer to our [paper](https://arxiv.org/abs/2108.06084). | ||
|
||
# Disable batch size warmup (--rampup-batch-size) | ||
In our [paper](https://arxiv.org/abs/2108.06084) section 5.4 we demonstrate that curriculum learning (seqlen-based) provides much better training stability than the batch size warmup technique. So when using CL you need to remove the `--rampup-batch-size` config in your training script. It's not recommended to use both CL and batch size warmup, because both of them will reduce the number of tokens in a batch. Another related change you might want is to increase your micro batch size, since without batch size warmup your batch size will be fixed now. | ||
|
||
# Token-based training termination | ||
|
||
Because CL changes length of each sequence/sample during training, it is very hard/impossible to use number of steps/samples to terminate the training exactly at the desired number of tokens. Thus we add a `--train-tokens` config as an alternative accurate token-based termination. We recommend increase your original `--train-samples` or `--train-iters` to a large enough number (e.g., 2X of what you used for baseline), and set `--train-tokens` at the exact desired number of training tokens (e.g., 300B for GPT-3 like training). | ||
|
||
# Token-based LR decay | ||
|
||
Again because CL changes the number of tokens per batch, in our [paper](https://arxiv.org/abs/2108.06084) Appendix A.2 we show that it is also necessary to change the LR decay to token-based (to avoid decaying LR too fast). Thus we add a `--lr-decay-tokens` which will be the number of LR decay tokens. If previously you were using `--lr-decay-samples`, you can calculate your `--lr-decay-tokens` simply by multiplying the former by full seqlen (e.g. 2K for GPT-3). Then you need to replace `--lr-decay-samples` with `--lr-decay-tokens` in your script. | ||
|
||
# LR warmup adjustment | ||
|
||
For LR warmup we don't change it to token-based, because doing so for CL means slowing down the LR warmup, which is both unnecessary and harmful. However, you may need to adjust your `--lr-warmup-samples` or `--lr-warmup-iters` from non-CL cases for various reasons (e.g., if you used `--rampup-batch-size` in non-CL case, for CL we don't use it so the number of samples per batch will be different at beginning). Assuming you want to use `X` tokens to warmup the LR (for OpenAI GPT-3 this was 375M tokens), then for CL case you shall set `--lr-warmup-samples` as `X` divided by the `min_difficulty` below, or set `--lr-warmup-iters` as `X` divided by `min_difficulty * --global-batch-size`. This is a rough estimation based on that CL starts from seqlen `min_difficulty` and it won't increase too much during LR warmup. | ||
|
||
# Token-based tensorboard | ||
|
||
Because of the above changes, we also add token-based tensorboard scalars. We also add scalars that plot the seqlen at each step. | ||
|
||
# Curriculum learning hyperparameters tuning strategy | ||
|
||
The curriculum learning hyperparameters are all located in the deepspeed config json file (see the example `ds_config_cl.json` in this dir). There are a few config entries that you may need to adjust to your circumstances, and two of which require some tuning. In our [paper](https://arxiv.org/abs/2108.06084) Appendix A.1 we have a more detailed tuning strategy description. | ||
|
||
1. `max_difficulty` should be set as the full seqlen (i.e., your `--seq-length`). No need to tune this. | ||
|
||
2. `min_difficulty` is the beginning seqlen used by CL. In general smaller `min_difficulty` could provide better stability/convergence speed benefit. However we observe that for a larger model or for different training data, starting from a very small seqlen could lead to significant validation PPL fluctuation (or even divergence) at the very beginning. We recommend to start with `min_difficulty` at 64, and then increase it if you observe problems at the very beginning. Note that to enable Tensor Core acceleration you should always use a multiple of 8. | ||
|
||
3. `total_curriculum_step` is the total number of steps used by CL. In general larger `total_curriculum_step` could provide better stability/convergence speed benefit. However we observe that a too large `total_curriculum_step` could lead to overfitting and significant validation PPL fluctuation (or even divergence) at the first few multiple of LR warmup steps. In our paper we have a detailed tuning strategy based on binary search. However, if you want to reduce the tuning effort we recommend directly setting `total_curriculum_step` as half of baseline's total number of steps. This may not provide the highest convergence speed benefit, but should provide enough training stability gains. | ||
|
||
4. `difficulty_step` is the change in seq length per CL step. A smaller value is preferable since it gives more smooth CL and better stability. Like `min_difficulty` it too needs to be multiple of 8 for Tensor core acceleration, thus 8 is a good default. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
{ | ||
"train_batch_size": 512, | ||
"gradient_accumulation_steps": 1, | ||
"steps_per_print": 1, | ||
"zero_optimization": { | ||
"stage": 0 | ||
}, | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 0.00015, | ||
"max_grad_norm": 1.0, | ||
"betas": [0.9, 0.95] | ||
} | ||
}, | ||
"gradient_clipping": 1.0, | ||
"fp16": { | ||
"enabled": true, | ||
"loss_scale": 0, | ||
"loss_scale_window": 1000, | ||
"hysteresis": 2, | ||
"min_loss_scale": 1 | ||
}, | ||
"wall_clock_breakdown": false, | ||
"zero_allow_untested_optimizer": false, | ||
"curriculum_learning": { | ||
"enabled": true, | ||
"curriculum_type": "seqlen", | ||
"min_difficulty": 8, | ||
"max_difficulty": 1024, | ||
"schedule_type": "fixed_linear", | ||
"schedule_config": { | ||
"total_curriculum_step": 60000, | ||
"difficulty_step": 8 | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
#!/bin/bash | ||
|
||
# This is a dummy train script to show how to use curriculum | ||
# learning, some parameters are not for actual GPT pretraining. | ||
|
||
TARGET_GLOBAL_BATCH_SIZE=512 | ||
TRAIN_SAMPLES=146_484_375 | ||
LR=1.0e-4 | ||
MIN_LR=1.0e-5 | ||
LR_DECAY_SAMPLES=126_953_125 | ||
LR_WARMUP_SAMPLES=183_105 | ||
SEQLEN=1024 | ||
|
||
############################################################ | ||
# New configs for curriculum learning, see README.md | ||
TRAIN_TOKENS=10_000_000_000 | ||
# LR_DECAY_TOKENS=LR_DECAY_SAMPLES*SEQLEN | ||
LR_DECAY_TOKENS=130000000000 | ||
############################################################ | ||
|
||
LOG_INTERVAL=100 | ||
EVAL_ITERS=10 | ||
EVAL_INTERVAL=100 | ||
SAVE_INTERVAL=1000 | ||
|
||
VOCAB_PATH=/data/Megatron-LM/data/gpt2-vocab.json | ||
MERGE_PATH=/data/Megatron-LM/data/gpt2-merges.txt | ||
DATA_PATH=/data/Megatron-LM/data/indexed_datasets/megatron | ||
|
||
MICRO_BATCH_SIZE=1 | ||
MP_SIZE=1 | ||
PP_SIZE=1 | ||
|
||
NUM_GPUS=128 | ||
echo ${NUM_GPUS} | ||
if [[ $PP_SIZE -gt 0 ]]; then | ||
DP_SIZE=$(( ${NUM_GPUS} / (${PP_SIZE} * ${MP_SIZE}) )) | ||
else | ||
DP_SIZE=$(( ${NUM_GPUS} / ${MP_SIZE} )) | ||
fi | ||
GRAD_ACC_STEPS=$(( ${TARGET_GLOBAL_BATCH_SIZE} / (${MICRO_BATCH_SIZE} * ${DP_SIZE}) )) | ||
|
||
NAME="gpt-117M-pp${PP_SIZE}-mp${MP_SIZE}-bsz${TARGET_GLOBAL_BATCH_SIZE}-mbsz${MICRO_BATCH_SIZE}-cl" | ||
current_time=$(date "+%Y.%m.%d-%H.%M.%S") | ||
host="${HOSTNAME}" | ||
TENSORBOARD_DIR="tensorboard/${NAME}_${host}_${current_time}" | ||
mkdir -p ${TENSORBOARD_DIR} | ||
CHECKPOINT_PATH="checkpoints/${NAME}" | ||
|
||
megatron_options=" \ | ||
--data-path ${DATA_PATH} \ | ||
--vocab-file ${VOCAB_PATH} \ | ||
--merge-file ${MERGE_PATH} \ | ||
--data-impl mmap \ | ||
--override-lr-scheduler \ | ||
--adam-beta1 0.9 \ | ||
--adam-beta2 0.95 \ | ||
--tensor-model-parallel-size ${MP_SIZE} \ | ||
--init-method-std 0.014 \ | ||
--lr-decay-tokens ${LR_DECAY_TOKENS} \ | ||
--lr-warmup-samples ${LR_WARMUP_SAMPLES} \ | ||
--micro-batch-size ${MICRO_BATCH_SIZE} \ | ||
--global-batch-size ${TARGET_GLOBAL_BATCH_SIZE} \ | ||
--num-layers 12 \ | ||
--hidden-size 768 \ | ||
--num-attention-heads 16 \ | ||
--seq-length ${SEQLEN} \ | ||
--max-position-embeddings ${SEQLEN} \ | ||
--train-samples ${TRAIN_SAMPLES} \ | ||
--train-tokens ${TRAIN_TOKENS} \ | ||
--lr ${LR} \ | ||
--min-lr ${MIN_LR} \ | ||
--lr-decay-style cosine \ | ||
--split 98,2,0 \ | ||
--log-interval ${LOG_INTERVAL} \ | ||
--eval-interval ${EVAL_INTERVAL} \ | ||
--eval-iters ${EVAL_ITERS} \ | ||
--save-interval ${SAVE_INTERVAL} \ | ||
--weight-decay 0.1 \ | ||
--clip-grad 1.0 \ | ||
--hysteresis 2 \ | ||
--num-workers 0 \ | ||
--checkpoint-activations \ | ||
--fp16 \ | ||
--load ${CHECKPOINT_PATH} \ | ||
--save ${CHECKPOINT_PATH} \ | ||
--tensorboard-queue-size 1 \ | ||
--log-timers-to-tensorboard \ | ||
--log-batch-size-to-tensorboard \ | ||
--log-validation-ppl-to-tensorboard \ | ||
--no-masked-softmax-fusion \ | ||
--tensorboard-dir ${TENSORBOARD_DIR}" | ||
|
||
config_json="ds_config_cl.json" | ||
|
||
deepspeed_options=" \ | ||
--deepspeed \ | ||
--deepspeed_config ${config_json} \ | ||
--pipeline-model-parallel-size ${PP_SIZE} \ | ||
--partition-activations" | ||
|
||
run_cmd="deepspeed ../../pretrain_gpt.py ${megatron_options} ${deepspeed_options} &>> ${NAME}.log" | ||
echo ${run_cmd} | ||
eval ${run_cmd} | ||
set +x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.