Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[transformer] add multi warmup and learning rate for different modules #2449

Merged
merged 19 commits into from
Apr 11, 2024
17 changes: 7 additions & 10 deletions wenet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch.distributed as dist

from torch.distributed.elastic.multiprocessing.errors import record
from wenet.utils.common import lrs_to_str

from wenet.utils.executor import Executor
from wenet.utils.config import override_config
Expand Down Expand Up @@ -117,8 +118,7 @@ def main():

# Get executor
tag = configs["init_infos"].get("tag", "init")
executor = Executor(global_step=configs["init_infos"].get('step', -1) +
int("step_" in tag))
executor = Executor(global_step=configs["init_infos"].get('step', -1))

# Init scaler, used for pytorch amp mixed precision training
scaler = init_scaler(args)
Expand All @@ -134,9 +134,9 @@ def main():
for epoch in range(start_epoch, end_epoch):
configs['epoch'] = epoch

lr = optimizer.param_groups[0]['lr']
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(
epoch, lr, rank))
lrs = [group['lr'] for group in optimizer.param_groups]
logging.info('Epoch {} Step {} TRAIN info lr {} rank {}'.format(
epoch, executor.step, lrs_to_str(lrs), rank))

dist.barrier(
) # NOTE(xcsong): Ensure all ranks start Train at the same time.
Expand All @@ -150,19 +150,16 @@ def main():
dist.barrier(
) # NOTE(xcsong): Ensure all ranks start CV at the same time.
loss_dict = executor.cv(model, cv_data_loader, configs)

lr = optimizer.param_groups[0]['lr']
logging.info('Epoch {} CV info lr {} cv_loss {} rank {} acc {}'.format(
epoch, lr, loss_dict["loss"], rank, loss_dict["acc"]))
info_dict = {
'epoch': epoch,
'lr': lr,
'lrs': [group['lr'] for group in optimizer.param_groups],
'step': executor.step,
'save_time': datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'),
'tag': "epoch_{}".format(epoch),
'loss_dict': loss_dict,
**configs
}
# epoch cv: tensorboard && log
log_per_epoch(writer, info_dict=info_dict)
save_model(model, info_dict=info_dict)

Expand Down
2 changes: 1 addition & 1 deletion wenet/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def load_checkpoint(model: torch.nn.Module, path: str) -> dict:


