diff --git a/examples/language_model/README.md b/examples/language_model/README.md index 3992e2ca1b..66c5cb8e90 100644 --- a/examples/language_model/README.md +++ b/examples/language_model/README.md @@ -26,6 +26,10 @@ torch.hub.list('pytorch/fairseq') # [..., 'transformer_lm.wmt19.en', ...] # Load an English LM trained on WMT'19 News Crawl data en_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt19.en', tokenizer='moses', bpe='fastbpe') +en_lm.eval() # disable dropout + +# Move model to GPU +en_lm.cuda() # Sample from the language model en_lm.sample('Barack Obama', beam=1, sampling=True, sampling_topk=10, temperature=0.8) diff --git a/examples/translation/README.md b/examples/translation/README.md index db1844df55..055a508a28 100644 --- a/examples/translation/README.md +++ b/examples/translation/README.md @@ -34,13 +34,21 @@ torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt16.en-de', ... ] # Load a transformer trained on WMT'16 En-De en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt16.en-de', tokenizer='moses', bpe='subword_nmt') +en2de.eval() # disable dropout # The underlying model is available under the *models* attribute assert isinstance(en2de.models[0], fairseq.models.transformer.TransformerModel) +# Move model to GPU for faster translation +en2de.cuda() + # Translate a sentence en2de.translate('Hello world!') # 'Hallo Welt!' + +# Batched translation +en2de.translate(['Hello world!', 'The cat sat on the mat.']) +# ['Hallo Welt!', 'Die Katze saß auf der Matte.'] ``` Loading custom models: diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index 7fa5a4fffb..0bce075306 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -7,6 +7,7 @@ import argparse import copy import os +from typing import List, Dict, Iterator, Tuple, Any import torch from torch import nn @@ -106,6 +107,10 @@ def __init__(self, args, task, models): self.tokenizer = encoders.build_tokenizer(args) self.bpe = encoders.build_bpe(args) + self.max_positions = utils.resolve_max_positions( + self.task.max_positions(), *[model.max_positions() for model in models] + ) + # this is useful for determining the device self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float)) @@ -113,21 +118,35 @@ def __init__(self, args, task, models): def device(self): return self._float_tensor.device - def translate(self, sentence: str, beam: int = 5, verbose: bool = False, **kwargs) -> str: - return self.sample(sentence, beam, verbose, **kwargs) + def translate(self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs) -> List[str]: + return self.sample(sentences, beam, verbose, **kwargs) - def sample(self, sentence: str, beam: int = 1, verbose: bool = False, **kwargs) -> str: - input = self.encode(sentence) - hypo = self.generate(input, beam, verbose, **kwargs)[0]['tokens'] - return self.decode(hypo) + def sample(self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs) -> List[str]: + if isinstance(sentences, str): + return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0] + tokenized_sentences = [self.encode(sentence) for sentence in sentences] + batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs) + return [self.decode(hypos[0]['tokens']) for hypos in batched_hypos] - def score(self, sentence: str, **kwargs): + def score(self, sentences: List[str], **kwargs): + if isinstance(sentences, str): + return self.score([sentences], **kwargs)[0] # NOTE: this doesn't support translation tasks currently - input = self.encode(sentence) - return self.generate(input, score_reference=True, **kwargs)[0] - - def generate(self, tokens: torch.LongTensor, beam: int = 5, verbose: bool = False, **kwargs) -> torch.LongTensor: - sample = self._build_sample(tokens) + tokenized_sentences = [self.encode(sentence) for sentence in sentences] + return [hypos[0] for hypos in self.generate(tokenized_sentences, score_reference=True, **kwargs)] + + def generate( + self, + tokenized_sentences: List[torch.LongTensor], + beam: int = 5, + verbose: bool = False, + skip_invalid_size_inputs=False, + **kwargs + ) -> List[List[Dict[str, torch.Tensor]]]: + if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1: + return self.generate( + tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs + )[0] # build generator using current args as well as any kwargs gen_args = copy.copy(self.args) @@ -136,30 +155,35 @@ def generate(self, tokens: torch.LongTensor, beam: int = 5, verbose: bool = Fals setattr(gen_args, k, v) generator = self.task.build_generator(gen_args) - translations = self.task.inference_step(generator, self.models, sample) - - if verbose: - src_str_with_unk = self.string(tokens) - print('S\t{}'.format(src_str_with_unk)) + results = [] + for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs): + batch = utils.apply_to_sample(lambda t: t.to(self.device), batch) + translations = self.task.inference_step(generator, self.models, batch) + for id, hypos in zip(batch["id"].tolist(), translations): + results.append((id, hypos)) - def getarg(name, default): - return getattr(gen_args, name, getattr(self.args, name, default)) + # sort output to match input order + outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])] - # Process top predictions - hypos = translations[0] if verbose: - for hypo in hypos: - hypo_str = self.decode(hypo['tokens']) - print('H\t{}\t{}'.format(hypo['score'], hypo_str)) - print('P\t{}'.format( - ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist())) - )) - if hypo['alignment'] is not None and getarg('print_alignment', False): - print('A\t{}'.format( - ' '.join(map(lambda x: str(utils.item(x)), hypo['alignment'].int().cpu())) - )) - return hypos + def getarg(name, default): + return getattr(gen_args, name, getattr(self.args, name, default)) + + for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs): + src_str_with_unk = self.string(source_tokens) + print('S\t{}'.format(src_str_with_unk)) + for hypo in target_hypotheses: + hypo_str = self.decode(hypo['tokens']) + print('H\t{}\t{}'.format(hypo['score'], hypo_str)) + print('P\t{}'.format( + ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist())) + )) + if hypo['alignment'] is not None and getarg('print_alignment', False): + print('A\t{}'.format( + ' '.join(map(lambda x: str(utils.item(x)), hypo['alignment'].int().cpu())) + )) + return outputs def encode(self, sentence: str) -> torch.LongTensor: sentence = self.tokenize(sentence) @@ -197,15 +221,18 @@ def binarize(self, sentence: str) -> torch.LongTensor: def string(self, tokens: torch.LongTensor) -> str: return self.tgt_dict.string(tokens) - def _build_sample(self, src_tokens: torch.LongTensor): - assert torch.is_tensor(src_tokens) - dataset = self.task.build_dataset_for_inference([src_tokens], [src_tokens.numel()]) - sample = dataset.collater([dataset[0]]) - sample = utils.apply_to_sample( - lambda tensor: tensor.to(self.device), - sample - ) - return sample + def _build_batches( + self, tokens: List[List[int]], skip_invalid_size_inputs: bool + ) -> Iterator[Dict[str, Any]]: + lengths = torch.LongTensor([t.numel() for t in tokens]) + batch_iterator = self.task.get_batch_iterator( + dataset=self.task.build_dataset_for_inference(tokens, lengths), + max_tokens=self.args.max_tokens, + max_sentences=self.args.max_sentences, + max_positions=self.max_positions, + ignore_invalid_inputs=skip_invalid_size_inputs, + ).next_epoch_itr(shuffle=False) + return batch_iterator class BPEHubInterface(object):