From f8b795f427a39c19a6b7245be240680617156948 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 27 Feb 2020 08:19:48 -0800 Subject: [PATCH] Move meters, metrics and progress_bar into fairseq.logging (#1046) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1046 Differential Revision: D20030412 Pulled By: myleott fbshipit-source-id: bd87391aa9cdb73306ee90a30eeb2bdeff3690f9 --- examples/speech_recognition/infer.py | 71 +++++----- fairseq/__init__.py | 8 ++ fairseq/logging/__init__.py | 0 fairseq/{ => logging}/meters.py | 0 fairseq/{ => logging}/metrics.py | 0 fairseq/{ => logging}/progress_bar.py | 99 ++++++++----- fairseq/options.py | 2 +- fairseq/trainer.py | 6 +- fairseq_cli/eval_lm.py | 158 +++++++++++---------- fairseq_cli/generate.py | 196 +++++++++++++------------- fairseq_cli/train.py | 33 +++-- fairseq_cli/validate.py | 13 +- 12 files changed, 328 insertions(+), 258 deletions(-) create mode 100644 fairseq/logging/__init__.py rename fairseq/{ => logging}/meters.py (100%) rename fairseq/{ => logging}/metrics.py (100%) rename fairseq/{ => logging}/progress_bar.py (79%) diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index bc6d51b0c7..4b9151c79a 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -14,8 +14,8 @@ import sentencepiece as spm import torch -from fairseq import checkpoint_utils, options, progress_bar, utils, tasks -from fairseq.meters import StopwatchMeter, TimeMeter +from fairseq import checkpoint_utils, options, utils, tasks +from fairseq.logging import meters, progress_bar from fairseq.utils import import_user_module @@ -199,9 +199,15 @@ def main(args): # Load dataset (possibly sharded) itr = get_dataset_itr(args, task) + progress = progress_bar.progress_bar( + itr, + log_format=args.log_format, + log_interval=args.log_interval, + default_log_format=('tqdm' if not args.no_progress_bar else 'none'), + ) # Initialize generator - gen_timer = StopwatchMeter() + gen_timer = meters.StopwatchMeter() generator = task.build_generator(args) num_sentences = 0 @@ -213,36 +219,35 @@ def main(args): sp.Load(os.path.join(args.data, "spm.model")) res_files = prepare_result_files(args) - with progress_bar.build_progress_bar(args, itr) as t: - wps_meter = TimeMeter() - for sample in t: - sample = utils.move_to_cuda(sample) if use_cuda else sample - if "net_input" not in sample: - continue - - prefix_tokens = None - if args.prefix_size > 0: - prefix_tokens = sample["target"][:, : args.prefix_size] - - gen_timer.start() - hypos = task.inference_step(generator, models, sample, prefix_tokens) - num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) - gen_timer.stop(num_generated_tokens) - - for i, sample_id in enumerate(sample["id"].tolist()): - speaker = task.dataset(args.gen_subset).speakers[int(sample_id)] - id = task.dataset(args.gen_subset).ids[int(sample_id)] - target_tokens = ( - utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu() - ) - # Process top predictions - process_predictions( - args, hypos[i], sp, tgt_dict, target_tokens, res_files, speaker, id - ) - - wps_meter.update(num_generated_tokens) - t.log({"wps": round(wps_meter.avg)}) - num_sentences += sample["nsentences"] + wps_meter = meters.TimeMeter() + for sample in progress: + sample = utils.move_to_cuda(sample) if use_cuda else sample + if "net_input" not in sample: + continue + + prefix_tokens = None + if args.prefix_size > 0: + prefix_tokens = sample["target"][:, : args.prefix_size] + + gen_timer.start() + hypos = task.inference_step(generator, models, sample, prefix_tokens) + num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) + gen_timer.stop(num_generated_tokens) + + for i, sample_id in enumerate(sample["id"].tolist()): + speaker = task.dataset(args.gen_subset).speakers[int(sample_id)] + id = task.dataset(args.gen_subset).ids[int(sample_id)] + target_tokens = ( + utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu() + ) + # Process top predictions + process_predictions( + args, hypos[i], sp, tgt_dict, target_tokens, res_files, speaker, id + ) + + wps_meter.update(num_generated_tokens) + progress.log({"wps": round(wps_meter.avg)}) + num_sentences += sample["nsentences"] logger.info( "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}" diff --git a/fairseq/__init__.py b/fairseq/__init__.py index 3ac6cf8585..809c6ed2f7 100644 --- a/fairseq/__init__.py +++ b/fairseq/__init__.py @@ -6,6 +6,14 @@ __all__ = ['pdb'] __version__ = '0.9.0' +import sys + +# backwards compatibility to support `from fairseq.meters import AverageMeter` +from fairseq.logging import meters, metrics, progress_bar # noqa +sys.modules['fairseq.meters'] = meters +sys.modules['fairseq.metrics'] = metrics +sys.modules['fairseq.progress_bar'] = progress_bar + import fairseq.criterions # noqa import fairseq.models # noqa import fairseq.modules # noqa diff --git a/fairseq/logging/__init__.py b/fairseq/logging/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/fairseq/meters.py b/fairseq/logging/meters.py similarity index 100% rename from fairseq/meters.py rename to fairseq/logging/meters.py diff --git a/fairseq/metrics.py b/fairseq/logging/metrics.py similarity index 100% rename from fairseq/metrics.py rename to fairseq/logging/metrics.py diff --git a/fairseq/progress_bar.py b/fairseq/logging/progress_bar.py similarity index 79% rename from fairseq/progress_bar.py rename to fairseq/logging/progress_bar.py index f72136d475..5259659f48 100644 --- a/fairseq/progress_bar.py +++ b/fairseq/logging/progress_bar.py @@ -7,53 +7,86 @@ Wrapper around various loggers and progress bars (e.g., tqdm). """ -from collections import OrderedDict -from contextlib import contextmanager import json import logging -from numbers import Number import os import sys +from collections import OrderedDict +from contextlib import contextmanager +from numbers import Number +from typing import Optional import torch -from fairseq import distributed_utils -from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter +from .meters import AverageMeter, StopwatchMeter, TimeMeter logger = logging.getLogger(__name__) -def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', no_progress_bar='none'): - if args.log_format is None: - args.log_format = no_progress_bar if args.no_progress_bar else default - - if args.log_format == 'tqdm' and not sys.stderr.isatty(): - args.log_format = 'simple' - - if args.log_format == 'json': - bar = json_progress_bar(iterator, epoch, prefix, args.log_interval) - elif args.log_format == 'none': - bar = noop_progress_bar(iterator, epoch, prefix) - elif args.log_format == 'simple': - bar = simple_progress_bar(iterator, epoch, prefix, args.log_interval) - elif args.log_format == 'tqdm': - bar = tqdm_progress_bar(iterator, epoch, prefix) +def progress_bar( + iterator, + log_format: Optional[str] = None, + log_interval: int = 100, + epoch: Optional[int] = None, + prefix: Optional[str] = None, + tensorboard_logdir: Optional[str] = None, + default_log_format: str = 'tqdm', +): + if log_format is None: + log_format = default_log_format + if log_format == 'tqdm' and not sys.stderr.isatty(): + log_format = 'simple' + + if log_format == 'json': + bar = JsonProgressBar(iterator, epoch, prefix, log_interval) + elif log_format == 'none': + bar = NoopProgressBar(iterator, epoch, prefix) + elif log_format == 'simple': + bar = SimpleProgressBar(iterator, epoch, prefix, log_interval) + elif log_format == 'tqdm': + bar = TqdmProgressBar(iterator, epoch, prefix) else: - raise ValueError('Unknown log format: {}'.format(args.log_format)) + raise ValueError('Unknown log format: {}'.format(log_format)) - if args.tensorboard_logdir and distributed_utils.is_master(args): + if tensorboard_logdir: try: # [FB only] custom wrapper for TensorBoard import palaas # noqa - from fairseq.fb_tbmf_wrapper import fb_tbmf_wrapper - bar = fb_tbmf_wrapper(bar, args, args.log_interval) + from .fb_tbmf_wrapper import FbTbmfWrapper + bar = FbTbmfWrapper(bar, log_interval) except ImportError: - bar = tensorboard_log_wrapper(bar, args.tensorboard_logdir, args) + bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir) return bar +def build_progress_bar( + args, + iterator, + epoch: Optional[int] = None, + prefix: Optional[str] = None, + default: str = 'tqdm', + no_progress_bar: str = 'none', +): + """Legacy wrapper that takes an argparse.Namespace.""" + if getattr(args, 'no_progress_bar', False): + default = no_progress_bar + if getattr(args, 'distributed_rank', 0) == 0: + tensorboard_logdir = getattr(args, 'tensorboard_logdir', None) + else: + tensorboard_logdir = None + return progress_bar( + iterator, + log_format=args.log_format, + log_interval=args.log_interval, + epoch=epoch, + prefix=prefix, + tensorboard_logdir=tensorboard_logdir, + default_log_format=default, + ) + + def format_stat(stat): if isinstance(stat, Number): stat = '{:g}'.format(stat) @@ -68,7 +101,7 @@ def format_stat(stat): return stat -class progress_bar(object): +class BaseProgressBar(object): """Abstract class for progress bars.""" def __init__(self, iterable, epoch=None, prefix=None): self.iterable = iterable @@ -125,7 +158,7 @@ def rename_logger(logger, new_name): logger.name = old_name -class json_progress_bar(progress_bar): +class JsonProgressBar(BaseProgressBar): """Log output in JSON format.""" def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): @@ -179,7 +212,7 @@ def _format_stats(self, stats, epoch=None, update=None): return postfix -class noop_progress_bar(progress_bar): +class NoopProgressBar(BaseProgressBar): """No logging.""" def __init__(self, iterable, epoch=None, prefix=None): @@ -198,7 +231,7 @@ def print(self, stats, tag=None, step=None): pass -class simple_progress_bar(progress_bar): +class SimpleProgressBar(BaseProgressBar): """A minimal logger for non-TTY environments.""" def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): @@ -233,7 +266,7 @@ def print(self, stats, tag=None, step=None): logger.info('{} | {}'.format(self.prefix, postfix)) -class tqdm_progress_bar(progress_bar): +class TqdmProgressBar(BaseProgressBar): """Log to tqdm.""" def __init__(self, iterable, epoch=None, prefix=None): @@ -261,13 +294,12 @@ def print(self, stats, tag=None, step=None): SummaryWriter = None -class tensorboard_log_wrapper(progress_bar): +class TensorboardProgressBarWrapper(BaseProgressBar): """Log to tensorboard.""" - def __init__(self, wrapped_bar, tensorboard_logdir, args): + def __init__(self, wrapped_bar, tensorboard_logdir): self.wrapped_bar = wrapped_bar self.tensorboard_logdir = tensorboard_logdir - self.args = args if SummaryWriter is None: logger.warning( @@ -281,7 +313,6 @@ def _writer(self, key): _writers = _tensorboard_writers if key not in _writers: _writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key)) - _writers[key].add_text('args', str(vars(self.args))) _writers[key].add_text('sys.argv', " ".join(sys.argv)) return _writers[key] diff --git a/fairseq/options.py b/fairseq/options.py index 07034cec70..2bc179a1fa 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -198,7 +198,7 @@ def get_parser(desc, default_task="translation"): parser = argparse.ArgumentParser(allow_abbrev=False) # fmt: off parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar') - parser.add_argument('--log-interval', type=int, default=1000, metavar='N', + parser.add_argument('--log-interval', type=int, default=100, metavar='N', help='log progress every N batches (when progress bar is disabled)') parser.add_argument('--log-format', default=None, help='log format to use', choices=['json', 'none', 'simple', 'tqdm']) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 821379e483..a69b2b2619 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -17,9 +17,9 @@ import torch -from fairseq import checkpoint_utils, distributed_utils, metrics, models, optim, utils +from fairseq import checkpoint_utils, distributed_utils, models, optim, utils from fairseq.file_io import PathManager -from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter +from fairseq.logging import meters, metrics from fairseq.optim import lr_scheduler @@ -226,7 +226,7 @@ def load_checkpoint( # reset TimeMeters, since their start times don't make sense anymore for meter in metrics.get_meters("default"): - if isinstance(meter, TimeMeter): + if isinstance(meter, meters.TimeMeter): meter.reset() else: logger.info("no existing checkpoint found {}".format(filename)) diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index 5a7fe11854..6acfe8899f 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -14,9 +14,10 @@ import torch -from fairseq import checkpoint_utils, options, progress_bar, tasks, utils +from fairseq import checkpoint_utils, options, tasks, utils from fairseq.data import LMContextWindowDataset -from fairseq.meters import StopwatchMeter, TimeMeter +from fairseq.logging import progress_bar +from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.sequence_scorer import SequenceScorer @@ -120,6 +121,12 @@ def main(parsed_args): shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) + progress = progress_bar.progress_bar( + itr, + log_format=args.log_format, + log_interval=args.log_interval, + default_log_format=('tqdm' if not args.no_progress_bar else 'none'), + ) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, args.softmax_batch) @@ -144,82 +151,81 @@ def main(parsed_args): word_stats = dict() - with progress_bar.build_progress_bar(args, itr) as t: - wps_meter = TimeMeter() - - for sample in t: - if 'net_input' not in sample: - continue - - sample = utils.move_to_cuda(sample) if use_cuda else sample - - gen_timer.start() - hypos = scorer.generate(models, sample) - gen_timer.stop(sample['ntokens']) - - for i, hypos_i in enumerate(hypos): - hypo = hypos_i[0] - sample_id = sample['id'][i] - - tokens = hypo['tokens'] - tgt_len = tokens.numel() - pos_scores = hypo['positional_scores'].float() - - if args.add_bos_token: - assert hypo['tokens'][0].item() == task.target_dictionary.bos() - tokens = tokens[1:] - pos_scores = pos_scores[1:] - - skipped_toks = 0 - if bpe_toks is not None: - for i in range(tgt_len - 1): - if tokens[i].item() in bpe_toks: - skipped_toks += 1 - pos_scores[i + 1] += pos_scores[i] - pos_scores[i] = 0 - - inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) - if inf_scores.any(): + wps_meter = TimeMeter() + + for sample in progress: + if 'net_input' not in sample: + continue + + sample = utils.move_to_cuda(sample) if use_cuda else sample + + gen_timer.start() + hypos = scorer.generate(models, sample) + gen_timer.stop(sample['ntokens']) + + for i, hypos_i in enumerate(hypos): + hypo = hypos_i[0] + sample_id = sample['id'][i] + + tokens = hypo['tokens'] + tgt_len = tokens.numel() + pos_scores = hypo['positional_scores'].float() + + if args.add_bos_token: + assert hypo['tokens'][0].item() == task.target_dictionary.bos() + tokens = tokens[1:] + pos_scores = pos_scores[1:] + + skipped_toks = 0 + if bpe_toks is not None: + for i in range(tgt_len - 1): + if tokens[i].item() in bpe_toks: + skipped_toks += 1 + pos_scores[i + 1] += pos_scores[i] + pos_scores[i] = 0 + + inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) + if inf_scores.any(): + logger.info( + 'skipping tokens with inf scores:', + task.target_dictionary.string(tokens[inf_scores.nonzero()]) + ) + pos_scores = pos_scores[(~inf_scores).nonzero()] + score_sum += pos_scores.sum().cpu() + count += pos_scores.numel() - skipped_toks + + if args.output_word_probs or args.output_word_stats: + w = '' + word_prob = [] + is_bpe = False + for i in range(len(tokens)): + w_ind = tokens[i].item() + w += task.source_dictionary[w_ind] + if bpe_toks is not None and w_ind in bpe_toks: + w = w[:-bpe_len] + is_bpe = True + else: + word_prob.append((w, pos_scores[i].item())) + + next_prob = None + ind = i + 1 + while ind < len(tokens): + if pos_scores[ind].item() != 0: + next_prob = pos_scores[ind] + break + ind += 1 + + word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob) + is_bpe = False + w = '' + if args.output_word_probs: logger.info( - 'skipping tokens with inf scores:', - task.target_dictionary.string(tokens[inf_scores.nonzero()]) + str(int(sample_id)) + " " + + ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)) ) - pos_scores = pos_scores[(~inf_scores).nonzero()] - score_sum += pos_scores.sum().cpu() - count += pos_scores.numel() - skipped_toks - - if args.output_word_probs or args.output_word_stats: - w = '' - word_prob = [] - is_bpe = False - for i in range(len(tokens)): - w_ind = tokens[i].item() - w += task.source_dictionary[w_ind] - if bpe_toks is not None and w_ind in bpe_toks: - w = w[:-bpe_len] - is_bpe = True - else: - word_prob.append((w, pos_scores[i].item())) - - next_prob = None - ind = i + 1 - while ind < len(tokens): - if pos_scores[ind].item() != 0: - next_prob = pos_scores[ind] - break - ind += 1 - - word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob) - is_bpe = False - w = '' - if args.output_word_probs: - logger.info( - str(int(sample_id)) + " " - + ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)) - ) - - wps_meter.update(sample['ntokens']) - t.log({'wps': round(wps_meter.avg)}) + + wps_meter.update(sample['ntokens']) + progress.log({'wps': round(wps_meter.avg)}) avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2 logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format( diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 53c2a736ca..61c41e35c6 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -14,8 +14,9 @@ import torch -from fairseq import bleu, checkpoint_utils, options, progress_bar, tasks, utils -from fairseq.meters import StopwatchMeter, TimeMeter +from fairseq import bleu, checkpoint_utils, options, tasks, utils +from fairseq.logging import progress_bar +from fairseq.logging.meters import StopwatchMeter, TimeMeter def main(args): @@ -100,6 +101,12 @@ def _main(args, output_file): shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) + progress = progress_bar.progress_bar( + itr, + log_format=args.log_format, + log_interval=args.log_interval, + default_log_format=('tqdm' if not args.no_progress_bar else 'none'), + ) # Initialize generator gen_timer = StopwatchMeter() @@ -112,106 +119,105 @@ def _main(args, output_file): scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) num_sentences = 0 has_target = True - with progress_bar.build_progress_bar(args, itr) as t: - wps_meter = TimeMeter() - for sample in t: - sample = utils.move_to_cuda(sample) if use_cuda else sample - if 'net_input' not in sample: - continue - - prefix_tokens = None - if args.prefix_size > 0: - prefix_tokens = sample['target'][:, :args.prefix_size] - - gen_timer.start() - hypos = task.inference_step(generator, models, sample, prefix_tokens) - num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) - gen_timer.stop(num_generated_tokens) - - for i, sample_id in enumerate(sample['id'].tolist()): - has_target = sample['target'] is not None - - # Remove padding - src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) - target_tokens = None + wps_meter = TimeMeter() + for sample in progress: + sample = utils.move_to_cuda(sample) if use_cuda else sample + if 'net_input' not in sample: + continue + + prefix_tokens = None + if args.prefix_size > 0: + prefix_tokens = sample['target'][:, :args.prefix_size] + + gen_timer.start() + hypos = task.inference_step(generator, models, sample, prefix_tokens) + num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) + gen_timer.stop(num_generated_tokens) + + for i, sample_id in enumerate(sample['id'].tolist()): + has_target = sample['target'] is not None + + # Remove padding + src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) + target_tokens = None + if has_target: + target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu() + + # Either retrieve the original sentences or regenerate them from tokens. + if align_dict is not None: + src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) + target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id) + else: + if src_dict is not None: + src_str = src_dict.string(src_tokens, args.remove_bpe) + else: + src_str = "" if has_target: - target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu() + target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) - # Either retrieve the original sentences or regenerate them from tokens. - if align_dict is not None: - src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) - target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id) - else: - if src_dict is not None: - src_str = src_dict.string(src_tokens, args.remove_bpe) - else: - src_str = "" - if has_target: - target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) + if not args.quiet: + if src_dict is not None: + print('S-{}\t{}'.format(sample_id, src_str), file=output_file) + if has_target: + print('T-{}\t{}'.format(sample_id, target_str), file=output_file) + + # Process top predictions + for j, hypo in enumerate(hypos[i][:args.nbest]): + hypo_tokens, hypo_str, alignment = utils.post_process_prediction( + hypo_tokens=hypo['tokens'].int().cpu(), + src_str=src_str, + alignment=hypo['alignment'], + align_dict=align_dict, + tgt_dict=tgt_dict, + remove_bpe=args.remove_bpe, + ) if not args.quiet: - if src_dict is not None: - print('S-{}\t{}'.format(sample_id, src_str), file=output_file) - if has_target: - print('T-{}\t{}'.format(sample_id, target_str), file=output_file) - - # Process top predictions - for j, hypo in enumerate(hypos[i][:args.nbest]): - hypo_tokens, hypo_str, alignment = utils.post_process_prediction( - hypo_tokens=hypo['tokens'].int().cpu(), - src_str=src_str, - alignment=hypo['alignment'], - align_dict=align_dict, - tgt_dict=tgt_dict, - remove_bpe=args.remove_bpe, - ) - - if not args.quiet: - score = hypo['score'] / math.log(2) # convert to base 2 - print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str), file=output_file) - print('P-{}\t{}'.format( + score = hypo['score'] / math.log(2) # convert to base 2 + print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str), file=output_file) + print('P-{}\t{}'.format( + sample_id, + ' '.join(map( + lambda x: '{:.4f}'.format(x), + # convert from base e to base 2 + hypo['positional_scores'].div_(math.log(2)).tolist(), + )) + ), file=output_file) + + if args.print_alignment: + print('A-{}\t{}'.format( sample_id, - ' '.join(map( - lambda x: '{:.4f}'.format(x), - # convert from base e to base 2 - hypo['positional_scores'].div_(math.log(2)).tolist(), - )) + ' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment]) ), file=output_file) - if args.print_alignment: - print('A-{}\t{}'.format( - sample_id, - ' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment]) - ), file=output_file) - - if args.print_step: - print('I-{}\t{}'.format(sample_id, hypo['steps']), file=output_file) - - if getattr(args, 'retain_iter_history', False): - for step, h in enumerate(hypo['history']): - _, h_str, _ = utils.post_process_prediction( - hypo_tokens=h['tokens'].int().cpu(), - src_str=src_str, - alignment=None, - align_dict=None, - tgt_dict=tgt_dict, - remove_bpe=None, - ) - print('E-{}_{}\t{}'.format(sample_id, step, h_str), file=output_file) - - # Score only the top hypothesis - if has_target and j == 0: - if align_dict is not None or args.remove_bpe is not None: - # Convert back to tokens for evaluation with unk replacement and/or without BPE - target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True) - if hasattr(scorer, 'add_string'): - scorer.add_string(target_str, hypo_str) - else: - scorer.add(target_tokens, hypo_tokens) - - wps_meter.update(num_generated_tokens) - t.log({'wps': round(wps_meter.avg)}) - num_sentences += sample['nsentences'] + if args.print_step: + print('I-{}\t{}'.format(sample_id, hypo['steps']), file=output_file) + + if getattr(args, 'retain_iter_history', False): + for step, h in enumerate(hypo['history']): + _, h_str, _ = utils.post_process_prediction( + hypo_tokens=h['tokens'].int().cpu(), + src_str=src_str, + alignment=None, + align_dict=None, + tgt_dict=tgt_dict, + remove_bpe=None, + ) + print('E-{}_{}\t{}'.format(sample_id, step, h_str), file=output_file) + + # Score only the top hypothesis + if has_target and j == 0: + if align_dict is not None or args.remove_bpe is not None: + # Convert back to tokens for evaluation with unk replacement and/or without BPE + target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True) + if hasattr(scorer, 'add_string'): + scorer.add_string(target_str, hypo_str) + else: + scorer.add(target_tokens, hypo_tokens) + + wps_meter.update(num_generated_tokens) + progress.log({'wps': round(wps_meter.avg)}) + num_sentences += sample['nsentences'] logger.info('NOTE: hypothesis and token scores are output in base 2') logger.info('Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index cb70443893..61da10c4c2 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -16,12 +16,10 @@ import numpy as np import torch -from fairseq import ( - checkpoint_utils, distributed_utils, metrics, options, progress_bar, tasks, utils -) +from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import iterators +from fairseq.logging import meters, metrics, progress_bar from fairseq.trainer import Trainer -from fairseq.meters import StopwatchMeter logging.basicConfig( @@ -86,7 +84,7 @@ def main(args, init_distributed=False): max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = trainer.get_lr() - train_meter = StopwatchMeter() + train_meter = meters.StopwatchMeter() train_meter.start() valid_subsets = args.valid_subset.split(',') while ( @@ -158,8 +156,15 @@ def train(args, trainer, task, epoch_itr): else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) - progress = progress_bar.build_progress_bar( - args, itr, epoch_itr.epoch, no_progress_bar='simple', + progress = progress_bar.progress_bar( + itr, + log_format=args.log_format, + log_interval=args.log_interval, + epoch=epoch_itr.epoch, + tensorboard_logdir=( + args.tensorboard_logdir if distributed_utils.is_master(args) else None + ), + default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch @@ -229,10 +234,16 @@ def validate(args, trainer, task, epoch_itr, subsets): shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) - progress = progress_bar.build_progress_bar( - args, itr, epoch_itr.epoch, - prefix='valid on \'{}\' subset'.format(subset), - no_progress_bar='simple' + progress = progress_bar.progress_bar( + itr, + log_format=args.log_format, + log_interval=args.log_interval, + epoch=epoch_itr.epoch, + prefix=f"valid on '{subset}' subset", + tensorboard_logdir=( + args.tensorboard_logdir if distributed_utils.is_master(args) else None + ), + default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # create a new root metrics aggregator so validation metrics diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index 94fca20fdf..5306551e27 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -10,7 +10,8 @@ import torch -from fairseq import checkpoint_utils, metrics, options, progress_bar, utils +from fairseq import checkpoint_utils, options, utils +from fairseq.logging import metrics, progress_bar logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', @@ -80,10 +81,12 @@ def main(args, override_args=None): seed=args.seed, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) - progress = progress_bar.build_progress_bar( - args, itr, - prefix='valid on \'{}\' subset'.format(subset), - no_progress_bar='simple' + progress = progress_bar.progress_bar( + itr, + log_format=args.log_format, + log_interval=args.log_interval, + prefix=f"valid on '{subset}' subset", + default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) log_outputs = []