Skip to content

Commit

Permalink
Curriculum learning support (deepspeedai#17)
Browse files Browse the repository at this point in the history
* initial commit

* script fix
  • Loading branch information
conglongli authored Oct 28, 2021
1 parent 0d88755 commit db97cd2
Show file tree
Hide file tree
Showing 10 changed files with 341 additions and 38 deletions.
32 changes: 32 additions & 0 deletions examples/curriculum_learning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
This is a short tutorial of how to use/tune the curriculum learning (CL) integration. Currently it is only integrated for GPT pre-training. For technical details please refer to our [paper](https://arxiv.org/abs/2108.06084).

# Disable batch size warmup (--rampup-batch-size)
In our [paper](https://arxiv.org/abs/2108.06084) section 5.4 we demonstrate that curriculum learning (seqlen-based) provides much better training stability than the batch size warmup technique. So when using CL you need to remove the `--rampup-batch-size` config in your training script. It's not recommended to use both CL and batch size warmup, because both of them will reduce the number of tokens in a batch. Another related change you might want is to increase your micro batch size, since without batch size warmup your batch size will be fixed now.

# Token-based training termination

Because CL changes length of each sequence/sample during training, it is very hard/impossible to use number of steps/samples to terminate the training exactly at the desired number of tokens. Thus we add a `--train-tokens` config as an alternative accurate token-based termination. We recommend increase your original `--train-samples` or `--train-iters` to a large enough number (e.g., 2X of what you used for baseline), and set `--train-tokens` at the exact desired number of training tokens (e.g., 300B for GPT-3 like training).

# Token-based LR decay

Again because CL changes the number of tokens per batch, in our [paper](https://arxiv.org/abs/2108.06084) Appendix A.2 we show that it is also necessary to change the LR decay to token-based (to avoid decaying LR too fast). Thus we add a `--lr-decay-tokens` which will be the number of LR decay tokens. If previously you were using `--lr-decay-samples`, you can calculate your `--lr-decay-tokens` simply by multiplying the former by full seqlen (e.g. 2K for GPT-3). Then you need to replace `--lr-decay-samples` with `--lr-decay-tokens` in your script.

# LR warmup adjustment

For LR warmup we don't change it to token-based, because doing so for CL means slowing down the LR warmup, which is both unnecessary and harmful. However, you may need to adjust your `--lr-warmup-samples` or `--lr-warmup-iters` from non-CL cases for various reasons (e.g., if you used `--rampup-batch-size` in non-CL case, for CL we don't use it so the number of samples per batch will be different at beginning). Assuming you want to use `X` tokens to warmup the LR (for OpenAI GPT-3 this was 375M tokens), then for CL case you shall set `--lr-warmup-samples` as `X` divided by the `min_difficulty` below, or set `--lr-warmup-iters` as `X` divided by `min_difficulty * --global-batch-size`. This is a rough estimation based on that CL starts from seqlen `min_difficulty` and it won't increase too much during LR warmup.

# Token-based tensorboard

Because of the above changes, we also add token-based tensorboard scalars. We also add scalars that plot the seqlen at each step.

# Curriculum learning hyperparameters tuning strategy

The curriculum learning hyperparameters are all located in the deepspeed config json file (see the example `ds_config_cl.json` in this dir). There are a few config entries that you may need to adjust to your circumstances, and two of which require some tuning. In our [paper](https://arxiv.org/abs/2108.06084) Appendix A.1 we have a more detailed tuning strategy description.

1. `max_difficulty` should be set as the full seqlen (i.e., your `--seq-length`). No need to tune this.

2. `min_difficulty` is the beginning seqlen used by CL. In general smaller `min_difficulty` could provide better stability/convergence speed benefit. However we observe that for a larger model or for different training data, starting from a very small seqlen could lead to significant validation PPL fluctuation (or even divergence) at the very beginning. We recommend to start with `min_difficulty` at 64, and then increase it if you observe problems at the very beginning. Note that to enable Tensor Core acceleration you should always use a multiple of 8.

3. `total_curriculum_step` is the total number of steps used by CL. In general larger `total_curriculum_step` could provide better stability/convergence speed benefit. However we observe that a too large `total_curriculum_step` could lead to overfitting and significant validation PPL fluctuation (or even divergence) at the first few multiple of LR warmup steps. In our paper we have a detailed tuning strategy based on binary search. However, if you want to reduce the tuning effort we recommend directly setting `total_curriculum_step` as half of baseline's total number of steps. This may not provide the highest convergence speed benefit, but should provide enough training stability gains.

4. `difficulty_step` is the change in seq length per CL step. A smaller value is preferable since it gives more smooth CL and better stability. Like `min_difficulty` it too needs to be multiple of 8 for Tensor core acceleration, thus 8 is a good default.
37 changes: 37 additions & 0 deletions examples/curriculum_learning/ds_config_cl.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"train_batch_size": 512,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"zero_optimization": {
"stage": 0
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0,
"betas": [0.9, 0.95]
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"wall_clock_breakdown": false,
"zero_allow_untested_optimizer": false,
"curriculum_learning": {
"enabled": true,
"curriculum_type": "seqlen",
"min_difficulty": 8,
"max_difficulty": 1024,
"schedule_type": "fixed_linear",
"schedule_config": {
"total_curriculum_step": 60000,
"difficulty_step": 8
}
}
}
105 changes: 105 additions & 0 deletions examples/curriculum_learning/pretrain_gpt_cl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#!/bin/bash

# This is a dummy train script to show how to use curriculum
# learning, some parameters are not for actual GPT pretraining.

TARGET_GLOBAL_BATCH_SIZE=512
TRAIN_SAMPLES=146_484_375
LR=1.0e-4
MIN_LR=1.0e-5
LR_DECAY_SAMPLES=126_953_125
LR_WARMUP_SAMPLES=183_105
SEQLEN=1024

############################################################
# New configs for curriculum learning, see README.md
TRAIN_TOKENS=10_000_000_000
# LR_DECAY_TOKENS=LR_DECAY_SAMPLES*SEQLEN
LR_DECAY_TOKENS=130000000000
############################################################

LOG_INTERVAL=100
EVAL_ITERS=10
EVAL_INTERVAL=100
SAVE_INTERVAL=1000

VOCAB_PATH=/data/Megatron-LM/data/gpt2-vocab.json
MERGE_PATH=/data/Megatron-LM/data/gpt2-merges.txt
DATA_PATH=/data/Megatron-LM/data/indexed_datasets/megatron

MICRO_BATCH_SIZE=1
MP_SIZE=1
PP_SIZE=1

NUM_GPUS=128
echo ${NUM_GPUS}
if [[ $PP_SIZE -gt 0 ]]; then
DP_SIZE=$(( ${NUM_GPUS} / (${PP_SIZE} * ${MP_SIZE}) ))
else
DP_SIZE=$(( ${NUM_GPUS} / ${MP_SIZE} ))
fi
GRAD_ACC_STEPS=$(( ${TARGET_GLOBAL_BATCH_SIZE} / (${MICRO_BATCH_SIZE} * ${DP_SIZE}) ))

NAME="gpt-117M-pp${PP_SIZE}-mp${MP_SIZE}-bsz${TARGET_GLOBAL_BATCH_SIZE}-mbsz${MICRO_BATCH_SIZE}-cl"
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
host="${HOSTNAME}"
TENSORBOARD_DIR="tensorboard/${NAME}_${host}_${current_time}"
mkdir -p ${TENSORBOARD_DIR}
CHECKPOINT_PATH="checkpoints/${NAME}"

megatron_options=" \
--data-path ${DATA_PATH} \
--vocab-file ${VOCAB_PATH} \
--merge-file ${MERGE_PATH} \
--data-impl mmap \
--override-lr-scheduler \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--tensor-model-parallel-size ${MP_SIZE} \
--init-method-std 0.014 \
--lr-decay-tokens ${LR_DECAY_TOKENS} \
--lr-warmup-samples ${LR_WARMUP_SAMPLES} \
--micro-batch-size ${MICRO_BATCH_SIZE} \
--global-batch-size ${TARGET_GLOBAL_BATCH_SIZE} \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 16 \
--seq-length ${SEQLEN} \
--max-position-embeddings ${SEQLEN} \
--train-samples ${TRAIN_SAMPLES} \
--train-tokens ${TRAIN_TOKENS} \
--lr ${LR} \
--min-lr ${MIN_LR} \
--lr-decay-style cosine \
--split 98,2,0 \
--log-interval ${LOG_INTERVAL} \
--eval-interval ${EVAL_INTERVAL} \
--eval-iters ${EVAL_ITERS} \
--save-interval ${SAVE_INTERVAL} \
--weight-decay 0.1 \
--clip-grad 1.0 \
--hysteresis 2 \
--num-workers 0 \
--checkpoint-activations \
--fp16 \
--load ${CHECKPOINT_PATH} \
--save ${CHECKPOINT_PATH} \
--tensorboard-queue-size 1 \
--log-timers-to-tensorboard \
--log-batch-size-to-tensorboard \
--log-validation-ppl-to-tensorboard \
--no-masked-softmax-fusion \
--tensorboard-dir ${TENSORBOARD_DIR}"

config_json="ds_config_cl.json"

deepspeed_options=" \
--deepspeed \
--deepspeed_config ${config_json} \
--pipeline-model-parallel-size ${PP_SIZE} \
--partition-activations"

run_cmd="deepspeed ../../pretrain_gpt.py ${megatron_options} ${deepspeed_options} &>> ${NAME}.log"
echo ${run_cmd}
eval ${run_cmd}
set +x
9 changes: 9 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def parse_args(extra_args_provider=None, defaults={},
# Consumed tokens.
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
args.consumed_train_tokens = 0

# Iteration-based training.
if args.train_iters:
Expand Down Expand Up @@ -239,6 +240,8 @@ def parse_args(extra_args_provider=None, defaults={},
'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations'

args.curriculum_learning = False

_print_args(args)
return args

Expand Down Expand Up @@ -410,6 +413,9 @@ def _add_training_args(parser):
help='Total number of samples to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--train-tokens', type=int, default=None,
help='Total number of tokens to train over all '
'training runs.')
group.add_argument('--log-interval', type=int, default=100,
help='Report loss and timing interval.')
group.add_argument('--exit-interval', type=int, default=None,
Expand Down Expand Up @@ -475,6 +481,9 @@ def _add_learning_rate_args(parser):
group.add_argument('--lr-decay-samples', type=int, default=None,
help='number of samples to decay learning rate over,'
' If None defaults to `--train-samples`')
group.add_argument('--lr-decay-tokens', type=int, default=None,
help='number of tokens to decay learning rate over,'
' If not None will override iter/sample-based decay')
group.add_argument('--lr-warmup-fraction', type=float, default=None,
help='fraction of lr-warmup-(iters/samples) to use '
'for warmup (as a float)')
Expand Down
3 changes: 3 additions & 0 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration
state_dict['tokens'] = args.consumed_train_tokens

# DeepSpeed saves the model/optimizer/scheduler
if not args.deepspeed:
Expand Down Expand Up @@ -332,6 +333,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
else:
try:
iteration = state_dict['iteration']
if 'tokens' in state_dict:
args.consumed_train_tokens = state_dict['tokens']
except KeyError:
try: # Backward compatible with older checkpoints
iteration = state_dict['total_iters']
Expand Down
52 changes: 40 additions & 12 deletions megatron/learning_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import math

from megatron import print_rank_0
from megatron import print_rank_0, get_args

class AnnealingLR(object):
"""Anneals the learning rate."""
Expand All @@ -26,7 +26,7 @@ def __init__(self, optimizer, max_lr, min_lr,
warmup_steps, decay_steps, decay_style,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False):

args = get_args()
# Class values.
self.optimizer = optimizer

Expand All @@ -41,6 +41,10 @@ def __init__(self, optimizer, max_lr, min_lr,
assert self.decay_steps > 0
assert self.warmup_steps < self.decay_steps

self.decay_tokens = args.lr_decay_tokens
self.num_tokens = 0
self.warmup_tokens = 0

self.decay_style = decay_style

self.override_lr_scheduler = override_lr_scheduler
Expand All @@ -61,21 +65,35 @@ def get_lr(self):

# Use linear warmup for the initial part.
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
if self.num_steps == self.warmup_steps and \
self.decay_tokens is not None:
self.warmup_tokens = self.num_tokens
return self.max_lr * float(self.num_steps) / \
float(self.warmup_steps)

# If the learning rate is constant, just return the initial value.
if self.decay_style == 'constant':
return self.max_lr

# For any steps larger than `self.decay_steps`, use `self.min_lr`.
if self.num_steps > self.decay_steps:
return self.min_lr

# If we are done with the warmup period, use the decay style.
num_steps_ = self.num_steps - self.warmup_steps
decay_steps_ = self.decay_steps - self.warmup_steps
decay_ratio = float(num_steps_) / float(decay_steps_)
if self.decay_tokens is None:
# step-based decay

# For any steps larger than `self.decay_steps`, use `self.min_lr`.
if self.num_steps > self.decay_steps:
return self.min_lr

# If we are done with the warmup period, use the decay style.
num_steps_ = self.num_steps - self.warmup_steps
decay_steps_ = self.decay_steps - self.warmup_steps
decay_ratio = float(num_steps_) / float(decay_steps_)
else:
# token-based decay

if self.num_tokens > self.decay_tokens:
return self.min_lr
num_tokens_ = self.num_tokens - self.warmup_tokens
decay_tokens_ = self.decay_tokens - self.warmup_tokens
decay_ratio = float(num_tokens_) / float(decay_tokens_)
assert decay_ratio >= 0.0
assert decay_ratio <= 1.0
delta_lr = self.max_lr - self.min_lr
Expand All @@ -91,8 +109,12 @@ def get_lr(self):
return self.min_lr + coeff * delta_lr


def step(self, increment):
def step(self, increment, token_num=None):
"""Set lr for all parameters groups."""
if token_num is None:
args = get_args()
token_num = args.consumed_train_tokens
self.num_tokens = token_num
self.num_steps += increment
new_lr = self.get_lr()
for group in self.optimizer.param_groups:
Expand All @@ -104,6 +126,8 @@ def state_dict(self):
'max_lr': self.max_lr,
'warmup_steps': self.warmup_steps,
'num_steps': self.num_steps,
'warmup_tokens': self.warmup_tokens,
'num_tokens': self.num_tokens,
'decay_style': self.decay_style,
'decay_steps': self.decay_steps,
'min_lr': self.min_lr
Expand Down Expand Up @@ -161,4 +185,8 @@ def load_state_dict(self, sd):
num_steps = sd['num_iters']
else:
num_steps = sd['num_steps']
self.step(increment=num_steps)
if 'warmup_tokens' in sd:
self.warmup_tokens = sd['warmup_tokens']
if 'num_tokens' in sd:
self.num_tokens = sd['num_tokens']
self.step(num_steps, self.num_tokens)
14 changes: 13 additions & 1 deletion megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,19 @@ def set_input_tensor(self, input_tensor):

def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
forward_method_parallel_output=None, curriculum_seqlen=None):
if curriculum_seqlen is not None:
args = get_args()
args.curriculum_seqlen = curriculum_seqlen
if curriculum_seqlen < input_ids.size()[1]:
# seqlen-based curriculum learning
# input_ids, position_ids, labels have size [batch size, seqlen]
input_ids = input_ids[:, :curriculum_seqlen].contiguous()
position_ids = position_ids[:, :curriculum_seqlen].contiguous()
labels = labels[:, :curriculum_seqlen].contiguous()

# attention_mask has size [1, 1, seqlen, seqlen]
attention_mask = attention_mask[:, :, :curriculum_seqlen, :curriculum_seqlen].contiguous()

lm_output = self.language_model(
input_ids,
Expand Down
11 changes: 10 additions & 1 deletion megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,16 @@ def init_(tensor):


def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
args = get_args()
if args.curriculum_learning:
attention_mask_ = attention_mask
actual_seqlen = attention_scores.size()[2]
if actual_seqlen != attention_mask_.size()[2]:
# attention_mask has size [1, 1, seqlen, seqlen]
attention_mask_ = attention_mask_[:, :, :actual_seqlen, :actual_seqlen].contiguous()
attention_scores.masked_fill_(attention_mask_, -10000.0)
else:
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores


Expand Down
Loading

0 comments on commit db97cd2

Please # to comment.