Skip to content

Commit

Permalink
refactor cv log
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Mar 29, 2024
1 parent 10b7045 commit b39a8b1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
5 changes: 2 additions & 3 deletions wenet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def main():
configs['epoch'] = epoch

lrs = [group['lr'] for group in optimizer.param_groups]
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(
epoch, lrs, rank))
logging.info('Epoch {} Step {} TRAIN info lr {} rank {}'.format(
epoch, executor.step, lrs, rank))

dist.barrier(
) # NOTE(xcsong): Ensure all ranks start Train at the same time.
Expand All @@ -149,7 +149,6 @@ def main():
dist.barrier(
) # NOTE(xcsong): Ensure all ranks start CV at the same time.
loss_dict = executor.cv(model, cv_data_loader, configs)

info_dict = {
'epoch': epoch,
'lrs': [group['lr'] for group in optimizer.param_groups],
Expand Down
13 changes: 10 additions & 3 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,11 @@ def log_per_step(writer, info_dict, timer: Optional[StepTimer] = None):
# CV
for name, value in loss_dict.items():
writer.add_scalar('cv/{}'.format(name), value, step + 1)
logging.info(
'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format(
epoch, step, " ".join(["{:.4e}".format(lr) for lr in lrs]),
loss_dict["loss"], rank, loss_dict["acc"]))

if "step_" not in tag and (batch_idx + 1) % log_interval == 0:
log_str = '{} | '.format(tag)
if timer is not None:
Expand All @@ -652,9 +657,11 @@ def log_per_epoch(writer, info_dict):
loss_dict = info_dict["loss_dict"]
lrs = info_dict['lrs']
rank = int(os.environ.get('RANK', 0))
logging.info('Epoch {} CV info lr {} cv_loss {} rank {} acc {}'.format(
epoch, " ".join(["{:.4e}".format(lr) for lr in lrs]),
loss_dict["loss"], rank, loss_dict["acc"]))
step = info_dict["step"]
logging.info(
'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format(
epoch, step, " ".join(["{:.4e}".format(lr) for lr in lrs]),
loss_dict["loss"], rank, loss_dict["acc"]))

if int(os.environ.get('RANK', 0)) == 0:
for i, lr in enumerate(info_dict["lrs"]):
Expand Down

0 comments on commit b39a8b1

Please # to comment.