From 396b1888f283106d9c11a1090892530cff8b39cb Mon Sep 17 00:00:00 2001 From: Mddct Date: Fri, 29 Mar 2024 11:44:28 +0800 Subject: [PATCH] refactor cv log --- wenet/bin/train.py | 1 - wenet/utils/train_utils.py | 5 +++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 3d91e36265..4562d5f532 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -149,7 +149,6 @@ def main(): dist.barrier( ) # NOTE(xcsong): Ensure all ranks start CV at the same time. loss_dict = executor.cv(model, cv_data_loader, configs) - info_dict = { 'epoch': epoch, 'lrs': [group['lr'] for group in optimizer.param_groups], diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index d8cf8dd348..5d853ae7a4 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -627,6 +627,11 @@ def log_per_step(writer, info_dict, timer: Optional[StepTimer] = None): # CV for name, value in loss_dict.items(): writer.add_scalar('cv/{}'.format(name), value, step + 1) + logging.info( + 'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format( + epoch, step, " ".join(["{:.4e}".format(lr) for lr in lrs]), + loss_dict["loss"], rank, loss_dict["acc"])) + if "step_" not in tag and (batch_idx + 1) % log_interval == 0: log_str = '{} | '.format(tag) if timer is not None: