diff --git a/examples/speech_recognition/tasks/speech_recognition.py b/examples/speech_recognition/tasks/speech_recognition.py index dde0b12577..1181c9aef5 100644 --- a/examples/speech_recognition/tasks/speech_recognition.py +++ b/examples/speech_recognition/tasks/speech_recognition.py @@ -113,7 +113,7 @@ def load_dataset(self, split, combine=False, **kwargs): data_json_path = os.path.join(self.args.data, "{}.json".format(split)) self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict) - def build_generator(self, models, args): + def build_generator(self, models, args, **unused): w2l_decoder = getattr(args, "w2l_decoder", None) if w2l_decoder == "viterbi": from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder diff --git a/fairseq/options.py b/fairseq/options.py index e1df860fbe..31ed28a80e 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -380,6 +380,11 @@ def add_generation_args(parser): help='if set, uses attention feedback to compute and print alignment to source tokens') group.add_argument('--print-step', action='store_true') + group.add_argument('--lm-path', default=None, type=str, metavar='PATH', + help='path to lm checkpoint for lm fusion') + group.add_argument('--lm-weight', default=0.0, type=float, metavar='N', + help='weight for lm probs for lm fusion') + # arguments for iterative refinement generator group.add_argument('--iter-decode-eos-penalty', default=0.0, type=float, metavar='N', help='if > 0.0, it penalized early-stopping in decoding.') diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 965594cd6e..ff45c7dfb7 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -33,6 +33,8 @@ def __init__( search_strategy=None, eos=None, symbols_to_strip_from_output=None, + lm_model=None, + lm_weight=1.0 ): """Generates translations of a given source sentence. @@ -94,6 +96,11 @@ def __init__( self.model.eval() + self.lm_model = lm_model + self.lm_weight = lm_weight + if self.lm_model is not None: + self.lm_model.eval() + def cuda(self): self.model.cuda() return self @@ -292,6 +299,15 @@ def _generate( incremental_states, self.temperature, ) + + if self.lm_model is not None: + lm_out = self.lm_model(tokens[:, : step + 1]) + probs = self.lm_model.get_normalized_probs( + lm_out, log_probs=True, sample=None + ) + probs = probs[:, -1, :] * self.lm_weight + lprobs += probs + lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) lprobs[:, self.pad] = -math.inf # never select pad @@ -820,9 +836,11 @@ def forward_decoder( avg_attn = attn else: avg_attn.add_(attn) + avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log( self.models_size ) + if avg_attn is not None: avg_attn.div_(self.models_size) return avg_probs, avg_attn diff --git a/fairseq/tasks/translation_from_pretrained_bart.py b/fairseq/tasks/translation_from_pretrained_bart.py index b3c9f8e440..4d574ffc82 100644 --- a/fairseq/tasks/translation_from_pretrained_bart.py +++ b/fairseq/tasks/translation_from_pretrained_bart.py @@ -84,7 +84,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): append_source_id=True ) - def build_generator(self, models, args): + def build_generator(self, models, args, **unused): if getattr(args, 'score_reference', False): from fairseq.sequence_scorer import SequenceScorer return SequenceScorer( diff --git a/fairseq/tasks/translation_lev.py b/fairseq/tasks/translation_lev.py index be362a1881..18ac0ca385 100644 --- a/fairseq/tasks/translation_lev.py +++ b/fairseq/tasks/translation_lev.py @@ -128,7 +128,7 @@ def _full_mask(target_tokens): else: raise NotImplementedError - def build_generator(self, models, args): + def build_generator(self, models, args, **unused): # add models input to match the API for SequenceGenerator from fairseq.iterative_refinement_generator import IterativeRefinementGenerator return IterativeRefinementGenerator( diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 786f699432..0cf09feaee 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -7,6 +7,8 @@ Translate pre-processed data with a trained model. """ +import ast +from itertools import chain import logging import math import os @@ -78,17 +80,39 @@ def _main(args, output_file): src_dict = None tgt_dict = task.target_dictionary + overrides = ast.literal_eval(args.model_overrides) + # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(args.path), - arg_overrides=eval(args.model_overrides), + arg_overrides=overrides, task=task, suffix=getattr(args, "checkpoint_suffix", ""), ) + if args.lm_path is not None: + overrides['data'] = args.data + + try: + lms, _ = checkpoint_utils.load_model_ensemble( + [args.lm_path], + arg_overrides=overrides, + task=None, + ) + except: + logger.warning(f"Failed to load language model! Please make sure that the language model dict is the same " + f"as target dict and is located in the data dir ({args.data})") + raise + + assert len(lms) == 1 + else: + lms = [None] + # Optimize ensemble for generation - for model in models: + for model in chain(models, lms): + if model is None: + continue model.prepare_for_inference_(args) if args.fp16: model.half() @@ -124,7 +148,12 @@ def _main(args, output_file): # Initialize generator gen_timer = StopwatchMeter() - generator = task.build_generator(models, args) + + extra_gen_cls_kwargs = { + 'lm_model': lms[0], + 'lm_weight': args.lm_weight + } + generator = task.build_generator(models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(args) @@ -269,9 +298,11 @@ def decode_fn(x): if has_target: if args.bpe and not args.sacrebleu: if args.remove_bpe: - logger.warning("BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization") + logger.warning( + "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization") else: - logger.warning("If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization") + logger.warning( + "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization") # use print to be consistent with other main outputs: S-, H-, T-, D- and so on print( 'Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()),