def save_state_dict_and_infos(state_dict, path: str, infos=None):
logging.info('Checkpoint: save to checkpoint %s' % path)
torch.save(state_dict, path)
info_path = re.sub('.pt$', '.yaml', path)
if infos is None:
Expand All @@ -56,7 +57,6 @@ def save_checkpoint(model: torch.nn.Module, path: str, infos=None):
Args:
infos (dict or None): any info you want to save.
'''
logging.info('Checkpoint: save to checkpoint %s' % path)
if isinstance(model, torch.nn.DataParallel):
state_dict = model.module.state_dict()
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
Expand Down
13 changes: 13 additions & 0 deletions wenet/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,19 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return mask


def get_nested_attribute(obj, attr_path):
if isinstance(obj, torch.nn.parallel.DistributedDataParallel):
obj = obj.module
attributes = attr_path.split('.')
for attr in attributes:
obj = getattr(obj, attr)
return obj


def lrs_to_str(lrs: List):
return " ".join(["{:.4e}".format(lr) for lr in lrs])


class StepTimer:
"""Utility class for measuring steps/second."""

Expand Down
18 changes: 11 additions & 7 deletions wenet/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
class Executor:

def __init__(self, global_step: int = 0):
self.step = global_step
self.step = global_step + 1
self.train_step_timer = None
self.cv_step_timer = None

Expand Down Expand Up @@ -85,9 +85,12 @@ def train(self, model, optimizer, scheduler, train_data_loader,
info_dict = update_parameter_and_lr(model, optimizer,
scheduler, scaler,
info_dict)
# write training: tensorboard && log
log_per_step(writer, info_dict, timer=self.train_step_timer)
save_interval = info_dict.get('save_interval', sys.maxsize)
if self.step % save_interval == 0 and self.step != 0 \
and (batch_idx + 1) % info_dict["accum_grad"] == 0:
if (self.step +
1) % save_interval == 0 and self.step != 0 and (
batch_idx + 1) % info_dict["accum_grad"] == 0:
import torch.distributed as dist
# Ensure all ranks start CV at the same time in step mode
dist.barrier()
Expand All @@ -100,13 +103,14 @@ def train(self, model, optimizer, scheduler, train_data_loader,
loss_dict,
"save_time":
datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'),
"lr":
optimizer.param_groups[0]['lr']
"lrs":
[group['lr'] for group in optimizer.param_groups]
})
save_model(model, info_dict)
# write final cv: tensorboard
log_per_step(writer, info_dict)
# Ensure all ranks start Train at the same time in step mode
dist.barrier()
log_per_step(writer, info_dict, timer=self.train_step_timer)
self.step += 1 if (batch_idx +
1) % info_dict["accum_grad"] == 0 else 0

Expand Down Expand Up @@ -143,7 +147,7 @@ def cv(self, model, cv_data_loader, configs):
loss_value = loss_value.item()
loss_dict[loss_name] = loss_dict.get(loss_name, 0) + \
loss_value * num_utts

# write cv: log
log_per_step(writer=None,
info_dict=info_dict,
timer=self.cv_step_timer)
Expand Down
28 changes: 17 additions & 11 deletions wenet/utils/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Modified from ESPnet(https://github.com/espnet/espnet)
# NeMo(https://github.com/NVIDIA/NeMo)

from typing import Union
from typing import List, Union

import math
import warnings
Expand Down Expand Up @@ -43,11 +43,10 @@ class WarmupLR(_LRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: Union[int, float] = 25000,
warmup_steps: Union[int, float, List[Union[int, float]]] = 25000,
last_epoch: int = -1,
):
self.warmup_steps = warmup_steps

# __init__() must be invoked before setting field
# because step() is also invoked in __init__()
super().__init__(optimizer, last_epoch)
Expand All @@ -57,14 +56,21 @@ def __repr__(self):

def get_lr(self):
step_num = self.last_epoch + 1
if self.warmup_steps == 0:
return [lr * step_num**-0.5 for lr in self.base_lrs]
else:
return [
lr * self.warmup_steps**0.5 *
min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
for lr in self.base_lrs
]
warmup_steps = self.warmup_steps
if not isinstance(warmup_steps, List):
warmup_steps = [self.warmup_steps] * len(self.base_lrs)

def initlr_fn(lr):
return lr * step_num**-0.5

def warmuplr_fn(lr, warmup_step):
return lr * warmup_step**0.5 * min(step_num**-0.5,
step_num * warmup_step**-1.5)

return [
initlr_fn(lr) if warmup_steps[i] == 0 else warmuplr_fn(
lr, warmup_steps[i]) for (i, lr) in enumerate(self.base_lrs)
]

def set_step(self, step: int):
self.last_epoch = step
Expand Down
79 changes: 62 additions & 17 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from contextlib import nullcontext
import copy
from typing import Optional
from typing import List, Optional

import deepspeed
import json
Expand All @@ -41,10 +41,10 @@
convert_zero_checkpoint_to_fp32_state_dict)
from wenet.dataset.dataset import Dataset
from wenet.utils.checkpoint import save_checkpoint
from wenet.utils.common import StepTimer, get_nested_attribute, lrs_to_str
from wenet.utils.fsdp_utils import (check_gradient_checkpoint, fsdp_save_model,
apply_fsdp_checkpointing,
wenet_fsdp_wrap_policy)
from wenet.utils.common import StepTimer
from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing
from wenet.utils.ctc_utils import get_blank_id

Expand Down Expand Up @@ -439,10 +439,38 @@ def wrap_cuda_model(args, model, configs=None):


def init_optimizer_and_scheduler(args, configs, model):
groups = []
lr = configs['optim_conf'].get('lr')
if isinstance(lr, List):
assert configs['scheduler'] == 'warmuplr'
modules_m = configs['optim_conf']['modules']
assert isinstance(modules_m, List)
assert len(modules_m) + 1 == len(lr)
special_param_ids = set()
rest_params = []
for (i, m_str) in enumerate(modules_m):
sub_module = get_nested_attribute(model, m_str)
subs_params = []
for _, sub_params in sub_module.named_parameters():
subs_params.append(sub_params)
special_param_ids.add(id(sub_params))
groups.append({'params': subs_params, 'lr': lr[i]})
# other model's parameters
for _, param in model.named_parameters():
if id(param) not in special_param_ids:
rest_params.append(param)
groups.append({'params': rest_params, 'lr': lr[-1]})

params = groups if len(groups) > 0 else model.parameters()
optim_conf = copy.deepcopy(configs['optim_conf'])
if 'modules' in optim_conf:
del optim_conf['modules']
if isinstance(lr, List):
optim_conf['lr'] = lr[-1]
Comment on lines +468 to +469
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一步是为啥,optimizer必须传一个lr参数且不能是list?

if configs['optim'] == 'adam':
optimizer = optim.Adam(model.parameters(), **configs['optim_conf'])
optimizer = optim.Adam(params, **optim_conf)
elif configs['optim'] == 'adamw':
optimizer = optim.AdamW(model.parameters(), **configs['optim_conf'])
optimizer = optim.AdamW(params, **optim_conf)
else:
raise ValueError("unknown optimizer: " + configs['optim'])

Expand Down Expand Up @@ -704,7 +732,7 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
scheduler.step()
grad_norm = grad_norm.item()

info_dict["lr"] = optimizer.param_groups[0]['lr']
info_dict["lrs"] = [group['lr'] for group in optimizer.param_groups]
info_dict["grad_norm"] = grad_norm

return info_dict
Expand All @@ -719,28 +747,36 @@ def log_per_step(writer, info_dict, timer: Optional[StepTimer] = None):
train_engine = info_dict.get("train_engine", "torch_ddp")
accum_grad = info_dict.get('accum_grad', 1) if tag != "CV" else 1
log_interval = info_dict.get('log_interval', 10)
lr = info_dict.get("lr", 0.0)
lrs = info_dict.get("lrs", [0.0])
is_gradient_accumulation_boundary = info_dict.get(
"is_gradient_accumulation_boundary", False)

rank = int(os.environ.get('RANK', 0))

# TRAIN Tensorboard
if tag == "TRAIN" and rank == 0 and writer is not None:
if (train_engine == "deepspeed" and is_gradient_accumulation_boundary
) or (train_engine in ["torch_ddp", "torch_fsdp"] and
(batch_idx + 1) % accum_grad == 0):
writer.add_scalar('train/train_loss',
loss_dict['loss'] * accum_grad, step + 1)
writer.add_scalar('train/grad_norm', info_dict['grad_norm'],
step + 1)
loss_dict['loss'] * accum_grad, step)
writer.add_scalar('train/grad_norm', info_dict['grad_norm'], step)
for name, value in loss_dict.items():
if name != 'loss' and value is not None:
writer.add_scalar('train/{}'.format(name), value, step + 1)
writer.add_scalar('train/{}'.format(name), value, step)
# lr
for i, lr in enumerate(lrs):
writer.add_scalar('train/lr_{}'.format(i), lr, step)
# CV Tensorboard
elif "step_" in tag and rank == 0 and writer is not None:
writer.add_scalar('global_step/lr', lr, step + 1)
for name, value in loss_dict.items():
writer.add_scalar('global_step/{}'.format(name), value, step + 1)

writer.add_scalar('cv/{}'.format(name), value, step)
logging.info(
'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format(
epoch, step + 1, lrs_to_str(lrs), loss_dict["loss"], rank,
loss_dict["acc"]))
return

# TRAIN & CV, Shell log (stdout)
if (batch_idx + 1) % log_interval == 0:
log_str = '{} | '.format(tag)
if timer is not None:
Expand All @@ -757,16 +793,25 @@ def log_per_step(writer, info_dict, timer: Optional[StepTimer] = None):
if name != 'loss' and value is not None:
log_str += '{} {:.6f} '.format(name, value)
if tag == "TRAIN":
log_str += 'lr {:.8f} grad_norm {:.6f} rank {}'.format(
lr, info_dict['grad_norm'], rank)
log_str += 'lr {} grad_norm {:.6f} rank {}'.format(
lrs_to_str(lrs), info_dict['grad_norm'], rank)
logging.debug(log_str)


def log_per_epoch(writer, info_dict):
epoch = info_dict["epoch"]
loss_dict = info_dict["loss_dict"]
lrs = info_dict['lrs']
rank = int(os.environ.get('RANK', 0))
step = info_dict["step"]
logging.info(
'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format(
epoch, step, lrs_to_str(lrs), loss_dict["loss"], rank,
loss_dict["acc"]))

if int(os.environ.get('RANK', 0)) == 0:
writer.add_scalar('epoch/lr', info_dict["lr"], epoch)
for i, lr in enumerate(info_dict["lrs"]):
writer.add_scalar('epoch/lr_{}'.format(i), lr, epoch)
for name, value in loss_dict.items():
writer.add_scalar('epoch/{}'.format(name), value, epoch)

Expand Down
Loading