Skip to content

Commit

Permalink
refactor cv log
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Mar 29, 2024
1 parent 10b7045 commit 396b188
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
1 change: 0 additions & 1 deletion wenet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def main():
dist.barrier(
) # NOTE(xcsong): Ensure all ranks start CV at the same time.
loss_dict = executor.cv(model, cv_data_loader, configs)

info_dict = {
'epoch': epoch,
'lrs': [group['lr'] for group in optimizer.param_groups],
Expand Down
5 changes: 5 additions & 0 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,11 @@ def log_per_step(writer, info_dict, timer: Optional[StepTimer] = None):
# CV
for name, value in loss_dict.items():
writer.add_scalar('cv/{}'.format(name), value, step + 1)
logging.info(
'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format(
epoch, step, " ".join(["{:.4e}".format(lr) for lr in lrs]),
loss_dict["loss"], rank, loss_dict["acc"]))

if "step_" not in tag and (batch_idx + 1) % log_interval == 0:
log_str = '{} | '.format(tag)
if timer is not None:
Expand Down

0 comments on commit 396b188

Please # to comment.