diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index 4b9151c79a..ffa0f1e753 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -208,7 +208,7 @@ def main(args): # Initialize generator gen_timer = meters.StopwatchMeter() - generator = task.build_generator(args) + generator = task.build_generator(models, args) num_sentences = 0 diff --git a/examples/speech_recognition/tasks/speech_recognition.py b/examples/speech_recognition/tasks/speech_recognition.py index bd671e46dd..b555cfeefa 100644 --- a/examples/speech_recognition/tasks/speech_recognition.py +++ b/examples/speech_recognition/tasks/speech_recognition.py @@ -108,7 +108,7 @@ def load_dataset(self, split, combine=False, **kwargs): data_json_path = os.path.join(self.args.data, "{}.json".format(split)) self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict) - def build_generator(self, args): + def build_generator(self, models, args): w2l_decoder = getattr(args, "w2l_decoder", None) if w2l_decoder == "viterbi": from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder @@ -119,7 +119,7 @@ def build_generator(self, args): return W2lKenLMDecoder(args, self.target_dictionary) else: - return super().build_generator(args) + return super().build_generator(models, args) @property def target_dictionary(self): diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index ee990e6cd6..92a4c9092f 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -157,7 +157,7 @@ def generate( gen_args.beam = beam for k, v in kwargs.items(): setattr(gen_args, k, v) - generator = self.task.build_generator(gen_args) + generator = self.task.build_generator(self.models, gen_args) results = [] for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs): diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py index 20d691bef7..f87291bfbd 100644 --- a/fairseq/models/bart/hub_interface.py +++ b/fairseq/models/bart/hub_interface.py @@ -115,7 +115,7 @@ def generate(self, tokens: List[torch.LongTensor], beam: int = 5, verbose: bool gen_args.beam = beam for k, v in kwargs.items(): setattr(gen_args, k, v) - generator = self.task.build_generator(gen_args) + generator = self.task.build_generator([self.model], gen_args) translations = self.task.inference_step( generator, [self.model], diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 61466657b5..438b3a10b3 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -225,7 +225,7 @@ def build_criterion(self, args): return criterions.build_criterion(args, self) - def build_generator(self, args): + def build_generator(self, models, args): if getattr(args, "score_reference", False): from fairseq.sequence_scorer import SequenceScorer diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index ce81da96fd..e55b372683 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -261,6 +261,7 @@ def build_dataset_for_inference(self, src_tokens, src_lengths): return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary) def build_model(self, args): + model = super().build_model(args) if getattr(args, 'eval_bleu', False): assert getattr(args, 'eval_bleu_detok', None) is not None, ( '--eval-bleu-detok is required if using --eval-bleu; ' @@ -274,8 +275,8 @@ def build_model(self, args): )) gen_args = json.loads(getattr(args, 'eval_bleu_args', '{}') or '{}') - self.sequence_generator = self.build_generator(Namespace(**gen_args)) - return super().build_model(args) + self.sequence_generator = self.build_generator([model], Namespace(**gen_args)) + return model def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = super().valid_step(sample, model, criterion) diff --git a/fairseq/tasks/translation_from_pretrained_bart.py b/fairseq/tasks/translation_from_pretrained_bart.py index 4791394ebd..dfa8605563 100644 --- a/fairseq/tasks/translation_from_pretrained_bart.py +++ b/fairseq/tasks/translation_from_pretrained_bart.py @@ -79,7 +79,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): append_source_id=True ) - def build_generator(self, args): + def build_generator(self, models, args): if getattr(args, 'score_reference', False): from fairseq.sequence_scorer import SequenceScorer return SequenceScorer( diff --git a/fairseq/tasks/translation_lev.py b/fairseq/tasks/translation_lev.py index 093c340cab..d07c271569 100644 --- a/fairseq/tasks/translation_lev.py +++ b/fairseq/tasks/translation_lev.py @@ -126,7 +126,8 @@ def _full_mask(target_tokens): else: raise NotImplementedError - def build_generator(self, args): + def build_generator(self, models, args): + # add models input to match the API for SequenceGenerator from fairseq.iterative_refinement_generator import IterativeRefinementGenerator return IterativeRefinementGenerator( self.target_dictionary, diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 26cec20117..c439d0973f 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -111,7 +111,7 @@ def _main(args, output_file): # Initialize generator gen_timer = StopwatchMeter() - generator = task.build_generator(args) + generator = task.build_generator(models, args) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(args) diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index 12efdf8fbb..cfcd2c535b 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -112,7 +112,7 @@ def main(args): model.cuda() # Initialize generator - generator = task.build_generator(args) + generator = task.build_generator(models, args) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(args)