Skip to content

Commit

Permalink
build_generator api changes for the scripted SequenceGenerator (#697)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/translate#697

Pull Request resolved: #1922

Pull Request resolved: fairinternal/fairseq-py#1117

We are planning to deprecate the original SequenceGenerator and use the ScriptSequenceGenerator in the Fairseq. Due to the change of scripted Sequence Generator constructor, I change `build_generator` interface in Fairseq, pyspeech and pytorch translate.

Reviewed By: myleott

Differential Revision: D20683836

fbshipit-source-id: d01d891ebd067fe44291d3a0a784935edaf66acd
  • Loading branch information
liuchen9494 authored and facebook-github-bot committed Apr 3, 2020
1 parent f20dc23 commit cd2555a
Show file tree
Hide file tree
Showing 10 changed files with 14 additions and 12 deletions.
2 changes: 1 addition & 1 deletion examples/speech_recognition/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def main(args):

# Initialize generator
gen_timer = meters.StopwatchMeter()
generator = task.build_generator(args)
generator = task.build_generator(models, args)

num_sentences = 0

Expand Down
4 changes: 2 additions & 2 deletions examples/speech_recognition/tasks/speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def load_dataset(self, split, combine=False, **kwargs):
data_json_path = os.path.join(self.args.data, "{}.json".format(split))
self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict)

def build_generator(self, args):
def build_generator(self, models, args):
w2l_decoder = getattr(args, "w2l_decoder", None)
if w2l_decoder == "viterbi":
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
Expand All @@ -119,7 +119,7 @@ def build_generator(self, args):

return W2lKenLMDecoder(args, self.target_dictionary)
else:
return super().build_generator(args)
return super().build_generator(models, args)

@property
def target_dictionary(self):
Expand Down
2 changes: 1 addition & 1 deletion fairseq/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def generate(
gen_args.beam = beam
for k, v in kwargs.items():
setattr(gen_args, k, v)
generator = self.task.build_generator(gen_args)
generator = self.task.build_generator(self.models, gen_args)

results = []
for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
Expand Down
2 changes: 1 addition & 1 deletion fairseq/models/bart/hub_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def generate(self, tokens: List[torch.LongTensor], beam: int = 5, verbose: bool
gen_args.beam = beam
for k, v in kwargs.items():
setattr(gen_args, k, v)
generator = self.task.build_generator(gen_args)
generator = self.task.build_generator([self.model], gen_args)
translations = self.task.inference_step(
generator,
[self.model],
Expand Down
2 changes: 1 addition & 1 deletion fairseq/tasks/fairseq_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def build_criterion(self, args):

return criterions.build_criterion(args, self)

def build_generator(self, args):
def build_generator(self, models, args):
if getattr(args, "score_reference", False):
from fairseq.sequence_scorer import SequenceScorer

Expand Down
5 changes: 3 additions & 2 deletions fairseq/tasks/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def build_dataset_for_inference(self, src_tokens, src_lengths):
return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary)

def build_model(self, args):
model = super().build_model(args)
if getattr(args, 'eval_bleu', False):
assert getattr(args, 'eval_bleu_detok', None) is not None, (
'--eval-bleu-detok is required if using --eval-bleu; '
Expand All @@ -274,8 +275,8 @@ def build_model(self, args):
))

gen_args = json.loads(getattr(args, 'eval_bleu_args', '{}') or '{}')
self.sequence_generator = self.build_generator(Namespace(**gen_args))
return super().build_model(args)
self.sequence_generator = self.build_generator([model], Namespace(**gen_args))
return model

def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
Expand Down
2 changes: 1 addition & 1 deletion fairseq/tasks/translation_from_pretrained_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
append_source_id=True
)

def build_generator(self, args):
def build_generator(self, models, args):
if getattr(args, 'score_reference', False):
from fairseq.sequence_scorer import SequenceScorer
return SequenceScorer(
Expand Down
3 changes: 2 additions & 1 deletion fairseq/tasks/translation_lev.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def _full_mask(target_tokens):
else:
raise NotImplementedError

def build_generator(self, args):
def build_generator(self, models, args):
# add models input to match the API for SequenceGenerator
from fairseq.iterative_refinement_generator import IterativeRefinementGenerator
return IterativeRefinementGenerator(
self.target_dictionary,
Expand Down
2 changes: 1 addition & 1 deletion fairseq_cli/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _main(args, output_file):

# Initialize generator
gen_timer = StopwatchMeter()
generator = task.build_generator(args)
generator = task.build_generator(models, args)

# Handle tokenization and BPE
tokenizer = encoders.build_tokenizer(args)
Expand Down
2 changes: 1 addition & 1 deletion fairseq_cli/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def main(args):
model.cuda()

# Initialize generator
generator = task.build_generator(args)
generator = task.build_generator(models, args)

# Handle tokenization and BPE
tokenizer = encoders.build_tokenizer(args)
Expand Down

0 comments on commit cd2555a

Please # to comment.