Skip to content

Learning Rate Scheduling of Transformer

WEI HAORAN edited this page Sep 12, 2018 · 2 revisions

Introduction

Transformer model has its unique way to do learning rate scheduling. Usually it has a warmup phase and a decay phase. During warmup phase, the learning rate increases linearly, while decays in reservse-square or exponential method. There are many configurations about such kind of scheduling and many of them are included in tensor2tensor. In order to follow the lastest progress, we manage to make our code compatible with t2t on this part. Some work are still working in progress.

Overview

T2T now has two kinds of scheduling configuration: legacy configuration, which is the same as the method in Attention is All You Need, and a factored configuration, which combines different kinds of scheduling in different timesteps. This code now only support the former one, and is working on implement the latter.

Legacy Scheduling (Noam)

The overall formular about legacy scheduling, or Noam, is $$ \text{lrate} = \text{ret} \text{opt_corr} * \text{init_lr} $$ where $\text{init_lr}$ is the initial learning rate, $\text{opt_corr}$ a fixed value correlated to specific optimizer, and $\text{ret}$ a function of update steps.

As we use Adam to train transformer, $\text{opt_corr}$ is set to 0.002 in T2T. Noam computes $\text{ret}$ in this way: $$ ret = 5000.0 * d_{model}^{-0.5} * \min(t * \text{warmup_steps}^{-1.5}, t ^{-0.5}) $$ where $d_{model}$ is the dimension of model, t the current training step (start from 1), and $\text{warmup_steps}$ a given warmup steps.

T2T has two version settings, base_v1 and base_v2. In base_v1, $\text{warmup_steps}=4000$ and $\text{init_lr} = 0.1$. In base_v2, $\text{warmup_steps}=8000$ and $\text{init_lr} = 0.2$.

How to configure noam in our code Take base_v1 as example, we can configure it in yaml file like below:

optimizer_configs:
  optimizer: "adam"
  learning_rate: 0.1
  schedule_method: noam
  scheduler_configs:
    d_model: 512
    warmup_steps: 4000
Clone this wiki locally