diff --git a/README.md b/README.md
index 45dce65cf0..c39ff22c97 100644
--- a/README.md
+++ b/README.md
@@ -6,6 +6,7 @@ modeling and other text generation tasks.
### What's New:
+- September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
- August 2019: [WMT'19 models released](examples/wmt19/README.md)
- July 2019: fairseq relicensed under MIT license
- July 2019: [RoBERTa models and code released](examples/roberta/README.md)
@@ -32,6 +33,13 @@ Fairseq provides reference implementations of various sequence-to-sequence model
- [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
+- **Non-autoregressive Transformers**
+ - Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
+ - Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
+ - Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
+ - Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
+ - [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
+
**Additionally:**
- multi-GPU (distributed) training on one machine or across multiple machines
@@ -50,7 +58,7 @@ translation and language modeling datasets.
# Requirements and Installation
-* [PyTorch](http://pytorch.org/) version >= 1.1.0
+* [PyTorch](http://pytorch.org/) version >= 1.2.0
* Python version >= 3.5
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library with the `--cuda_ext` option
@@ -92,6 +100,7 @@ as well as example training and evaluation commands.
- [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
We also have more detailed READMEs to reproduce results from specific papers:
+- [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
- [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
diff --git a/examples/nonautoregressive_translation/README.md b/examples/nonautoregressive_translation/README.md
new file mode 100644
index 0000000000..5f030868cc
--- /dev/null
+++ b/examples/nonautoregressive_translation/README.md
@@ -0,0 +1,90 @@
+# Non-autoregressive Neural Machine Translation (NAT)
+
+This page mainly includes instructions for reproducing results from the paper
+* [Levenshtein Transformer (Gu et al., 2019)](https://arxiv.org/abs/1905.11006).
+
+We also provided our own implementations for several popular non-autoregressive-based models as reference:
+* [Non-Autoregressive Neural Machine Translation (Gu et al., 2017)](https://arxiv.org/abs/1711.02281)
+* [Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)](https://arxiv.org/abs/1802.06901)
+* [Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)](https://arxiv.org/abs/1902.03249)
+* [Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)](https://arxiv.org/abs/1904.09324v2)
+
+## Dataset
+
+First, follow the [instructions to download and preprocess the WMT'14 En-De dataset](../translation#prepare-wmt14en2desh).
+Make sure to learn a joint vocabulary by passing the `--joined-dictionary` option to `fairseq-preprocess`.
+
+### Knowledge Distillation
+Following [Gu et al. 2019](https://arxiv.org/abs/1905.11006), [knowledge distillation](https://arxiv.org/abs/1606.07947) from an autoregressive model can effectively simplify the training data distribution, which is sometimes essential for NAT-based models to learn good translations.
+The easiest way of performing distillation is to follow the [instructions of training a standard transformer model](../translation) on the same data, and then decode the training set to produce a distillation dataset for NAT.
+
+### Download
+We also provided the preprocessed [original](http://dl.fbaipublicfiles.com/nat/original_dataset.zip) and [distillation](http://dl.fbaipublicfiles.com/nat/distill_dataset.zip) datasets. Please build the binarized dataset on your own.
+
+
+## Train a model
+
+Then we can train a nonautoregressive model using the `translation_lev` task and a new criterion `nat_loss`.
+Use the `--noise` flag to specify the input noise used on the target sentences.
+In default, we run the task for *Levenshtein Transformer*, with `--noise='random_delete'`. Full scripts to run other models can also be found [here](./scripts.md).
+
+The following command will train a *Levenshtein Transformer* on the binarized dataset.
+
+```bash
+fairseq-train \
+ data-bin/wmt14_en_de_distill \
+ --save-dir checkpoints \
+ --ddp-backend=no_c10d \
+ --task translation_lev \
+ --criterion nat_loss \
+ --arch levenshtein_transformer \
+ --noise random_delete \
+ --share-all-embeddings \
+ --optimizer adam --adam-betas '(0.9,0.98)' \
+ --lr 0.0005 --lr-scheduler inverse_sqrt \
+ --min-lr '1e-09' --warmup-updates 10000 \
+ --warmup-init-lr '1e-07' --label-smoothing 0.1 \
+ --dropout 0.3 --weight-decay 0.01 \
+ --decoder-learned-pos \
+ --encoder-learned-pos \
+ --apply-bert-init \
+ --log-format 'simple' --log-interval 100 \
+ --fixed-validation-seed 7 \
+ --max-tokens 8000 \
+ --save-interval-updates 10000 \
+ --max-update 300000
+```
+
+## Translate
+
+Once a model is trained, we can generate translations using an `iterative_refinement_generator` which will based on the model's initial output and iteratively read and greedily refine the translation until (1) the model predicts the same translations for two consecutive iterations; or (2) the generator reaches the maximum iterations (`--iter-decode-max-iter`). Use `--print-step` to check the actual # of iteration for each sentence.
+
+For *Levenshtein Transformer*, it sometimes helps to apply a `--iter-decode-eos-penalty` (typically, 0~3) to penalize the model finishing generation too early and generating too short translations.
+
+
+For example, to generate with `--iter-decode-max-iter=9`:
+```bash
+fairseq-generate \
+ data-bin/wmt14_en_de_distill \
+ --gen-subset test \
+ --task translation_lev \
+ --path checkpoints/checkpoint_best.pt \
+ --iter-decode-max-iter 9 \
+ --iter-decode-eos-penalty 0 \
+ --beam 1 --remove-bpe \
+ --print-step \
+ --batch-size 400
+```
+In the end of the generation, we can see the tokenized BLEU score for the translation.
+
+
+## Citation
+
+```bibtex
+@article{gu2019levenshtein,
+ title={Levenshtein Transformer},
+ author={Gu, Jiatao and Wang, Changhan and Zhao, Jake},
+ journal={arXiv preprint arXiv:1905.11006},
+ year={2019}
+}
+```
diff --git a/examples/nonautoregressive_translation/scripts.md b/examples/nonautoregressive_translation/scripts.md
new file mode 100644
index 0000000000..2fda7f6204
--- /dev/null
+++ b/examples/nonautoregressive_translation/scripts.md
@@ -0,0 +1,148 @@
+# Examples of Training scripts for Non-autoregressive Machine Translation models
+
+### Non-autoregressive Transformer (NAT, Gu et al., 2017)
+Note that we need to have an additional module to perform "length prediction" (`--length-loss-factor`) before generating the whole sequence.
+```bash
+fairseq-train \
+ data-bin/wmt14_en_de_distill \
+ --save-dir checkpoints \
+ --ddp-backend=no_c10d \
+ --task translation_lev \
+ --criterion nat_loss \
+ --arch nonautoregressive_transformer \
+ --noise full_mask \
+ --share-all-embeddings \
+ --optimizer adam --adam-betas '(0.9,0.98)' \
+ --lr 0.0005 --lr-scheduler inverse_sqrt \
+ --min-lr '1e-09' --warmup-updates 10000 \
+ --warmup-init-lr '1e-07' --label-smoothing 0.1 \
+ --dropout 0.3 --weight-decay 0.01 \
+ --decoder-learned-pos \
+ --encoder-learned-pos \
+ --pred-length-offset \
+ --length-loss-factor 0.1 \
+ --apply-bert-init \
+ --log-format 'simple' --log-interval 100 \
+ --fixed-validation-seed 7 \
+ --max-tokens 8000 \
+ --save-interval-updates 10000 \
+ --max-update 300000
+```
+
+### Non-autoregressive Transformer with Iterative Refinement (iNAT, Lee et al., 2018)
+Note that `--train-step` means how many iterations of refinement we used during training, and `--dae-ratio` controls the ratio of denoising auto-encoder training described in the original paper.
+```bash
+fairseq-train \
+ data-bin/wmt14_en_de_distill \
+ --save-dir checkpoints \
+ --ddp-backend=no_c10d \
+ --task translation_lev \
+ --criterion nat_loss \
+ --arch nonautoregressive_transformer \
+ --noise full_mask \
+ --share-all-embeddings \
+ --optimizer adam --adam-betas '(0.9,0.98)' \
+ --lr 0.0005 --lr-scheduler inverse_sqrt \
+ --min-lr '1e-09' --warmup-updates 10000 \
+ --warmup-init-lr '1e-07' --label-smoothing 0.1 \
+ --dropout 0.3 --weight-decay 0.01 \
+ --decoder-learned-pos \
+ --encoder-learned-pos \
+ --pred-length-offset \
+ --length-loss-factor 0.1 \
+ --train-step 4 \
+ --dae-ratio 0.5 \
+ --stochastic-approx \
+ --apply-bert-init \
+ --log-format 'simple' --log-interval 100 \
+ --fixed-validation-seed 7 \
+ --max-tokens 8000 \
+ --save-interval-updates 10000 \
+ --max-update 300000
+```
+
+### Insertion Transformer (InsT, Stern et al., 2019)
+Note that we need to specify the "slot-loss" (uniform or balanced tree) described in the original paper. Here we use `--label-tau` to control the temperature.
+
+```bash
+fairseq-train \
+ data-bin/wmt14_en_de_distill \
+ --save-dir checkpoints \
+ --ddp-backend=no_c10d \
+ --task translation_lev \
+ --criterion nat_loss \
+ --arch insertion_transformer \
+ --noise random_delete \
+ --share-all-embeddings \
+ --optimizer adam --adam-betas '(0.9,0.98)' \
+ --lr 0.0005 --lr-scheduler inverse_sqrt \
+ --min-lr '1e-09' --warmup-updates 10000 \
+ --warmup-init-lr '1e-07' --label-smoothing 0.1 \
+ --dropout 0.3 --weight-decay 0.01 \
+ --decoder-learned-pos \
+ --encoder-learned-pos \
+ --pred-length-offset \
+ --length-loss-factor 0.1 \
+ --apply-bert-init \
+ --log-format 'simple' --log-interval 100 \
+ --fixed-validation-seed 7 \
+ --max-tokens 8000 \
+ --save-interval-updates 10000 \
+ --max-update 300000
+```
+
+
+### Mask Predict (CMLM, Ghazvininejad et al., 2019)
+```bash
+fairseq-train \
+ data-bin/wmt14_en_de_distill \
+ --save-dir checkpoints \
+ --ddp-backend=no_c10d \
+ --task translation_lev \
+ --criterion nat_loss \
+ --arch cmlm_transformer \
+ --noise random_mask \
+ --share-all-embeddings \
+ --optimizer adam --adam-betas '(0.9,0.98)' \
+ --lr 0.0005 --lr-scheduler inverse_sqrt \
+ --min-lr '1e-09' --warmup-updates 10000 \
+ --warmup-init-lr '1e-07' --label-smoothing 0.1 \
+ --dropout 0.3 --weight-decay 0.01 \
+ --decoder-learned-pos \
+ --encoder-learned-pos \
+ --apply-bert-init \
+ --log-format 'simple' --log-interval 100 \
+ --fixed-validation-seed 7 \
+ --max-tokens 8000 \
+ --save-interval-updates 10000 \
+ --max-update 300000
+```
+
+
+
+
+### Levenshtein Transformer (LevT, Gu et al., 2019)
+```bash
+fairseq-train \
+ data-bin/wmt14_en_de_distill \
+ --save-dir checkpoints \
+ --ddp-backend=no_c10d \
+ --task translation_lev \
+ --criterion nat_loss \
+ --arch levenshtein_transformer \
+ --noise random_delete \
+ --share-all-embeddings \
+ --optimizer adam --adam-betas '(0.9,0.98)' \
+ --lr 0.0005 --lr-scheduler inverse_sqrt \
+ --min-lr '1e-09' --warmup-updates 10000 \
+ --warmup-init-lr '1e-07' --label-smoothing 0.1 \
+ --dropout 0.3 --weight-decay 0.01 \
+ --decoder-learned-pos \
+ --encoder-learned-pos \
+ --apply-bert-init \
+ --log-format 'simple' --log-interval 100 \
+ --fixed-validation-seed 7 \
+ --max-tokens 8000 \
+ --save-interval-updates 10000 \
+ --max-update 300000
+```
diff --git a/fairseq/clib/libnat/edit_dist.cpp b/fairseq/clib/libnat/edit_dist.cpp
new file mode 100644
index 0000000000..966e9083bf
--- /dev/null
+++ b/fairseq/clib/libnat/edit_dist.cpp
@@ -0,0 +1,222 @@
+/**
+ * Copyright 2017-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include // @manual=//caffe2:torch_extension
+#include
+
+using namespace ::std;
+
+vector> edit_distance2_with_dp(
+ vector& x,
+ vector& y) {
+ uint32_t lx = x.size();
+ uint32_t ly = y.size();
+ vector> d(lx + 1, vector(ly + 1));
+ for (uint32_t i = 0; i < lx + 1; i++) {
+ d[i][0] = i;
+ }
+ for (uint32_t j = 0; j < ly + 1; j++) {
+ d[0][j] = j;
+ }
+ for (uint32_t i = 1; i < lx + 1; i++) {
+ for (uint32_t j = 1; j < ly + 1; j++) {
+ d[i][j] =
+ min(min(d[i - 1][j], d[i][j - 1]) + 1,
+ d[i - 1][j - 1] + 2 * (x.at(i - 1) == y.at(j - 1) ? 0 : 1));
+ }
+ }
+ return d;
+}
+
+vector> edit_distance2_backtracking(
+ vector>& d,
+ vector& x,
+ vector& y,
+ uint32_t terminal_symbol) {
+ vector seq;
+ vector> edit_seqs(x.size() + 2, vector());
+ /*
+ edit_seqs:
+ 0~x.size() cell is the insertion sequences
+ last cell is the delete sequence
+ */
+
+ if (x.size() == 0) {
+ edit_seqs.at(0) = y;
+ return edit_seqs;
+ }
+
+ uint32_t i = d.size() - 1;
+ uint32_t j = d.at(0).size() - 1;
+
+ while ((i >= 0) && (j >= 0)) {
+ if ((i == 0) && (j == 0)) {
+ break;
+ }
+
+ if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
+ seq.push_back(1); // insert
+ seq.push_back(y.at(j - 1));
+ j--;
+ } else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
+ seq.push_back(2); // delete
+ seq.push_back(x.at(i - 1));
+ i--;
+ } else {
+ seq.push_back(3); // keep
+ seq.push_back(x.at(i - 1));
+ i--;
+ j--;
+ }
+ }
+
+ uint32_t prev_op, op, s, word;
+ prev_op = 0, s = 0;
+ for (uint32_t k = 0; k < seq.size() / 2; k++) {
+ op = seq.at(seq.size() - 2 * k - 2);
+ word = seq.at(seq.size() - 2 * k - 1);
+ if (prev_op != 1) {
+ s++;
+ }
+ if (op == 1) // insert
+ {
+ edit_seqs.at(s - 1).push_back(word);
+ } else if (op == 2) // delete
+ {
+ edit_seqs.at(x.size() + 1).push_back(1);
+ } else {
+ edit_seqs.at(x.size() + 1).push_back(0);
+ }
+
+ prev_op = op;
+ }
+
+ for (uint32_t k = 0; k < edit_seqs.size(); k++) {
+ if (edit_seqs[k].size() == 0) {
+ edit_seqs[k].push_back(terminal_symbol);
+ }
+ }
+ return edit_seqs;
+}
+
+vector> edit_distance2_backtracking_with_delete(
+ vector>& d,
+ vector& x,
+ vector& y,
+ uint32_t terminal_symbol,
+ uint32_t deletion_symbol) {
+ vector seq;
+ vector> edit_seqs(x.size() + 1, vector());
+ /*
+ edit_seqs:
+ 0~x.size() cell is the insertion sequences
+ last cell is the delete sequence
+ */
+
+ if (x.size() == 0) {
+ edit_seqs.at(0) = y;
+ return edit_seqs;
+ }
+
+ uint32_t i = d.size() - 1;
+ uint32_t j = d.at(0).size() - 1;
+
+ while ((i >= 0) && (j >= 0)) {
+ if ((i == 0) && (j == 0)) {
+ break;
+ }
+
+ if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
+ seq.push_back(1); // insert
+ seq.push_back(y.at(j - 1));
+ j--;
+ } else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
+ seq.push_back(2); // delete
+ seq.push_back(x.at(i - 1));
+ i--;
+ } else {
+ seq.push_back(3); // keep
+ seq.push_back(x.at(i - 1));
+ i--;
+ j--;
+ }
+ }
+
+ uint32_t prev_op, op, s, word;
+ prev_op = 0, s = 0;
+ for (uint32_t k = 0; k < seq.size() / 2; k++) {
+ op = seq.at(seq.size() - 2 * k - 2);
+ word = seq.at(seq.size() - 2 * k - 1);
+ if (prev_op != 1) {
+ s++;
+ }
+ if (op == 1) // insert
+ {
+ edit_seqs.at(s - 1).push_back(word);
+ } else if (op == 2) // delete
+ {
+ edit_seqs.at(s - 1).push_back(deletion_symbol);
+ }
+
+ prev_op = op;
+ }
+
+ for (uint32_t k = 0; k < edit_seqs.size(); k++) {
+ if (edit_seqs.at(k).size() == 0) {
+ edit_seqs.at(k).push_back(terminal_symbol);
+ }
+ }
+ return edit_seqs;
+}
+
+vector compute_ed2(
+ vector>& xs,
+ vector>& ys) {
+ vector distances(xs.size());
+ for (uint32_t i = 0; i < xs.size(); i++) {
+ vector> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
+ distances.at(i) = d.at(xs.at(i).size()).at(ys.at(i).size());
+ }
+ return distances;
+}
+
+vector>> suggested_ed2_path(
+ vector>& xs,
+ vector>& ys,
+ uint32_t terminal_symbol) {
+ vector>> seq(xs.size());
+ for (uint32_t i = 0; i < xs.size(); i++) {
+ vector> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
+ seq.at(i) =
+ edit_distance2_backtracking(d, xs.at(i), ys.at(i), terminal_symbol);
+ }
+ return seq;
+}
+
+vector>> suggested_ed2_path_with_delete(
+ vector>& xs,
+ vector>& ys,
+ uint32_t terminal_symbol,
+ uint32_t deletion_symbol) {
+ vector>> seq(xs.size());
+ for (uint32_t i = 0; i < xs.size(); i++) {
+ vector> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
+ seq.at(i) = edit_distance2_backtracking_with_delete(
+ d, xs.at(i), ys.at(i), terminal_symbol, deletion_symbol);
+ }
+ return seq;
+}
+
+PYBIND11_MODULE(libnat, m) {
+ m.def("compute_ed2", &compute_ed2, "compute_ed2");
+ m.def("suggested_ed2_path", &suggested_ed2_path, "suggested_ed2_path");
+ m.def(
+ "suggested_ed2_path_with_delete",
+ &suggested_ed2_path_with_delete,
+ "suggested_ed2_path_with_delete");
+}
diff --git a/fairseq/criterions/nat_loss.py b/fairseq/criterions/nat_loss.py
new file mode 100644
index 0000000000..ccb25298f4
--- /dev/null
+++ b/fairseq/criterions/nat_loss.py
@@ -0,0 +1,190 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import torch.nn.functional as F
+from fairseq import utils
+from torch import Tensor
+
+from . import FairseqCriterion, register_criterion
+
+
+@register_criterion("nat_loss")
+class LabelSmoothedDualImitationCriterion(FairseqCriterion):
+ @staticmethod
+ def add_args(parser):
+ """Add criterion-specific arguments to the parser."""
+ # fmt: off
+ parser.add_argument(
+ '--label-smoothing',
+ default=0.,
+ type=float,
+ metavar='D',
+ help='epsilon for label smoothing, 0 means no label smoothing')
+ # fmt: on
+
+ def _compute_loss(
+ self, outputs, targets, masks=None, label_smoothing=0.0, name="loss", factor=1.0
+ ):
+ """
+ outputs: batch x len x d_model
+ targets: batch x len
+ masks: batch x len
+
+ policy_logprob: if there is some policy
+ depends on the likelihood score as rewards.
+ """
+
+ def mean_ds(x: Tensor, dim=None) -> Tensor:
+ return (
+ x.float().mean().type_as(x)
+ if dim is None
+ else x.float().mean(dim).type_as(x)
+ )
+
+ if masks is not None:
+ outputs, targets = outputs[masks], targets[masks]
+
+ logits = F.log_softmax(outputs, dim=-1)
+ if targets.dim() == 1:
+ losses = F.nll_loss(logits, targets, reduction="none")
+
+ else: # soft-labels
+ losses = F.kl_div(logits, targets, reduction="none")
+ losses = losses.float().sum(-1).type_as(losses)
+
+ nll_loss = mean_ds(losses)
+ if label_smoothing > 0:
+ loss = nll_loss * (1 - label_smoothing) - mean_ds(logits) * label_smoothing
+ else:
+ loss = nll_loss
+
+ loss = loss * factor
+ return {"name": name, "loss": loss, "nll_loss": nll_loss, "factor": factor}
+
+ def _custom_loss(self, loss, name="loss"):
+ return {"name": name, "loss": loss, "factor": 1}
+
+ def forward(self, model, sample, reduce=True):
+ """Compute the loss for the given sample.
+ Returns a tuple with three elements:
+ 1) the loss
+ 2) the sample size, which is used as the denominator for the gradient
+ 3) logging outputs to display while training
+ """
+ nsentences, ntokens = sample["nsentences"], sample["ntokens"]
+
+ # B x T
+ src_tokens, src_lengths = (
+ sample["net_input"]["src_tokens"],
+ sample["net_input"]["src_lengths"],
+ )
+ tgt_tokens, prev_output_tokens = sample["target"], sample["prev_target"]
+
+ outputs = model(src_tokens, src_lengths, prev_output_tokens, tgt_tokens)
+ losses = []
+ if "mask_ins_out" in outputs:
+ mask_ins_losses = self._compute_loss(
+ outputs["mask_ins_out"],
+ outputs["mask_ins_tgt"],
+ outputs["mask_ins_mask"],
+ name="m_ins-loss",
+ factor=1 if "mask_ins_w" not in outputs else outputs["mask_ins_w"],
+ )
+ losses += [mask_ins_losses]
+
+ if "word_ins_out" in outputs:
+ word_ins_losses = self._compute_loss(
+ outputs["word_ins_out"],
+ outputs["word_ins_tgt"],
+ outputs["word_ins_mask"],
+ self.args.label_smoothing,
+ name="w_ins-loss",
+ factor=1 if "word_ins_w" not in outputs else outputs["word_ins_w"],
+ )
+
+ losses += [word_ins_losses]
+ nll_loss = word_ins_losses["nll_loss"]
+
+ if "word_del_out" in outputs:
+ word_del_losses = self._compute_loss(
+ outputs["word_del_out"],
+ outputs["word_del_tgt"],
+ outputs["word_del_mask"],
+ 0.01,
+ name="w_del-loss",
+ factor=1 if "word_del_w" not in outputs else outputs["word_del_w"],
+ )
+
+ losses += [word_del_losses]
+
+ if "length_out" in outputs:
+ length_losses = self._compute_loss(
+ outputs["length_out"],
+ outputs["length_tgt"],
+ name="len-loss",
+ factor=1 if "length_w" not in outputs else outputs["length_w"],
+ )
+
+ losses += [length_losses]
+
+ for w in outputs:
+ if "-loss" in w:
+ losses += [self._custom_loss(outputs[w], w)]
+
+ loss = sum(l["loss"] for l in losses)
+
+ # NOTE: as we are summing up per token mlm loss and per sentence nsp loss
+ # we don't need to use sample_size as denominator for the gradient
+ # here sample_size is just used for logging
+ sample_size = 1
+ logging_output = {
+ "loss": utils.item(loss.data) if reduce else loss.data,
+ "nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
+ "ntokens": ntokens,
+ "nsentences": nsentences,
+ "sample_size": sample_size,
+ }
+
+ for l in losses:
+ logging_output[l["name"]] = (
+ utils.item(l["loss"].data / l["factor"])
+ if reduce
+ else l[["loss"]].data / l["factor"]
+ )
+
+ return loss, sample_size, logging_output
+
+ @staticmethod
+ def aggregate_logging_outputs(logging_outputs):
+ """Aggregate logging outputs from data parallel training."""
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+ loss = sum(log.get("loss", 0) for log in logging_outputs)
+ nll_loss = sum(log.get("nll_loss", 0) for log in logging_outputs)
+
+ results = {
+ "loss": loss / sample_size / math.log(2) if sample_size > 0 else 0.0,
+ "nll_loss": nll_loss / sample_size / math.log(2)
+ if sample_size > 0
+ else 0.0,
+ "ntokens": ntokens,
+ "nsentences": nsentences,
+ "sample_size": sample_size,
+ }
+
+ for key in logging_outputs[0]:
+ if key[-5:] == "-loss":
+ results[key[:-5]] = (
+ sum(log.get(key, 0) for log in logging_outputs)
+ / sample_size
+ / math.log(2)
+ if sample_size > 0
+ else 0.0
+ )
+
+ return results
diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py
index 417105e50b..5d135ba123 100644
--- a/fairseq/data/dictionary.py
+++ b/fairseq/data/dictionary.py
@@ -74,7 +74,10 @@ def token_string(i):
else:
return self[i]
- sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
+ if hasattr(self, 'bos_index'):
+ sent = ' '.join(token_string(i) for i in tensor if (i != self.eos()) and (i != self.bos()))
+ else:
+ sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
return data_utils.process_bpe_symbol(sent, bpe_symbol)
def unk_string(self, escape=False):
diff --git a/fairseq/iterative_refinement_generator.py b/fairseq/iterative_refinement_generator.py
new file mode 100644
index 0000000000..aee4884187
--- /dev/null
+++ b/fairseq/iterative_refinement_generator.py
@@ -0,0 +1,154 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from fairseq.models.model_utils import skip_tensors as _skip
+
+
+class IterativeRefinementGenerator(object):
+ def __init__(self,
+ tgt_dict,
+ eos_penalty=0.,
+ max_iter=10,
+ max_ratio=2,
+ decoding_format=None,
+ retain_dropout=False,
+ adaptive=True):
+ """
+ Generates translations based on iterative refinement.
+
+ Args:
+ tgt_dict: target dictionary
+ eos_penalty: if > 0.0, it penalized early-stopping in decoding
+ max_iter: maximum number of refinement iterations
+ max_ratio: generate sequences of maximum length ax, where x is the source length
+ decoding_format: decoding mode in {'unigram', 'ensemble', 'vote', 'dp', 'bs'}
+ retain_dropout: retaining dropout in the inference
+ adaptive: decoding with early stop
+ """
+ self.bos = tgt_dict.bos()
+ self.pad = tgt_dict.pad()
+ self.unk = tgt_dict.unk()
+ self.eos = tgt_dict.eos()
+ self.vocab_size = len(tgt_dict)
+ self.eos_penalty = eos_penalty
+ self.max_iter = max_iter
+ self.max_ratio = max_ratio
+ self.decoding_format = decoding_format
+ self.retain_dropout = retain_dropout
+ self.adaptive = adaptive
+
+ @torch.no_grad()
+ def generate(self, models, sample, prefix_tokens=None):
+
+ # TODO: model ensemble
+ assert len(models) == 1, 'only support single model'
+ model = models[0]
+ if not self.retain_dropout:
+ model.eval()
+
+ # TODO: better encoder inputs?
+ src_tokens = sample['net_input']['src_tokens']
+ src_lengths = sample['net_input']['src_lengths']
+ bsz, src_len = src_tokens.size()
+ sent_idxs = torch.arange(bsz, device=src_tokens.device)
+
+ # encoding
+ encoder_out = model.forward_encoder([src_tokens, src_lengths])
+
+ # initialize buffers (very model specific, with length prediction or not)
+ prev_decoder_out = model.initialize_output_tokens(
+ encoder_out, src_tokens)
+ prev_out_tokens = prev_decoder_out['output_tokens'].clone()
+
+ finalized = [[] for _ in range(bsz)]
+
+ def is_a_loop(x, y, s, a):
+ b, l_x, l_y = x.size(0), x.size(1), y.size(1)
+ if l_x > l_y:
+ y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1)
+ s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1)
+ if a is not None:
+ a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1)
+ elif l_x < l_y:
+ x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1)
+ return (x == y).all(1), y, s, a
+
+ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
+ cutoff = prev_out_token.ne(self.pad)
+ tokens = prev_out_token[cutoff]
+ scores = prev_out_score[cutoff]
+ if prev_out_attn is None:
+ hypo_attn, alignment = None, None
+ else:
+ hypo_attn = prev_out_attn[cutoff]
+ alignment = hypo_attn.max(dim=1)[1]
+ return {
+ 'steps': step,
+ 'tokens': tokens,
+ 'positional_scores': scores,
+ 'score': scores.mean(),
+ 'hypo_attn': hypo_attn,
+ 'alignment': alignment,
+ }
+
+ for step in range(self.max_iter + 1):
+
+ decoder_options = {
+ 'eos_penalty': self.eos_penalty,
+ 'max_ratio': self.max_ratio,
+ 'decoding_format': self.decoding_format
+ }
+ prev_decoder_out['step'] = step
+ prev_decoder_out['max_step'] = self.max_iter + 1
+
+ decoder_out = model.forward_decoder(
+ prev_decoder_out, encoder_out, **decoder_options
+ )
+
+ if self.adaptive:
+ # terminate if there is a loop
+ terminated, out_tokens, out_scores, out_attn = is_a_loop(
+ prev_out_tokens, decoder_out['output_tokens'],
+ decoder_out['output_scores'], decoder_out['attn'])
+ decoder_out['output_tokens'] = out_tokens
+ decoder_out['output_scores'] = out_scores
+ decoder_out['attn'] = out_attn
+
+ else:
+ terminated = decoder_out['output_tokens'].new_zeros(
+ decoder_out['output_tokens'].size(0)).bool()
+
+ if step == self.max_iter: # reach last iteration, terminate
+ terminated.fill_(1)
+
+ # collect finalized sentences
+ finalized_idxs = sent_idxs[terminated]
+ finalized_tokens = decoder_out['output_tokens'][terminated]
+ finalized_scores = decoder_out['output_scores'][terminated]
+ finalized_attn = None if decoder_out['attn'] is None else decoder_out['attn'][terminated]
+
+ for i in range(finalized_idxs.size(0)):
+ finalized[finalized_idxs[i]] = [
+ finalized_hypos(
+ step,
+ finalized_tokens[i],
+ finalized_scores[i],
+ None if finalized_attn is None else finalized_attn[i]
+ )
+ ]
+ # check if all terminated
+ if terminated.sum() == terminated.size(0):
+ break
+
+ # for next step
+ prev_decoder_out = _skip(decoder_out, ~terminated)
+ encoder_out = _skip(encoder_out, ~terminated)
+ sent_idxs = _skip(sent_idxs, ~terminated)
+
+ prev_out_tokens = prev_decoder_out['output_tokens'].clone()
+
+ return finalized
diff --git a/fairseq/models/cmlm_transformer.py b/fairseq/models/cmlm_transformer.py
new file mode 100644
index 0000000000..f76c93fd0f
--- /dev/null
+++ b/fairseq/models/cmlm_transformer.py
@@ -0,0 +1,136 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+This file implements:
+Ghazvininejad, Marjan, et al.
+"Constant-time machine translation with conditional masked language models."
+arXiv preprint arXiv:1904.09324 (2019).
+"""
+
+import torch
+from fairseq.models import register_model, register_model_architecture
+from fairseq.models.nonautoregressive_transformer import NATransformerModel
+
+
+def _skeptical_unmasking(output_scores, output_masks, p):
+ sorted_index = output_scores.sort(-1)[1]
+ boundary_len = (
+ (output_masks.sum(1, keepdim=True).type_as(output_scores) - 2) * p
+ ).long()
+ skeptical_mask = (
+ torch.arange(output_masks.size(1), device=output_masks.device)[None, :]
+ < boundary_len
+ )
+ return skeptical_mask.scatter(1, sorted_index, skeptical_mask)
+
+
+@register_model("cmlm_transformer")
+class CMLMNATransformerModel(NATransformerModel):
+ @staticmethod
+ def add_args(parser):
+ NATransformerModel.add_args(parser)
+
+ def forward(
+ self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
+ ):
+ assert not self.decoder.src_embedding_copy, "do not support embedding copy."
+
+ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
+ length_out, length_tgt = self.decoder.forward_length_prediction(
+ encoder_out, tgt_tokens
+ )
+
+ word_ins_out, word_ins_tgt, _ = self.decoder(
+ prev_output_tokens, encoder_out=encoder_out, tgt_tokens=tgt_tokens
+ )
+ word_ins_mask = prev_output_tokens.eq(self.unk)
+ return {
+ "word_ins_out": word_ins_out,
+ "word_ins_tgt": word_ins_tgt,
+ "word_ins_mask": word_ins_mask,
+ "length_out": length_out,
+ "length_tgt": length_tgt,
+ "length_w": self.decoder.length_loss_factor,
+ }
+
+ def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
+
+ step = decoder_out["step"]
+ max_step = decoder_out["max_step"]
+
+ output_tokens = decoder_out["output_tokens"]
+ output_scores = decoder_out["output_scores"]
+
+ # execute the decoder
+ output_masks = output_tokens.eq(self.unk)
+ _scores, _tokens = self.decoder(
+ output_tokens, encoder_out=encoder_out, decoding_format=decoding_format
+ )
+ output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
+ output_scores.masked_scatter_(output_masks, _scores[output_masks])
+
+ # skeptical decoding (depend on the maximum decoding steps.)
+ if (step + 1) < max_step:
+ skeptical_mask = _skeptical_unmasking(
+ output_scores, output_tokens.ne(self.pad), 1 - (step + 1) / max_step
+ )
+
+ output_tokens.masked_fill_(skeptical_mask, self.unk)
+ output_scores.masked_fill_(skeptical_mask, 0.0)
+
+ return {"output_tokens": output_tokens, "output_scores": output_scores}
+
+
+@register_model_architecture("cmlm_transformer", "cmlm_transformer")
+def base_architecture(args):
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
+ args.decoder_ffn_embed_dim = getattr(
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
+ )
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
+ args.activation_fn = getattr(args, "activation_fn", "relu")
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
+ args.share_decoder_input_output_embed = getattr(
+ args, "share_decoder_input_output_embed", False
+ )
+ args.share_all_embeddings = getattr(args, "share_all_embeddings", True)
+ args.no_token_positional_embeddings = getattr(
+ args, "no_token_positional_embeddings", False
+ )
+ args.adaptive_input = getattr(args, "adaptive_input", False)
+ args.apply_bert_init = getattr(args, "apply_bert_init", False)
+
+ args.decoder_output_dim = getattr(
+ args, "decoder_output_dim", args.decoder_embed_dim
+ )
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
+
+ # --- special arguments ---
+ args.sg_length_pred = getattr(args, "sg_length_pred", False)
+ args.pred_length_offset = getattr(args, "pred_length_offset", False)
+ args.length_loss_factor = getattr(args, "length_loss_factor", 0.1)
+ args.ngram_predictor = getattr(args, "ngram_predictor", 1)
+ args.src_embedding_copy = getattr(args, "src_embedding_copy", False)
+
+
+@register_model_architecture("cmlm_transformer", "cmlm_transformer_wmt_en_de")
+def iter_nat_wmt_en_de(args):
+ base_architecture(args)
diff --git a/fairseq/models/insertion_transformer.py b/fairseq/models/insertion_transformer.py
new file mode 100644
index 0000000000..5f5868a550
--- /dev/null
+++ b/fairseq/models/insertion_transformer.py
@@ -0,0 +1,259 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from fairseq import libnat
+from fairseq.models import register_model, register_model_architecture
+from fairseq.models.levenshtein_transformer import (
+ LevenshteinTransformerDecoder,
+ LevenshteinTransformerModel,
+)
+from fairseq.models.transformer import Linear, TransformerModel
+from fairseq.modules.transformer_sentence_encoder import init_bert_params
+
+
+class NegativeDistanceScore(object):
+ def __init__(self):
+
+ # pre-compute some values
+ self.scores = {}
+
+ self.scores[0.5] = self.compute_score_full(50, 0.5)
+ self.scores[1.0] = self.compute_score_full(50, 1.0)
+ self.scores[2.0] = self.compute_score_full(50, 2.0)
+
+ def __call__(self, i, L, tau):
+ if (tau is None) or (tau > 1000):
+ return 1 / L
+
+ if tau in self.scores:
+ if L < self.scores[tau].shape[0]:
+ return self.scores[tau][L - 1, i]
+ return self.compute_score(L, tau)[i]
+
+ def compute_score(self, L, tau):
+ s = np.array([-abs(L / 2 - i) / tau for i in range(L)])
+ s = np.exp(s - s.max())
+ return s / s.sum()
+
+ def compute_score_full(self, L, tau):
+ s = -abs(np.arange(0, L - 1)[:, None] / 2 - np.arange(L)[None, :]) / tau
+ s = np.tril(s, 0) + np.triu(s - float("inf"), 1)
+ s = np.exp(s - s.max(1, keepdims=True))
+ return s / s.sum(1, keepdims=True)
+
+
+neg_scorer = NegativeDistanceScore()
+
+
+def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx, vocab_size, tau=None):
+ B = in_tokens.size(0)
+ T = in_tokens.size(1)
+ V = vocab_size
+
+ with torch.cuda.device_of(in_tokens):
+ in_tokens_list = [
+ [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
+ ]
+ out_tokens_list = [
+ [t for t in s if t != padding_idx]
+ for i, s in enumerate(out_tokens.tolist())
+ ]
+
+ full_labels = libnat.suggested_ed2_path(
+ in_tokens_list, out_tokens_list, padding_idx
+ )
+ insert_labels = [a[:-1] for a in full_labels]
+
+ # numericalize1
+ insert_label_tensors = in_tokens.new_zeros(B * (T - 1) * V).float()
+ insert_index, insert_labels = zip(
+ *[
+ (w + (j + i * (T - 1)) * V, neg_scorer(k, len(label), tau))
+ for i, labels in enumerate(insert_labels)
+ for j, label in enumerate(labels[1:-1])
+ for k, w in enumerate(label)
+ ]
+ ) # HACK 1:-1
+ insert_index, insert_labels = [
+ torch.tensor(list(a), device=in_tokens.device)
+ for a in [insert_index, insert_labels]
+ ]
+ insert_label_tensors.scatter_(0, insert_index.long(), insert_labels)
+ insert_label_tensors = insert_label_tensors.view(B, T - 1, V)
+
+ return insert_label_tensors
+
+
+def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, padding_idx):
+
+ padding_masks = in_tokens[:, 1:].eq(padding_idx)
+ word_ins_scores.masked_fill_(padding_masks, 0.0)
+ word_ins_pred.masked_fill_(padding_masks, padding_idx)
+
+ in_coords = torch.arange(in_tokens.size(1), device=in_tokens.device)
+ in_coords = in_coords.unsqueeze(0).repeat(in_tokens.size(0), 1).type_as(in_scores)
+
+ # shift all padding predictions to infinite
+ out_coords = (in_coords[:, 1:] - 0.5).masked_fill(
+ word_ins_pred.eq(padding_idx), float("inf")
+ )
+ out_coords = torch.cat([in_coords, out_coords], 1).sort(-1)[1]
+ out_tokens = torch.cat([in_tokens, word_ins_pred], 1).gather(1, out_coords)
+ out_scores = torch.cat([in_scores, word_ins_scores], 1).gather(1, out_coords)
+ return out_tokens, out_scores
+
+
+@register_model("insertion_transformer")
+class InsertionTransformerModel(LevenshteinTransformerModel):
+ def __init__(self, encoder, decoder):
+ super().__init__(encoder, decoder)
+
+ @staticmethod
+ def add_args(parser):
+ TransformerModel.add_args(parser)
+ parser.add_argument(
+ "--apply-bert-init",
+ action="store_true",
+ help="use custom param initialization for BERT",
+ )
+ parser.add_argument("--label-tau", default=None, type=float)
+
+ @classmethod
+ def build_decoder(cls, args, tgt_dict, embed_tokens):
+ decoder = InsertionTransformerDecoder(args, tgt_dict, embed_tokens)
+ if getattr(args, "apply_bert_init", False):
+ decoder.apply(init_bert_params)
+ return decoder
+
+ def forward(
+ self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
+ ):
+
+ assert tgt_tokens is not None, "forward function only supports training."
+
+ # encoding
+ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
+
+ # generate training labels for insertion
+ word_ins_out = self.decoder.forward_word_ins(
+ prev_output_tokens, encoder_out=encoder_out
+ )
+ word_ins_tgt = _get_ins_targets(
+ prev_output_tokens,
+ tgt_tokens,
+ self.pad,
+ self.unk,
+ len(self.tgt_dict),
+ tau=self.decoder.label_tau,
+ ).type_as(word_ins_out)
+ word_ins_masks = prev_output_tokens[:, 1:].ne(self.pad)
+
+ return {
+ "word_ins_out": word_ins_out,
+ "word_ins_tgt": word_ins_tgt,
+ "word_ins_mask": word_ins_masks,
+ }
+
+ def forward_decoder(
+ self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs
+ ):
+
+ output_tokens = decoder_out["output_tokens"]
+ output_scores = decoder_out["output_scores"]
+ # TODO: decoding for InsertionTransformer
+ word_ins_out = self.decoder.forward_word_ins(
+ output_tokens, encoder_out=encoder_out
+ )
+ word_ins_score = F.log_softmax(word_ins_out, 2)
+ if eos_penalty > 0.0:
+ word_ins_score[:, :, self.pad] -= eos_penalty
+ word_ins_score, word_ins_pred = word_ins_score.max(-1)
+ output_tokens, output_scores = _apply_ins_words(
+ output_tokens, output_scores, word_ins_pred, word_ins_score, self.pad
+ )
+
+ # delete some unnecessary paddings
+ cut_off = output_tokens.ne(self.pad).sum(1).max()
+ output_tokens = output_tokens[:, :cut_off]
+ output_scores = output_scores[:, :cut_off]
+ return {"output_tokens": output_tokens, "output_scores": output_scores}
+
+
+class InsertionTransformerDecoder(LevenshteinTransformerDecoder):
+ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
+ # use the TransformerDecoder's __init__
+ super(LevenshteinTransformerDecoder, self).__init__(
+ args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
+ )
+
+ self.dictionary = dictionary
+ self.bos = dictionary.bos()
+ self.unk = dictionary.unk()
+ self.eos = dictionary.eos()
+ self.pool_out = Linear(self.output_embed_dim * 2, self.output_embed_dim)
+
+ self.label_tau = getattr(args, "label_tau", None)
+
+ def forward_word_ins(self, prev_output_tokens, encoder_out=None):
+ features, _ = self.extract_features(prev_output_tokens, encoder_out=encoder_out)
+ features = self.pool_out(
+ torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
+ )
+ return self.output_layer(features)
+
+ def forward_mask_ins(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def forward_word_del(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def forward_word_del_mask_ins(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+@register_model_architecture("insertion_transformer", "insertion_transformer")
+def base_architecture(args):
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
+ args.decoder_ffn_embed_dim = getattr(
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
+ )
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
+ args.activation_fn = getattr(args, "activation_fn", "relu")
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
+ args.share_decoder_input_output_embed = getattr(
+ args, "share_decoder_input_output_embed", False
+ )
+ args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
+ args.no_token_positional_embeddings = getattr(
+ args, "no_token_positional_embeddings", False
+ )
+ args.adaptive_input = getattr(args, "adaptive_input", False)
+ args.apply_bert_init = getattr(args, "apply_bert_init", False)
+
+ args.decoder_output_dim = getattr(
+ args, "decoder_output_dim", args.decoder_embed_dim
+ )
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
+
+ # special for insertion transformer
+ args.label_tau = getattr(args, "label_tau", None)
diff --git a/fairseq/models/iterative_nonautoregressive_transformer.py b/fairseq/models/iterative_nonautoregressive_transformer.py
new file mode 100644
index 0000000000..73585db354
--- /dev/null
+++ b/fairseq/models/iterative_nonautoregressive_transformer.py
@@ -0,0 +1,196 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from fairseq.models import register_model, register_model_architecture
+from fairseq.models.nonautoregressive_transformer import NATransformerModel
+
+
+def _sequential_poisoning(s, V, beta=0.33, bos=2, eos=3, pad=1):
+ # s: input batch
+ # V: vocabulary size
+ rand_words = torch.randint(low=4, high=V, size=s.size(), device=s.device)
+ choices = torch.rand(size=s.size(), device=s.device)
+ choices.masked_fill_((s == pad) | (s == bos) | (s == eos), 1)
+
+ replace = choices < beta / 3
+ repeat = (choices >= beta / 3) & (choices < beta * 2 / 3)
+ swap = (choices >= beta * 2 / 3) & (choices < beta)
+ safe = choices >= beta
+
+ for i in range(s.size(1) - 1):
+ rand_word = rand_words[:, i]
+ next_word = s[:, i + 1]
+ self_word = s[:, i]
+
+ replace_i = replace[:, i]
+ swap_i = swap[:, i] & (next_word != 3)
+ repeat_i = repeat[:, i] & (next_word != 3)
+ safe_i = safe[:, i] | ((next_word == 3) & (~replace_i))
+
+ s[:, i] = (
+ self_word * (safe_i | repeat_i).long()
+ + next_word * swap_i.long()
+ + rand_word * replace_i.long()
+ )
+ s[:, i + 1] = (
+ next_word * (safe_i | replace_i).long()
+ + self_word * (swap_i | repeat_i).long()
+ )
+ return s
+
+
+def gumbel_noise(input, TINY=1e-8):
+ return input.new_zeros(*input.size()).uniform_().add_(
+ TINY).log_().neg_().add_(TINY).log_().neg_()
+
+
+@register_model("iterative_nonautoregressive_transformer")
+class IterNATransformerModel(NATransformerModel):
+ @staticmethod
+ def add_args(parser):
+ NATransformerModel.add_args(parser)
+ parser.add_argument("--train-step", type=int,
+ help="number of refinement iterations during training")
+ parser.add_argument("--dae-ratio", type=float,
+ help="the probability of switching to the denoising auto-encoder loss")
+ parser.add_argument("--stochastic-approx", action="store_true",
+ help="sampling from the decoder as the inputs for next iteration")
+
+ @classmethod
+ def build_model(cls, args, task):
+ model = super().build_model(args, task)
+ model.train_step = getattr(args, "train_step", 4)
+ model.dae_ratio = getattr(args, "dae_ratio", 0.5)
+ model.stochastic_approx = getattr(args, "stochastic_approx", False)
+ return model
+
+ def forward(
+ self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
+ ):
+
+ B, T = prev_output_tokens.size()
+
+ # encoding
+ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
+ length_out, length_tgt = self.decoder.forward_length_prediction(
+ encoder_out, tgt_tokens
+ )
+ word_ins_outs, word_ins_tgts, word_ins_masks = [], [], []
+ for t in range(self.train_step):
+ word_ins_out, word_ins_tgt, word_ins_mask = self.decoder(
+ prev_output_tokens,
+ encoder_out=encoder_out,
+ tgt_tokens=tgt_tokens,
+ step=t,
+ )
+
+ word_ins_outs.append(word_ins_out)
+ word_ins_tgts.append(word_ins_tgt)
+ word_ins_masks.append(word_ins_mask)
+
+ if t < (self.train_step - 1):
+ # prediction for next iteration
+ if self.stochastic_approx:
+ word_ins_prediction = (
+ word_ins_out + gumbel_noise(word_ins_out)
+ ).max(-1)[1]
+ else:
+ word_ins_prediction = word_ins_out.max(-1)[1]
+
+ prev_output_tokens = prev_output_tokens.masked_scatter(
+ word_ins_mask, word_ins_prediction[word_ins_mask]
+ )
+
+ if self.dae_ratio > 0:
+ # we do not perform denoising for the first iteration
+ corrputed = (
+ torch.rand(size=(B,), device=prev_output_tokens.device)
+ < self.dae_ratio
+ )
+ corrputed_tokens = _sequential_poisoning(
+ tgt_tokens[corrputed],
+ len(self.tgt_dict),
+ 0.33,
+ self.bos,
+ self.eos,
+ self.pad,
+ )
+ prev_output_tokens[corrputed] = corrputed_tokens
+
+ # concat everything
+ word_ins_out = torch.cat(word_ins_outs, 0)
+ word_ins_tgt = torch.cat(word_ins_tgts, 0)
+ word_ins_mask = torch.cat(word_ins_masks, 0)
+
+ return {
+ "word_ins_out": word_ins_out,
+ "word_ins_tgt": word_ins_tgt,
+ "word_ins_mask": word_ins_mask,
+ "length_out": length_out,
+ "length_tgt": length_tgt,
+ "length_w": self.decoder.length_loss_factor,
+ }
+
+
+@register_model_architecture(
+ "iterative_nonautoregressive_transformer", "iterative_nonautoregressive_transformer"
+)
+def base_architecture(args):
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
+ args.decoder_ffn_embed_dim = getattr(
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
+ )
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
+ args.activation_fn = getattr(args, "activation_fn", "relu")
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
+ args.share_decoder_input_output_embed = getattr(
+ args, "share_decoder_input_output_embed", False
+ )
+ args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
+ args.no_token_positional_embeddings = getattr(
+ args, "no_token_positional_embeddings", False
+ )
+ args.adaptive_input = getattr(args, "adaptive_input", False)
+ args.apply_bert_init = getattr(args, "apply_bert_init", False)
+
+ args.decoder_output_dim = getattr(
+ args, "decoder_output_dim", args.decoder_embed_dim
+ )
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
+
+ # --- special arguments ---
+ args.sg_length_pred = getattr(args, "sg_length_pred", False)
+ args.pred_length_offset = getattr(args, "pred_length_offset", False)
+ args.length_loss_factor = getattr(args, "length_loss_factor", 0.1)
+ args.ngram_predictor = getattr(args, "ngram_predictor", 1)
+ args.src_embedding_copy = getattr(args, "src_embedding_copy", False)
+
+ args.train_step = getattr(args, "train_step", 4)
+ args.dae_ratio = getattr(args, "dae_ratio", 0.5)
+ args.stochastic_approx = getattr(args, "stochastic_approx", False)
+
+
+@register_model_architecture(
+ "iterative_nonautoregressive_transformer",
+ "iterative_nonautoregressive_transformer_wmt_en_de",
+)
+def iter_nat_wmt_en_de(args):
+ base_architecture(args)
diff --git a/fairseq/models/levenshtein_transformer.py b/fairseq/models/levenshtein_transformer.py
new file mode 100644
index 0000000000..876bf01a0f
--- /dev/null
+++ b/fairseq/models/levenshtein_transformer.py
@@ -0,0 +1,595 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+from fairseq import libnat
+from fairseq.models import register_model, register_model_architecture
+from fairseq.models.model_utils import fill_tensors as _fill, skip_tensors as _skip
+from fairseq.models.transformer import (
+ Embedding,
+ TransformerDecoder,
+ TransformerEncoder,
+ TransformerModel,
+)
+from fairseq.modules.transformer_sentence_encoder import init_bert_params
+
+
+def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
+ in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
+
+ with torch.cuda.device_of(in_tokens):
+ in_tokens_list = [
+ [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
+ ]
+ out_tokens_list = [
+ [t for t in s if t != padding_idx]
+ for i, s in enumerate(out_tokens.tolist())
+ ]
+
+ full_labels = libnat.suggested_ed2_path(
+ in_tokens_list, out_tokens_list, padding_idx
+ )
+ mask_inputs = [
+ [len(c) if c[0] != padding_idx else 0 for c in a[:-1]] for a in full_labels
+ ]
+
+ # generate labels
+ masked_tgt_masks = []
+ for mask_input in mask_inputs:
+ mask_label = []
+ for beam_size in mask_input[1:-1]: # HACK 1:-1
+ mask_label += [0] + [1 for _ in range(beam_size)]
+ masked_tgt_masks.append(
+ mask_label + [0 for _ in range(out_seq_len - len(mask_label))]
+ )
+ mask_ins_targets = [
+ mask_input[1:-1] + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))]
+ for mask_input in mask_inputs
+ ]
+
+ # transform to tensor
+ masked_tgt_masks = torch.tensor(
+ masked_tgt_masks, device=out_tokens.device
+ ).bool()
+ mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device)
+ masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx)
+ return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets
+
+
+def _get_del_targets(in_tokens, out_tokens, padding_idx):
+ out_seq_len = out_tokens.size(1)
+
+ with torch.cuda.device_of(in_tokens):
+ in_tokens_list = [
+ [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
+ ]
+ out_tokens_list = [
+ [t for t in s if t != padding_idx]
+ for i, s in enumerate(out_tokens.tolist())
+ ]
+
+ full_labels = libnat.suggested_ed2_path(
+ in_tokens_list, out_tokens_list, padding_idx
+ )
+ word_del_targets = [b[-1] for b in full_labels]
+ word_del_targets = [
+ labels + [0 for _ in range(out_seq_len - len(labels))]
+ for labels in word_del_targets
+ ]
+
+ # transform to tensor
+ word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device)
+ return word_del_targets
+
+
+def _get_del_ins_targets(in_tokens, out_tokens, padding_idx):
+ in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
+
+ with torch.cuda.device_of(in_tokens):
+ in_tokens_list = [
+ [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
+ ]
+ out_tokens_list = [
+ [t for t in s if t != padding_idx]
+ for i, s in enumerate(out_tokens.tolist())
+ ]
+
+ full_labels = libnat.suggested_ed2_path(
+ in_tokens_list, out_tokens_list, padding_idx
+ )
+
+ word_del_targets = [b[-1] for b in full_labels]
+ word_del_targets = [
+ labels + [0 for _ in range(out_seq_len - len(labels))]
+ for labels in word_del_targets
+ ]
+
+ mask_inputs = [
+ [len(c) if c[0] != padding_idx else 0 for c in a[:-1]] for a in full_labels
+ ]
+ mask_ins_targets = [
+ mask_input[1:-1] + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))]
+ for mask_input in mask_inputs
+ ]
+
+ # transform to tensor
+ mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device)
+ word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device)
+ return word_del_targets, mask_ins_targets
+
+
+def _apply_ins_masks(
+ in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx
+):
+
+ in_masks = in_tokens.ne(padding_idx)
+ in_lengths = in_masks.sum(1)
+
+ # HACK: hacky way to shift all the paddings to eos first.
+ in_tokens.masked_fill_(~in_masks, eos_idx)
+ mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
+
+ out_lengths = in_lengths + mask_ins_pred.sum(1)
+ out_max_len = out_lengths.max()
+ out_masks = (
+ torch.arange(out_max_len, device=out_lengths.device)[None, :]
+ < out_lengths[:, None]
+ )
+
+ reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
+ out_tokens = (
+ in_tokens.new_zeros(in_tokens.size(0), out_max_len)
+ .fill_(padding_idx)
+ .masked_fill_(out_masks, unk_idx)
+ )
+ out_tokens[:, 0] = in_tokens[:, 0]
+ out_tokens.scatter_(1, reordering, in_tokens[:, 1:])
+
+ out_scores = None
+ if in_scores is not None:
+ in_scores.masked_fill_(~in_masks, 0)
+ out_scores = in_scores.new_zeros(*out_tokens.size())
+ out_scores[:, 0] = in_scores[:, 0]
+ out_scores.scatter_(1, reordering, in_scores[:, 1:])
+
+ return out_tokens, out_scores
+
+
+def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx):
+ word_ins_masks = in_tokens.eq(unk_idx)
+ out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks])
+
+ if in_scores is not None:
+ out_scores = in_scores.masked_scatter(
+ word_ins_masks, word_ins_scores[word_ins_masks]
+ )
+ else:
+ out_scores = None
+
+ return out_tokens, out_scores
+
+
+def _apply_del_words(
+ in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx
+):
+ # apply deletion to a tensor
+ in_masks = in_tokens.ne(padding_idx)
+ bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
+
+ max_len = in_tokens.size(1)
+ word_del_pred.masked_fill_(~in_masks, 1)
+ word_del_pred.masked_fill_(bos_eos_masks, 0)
+
+ reordering = (
+ torch.arange(max_len, device=in_tokens.device)[None, :]
+ .expand_as(in_tokens)
+ .contiguous()
+ .masked_fill_(word_del_pred, max_len)
+ .sort(1)[1]
+ )
+
+ out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering)
+
+ out_scores = None
+ if in_scores is not None:
+ out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
+
+ out_attn = None
+ if in_attn is not None:
+ _mask = word_del_pred[:, :, None].expand_as(in_attn)
+ _reordering = reordering[:, :, None].expand_as(in_attn)
+ out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering)
+
+ return out_tokens, out_scores, out_attn
+
+
+@register_model("levenshtein_transformer")
+class LevenshteinTransformerModel(TransformerModel):
+ def __init__(self, encoder, decoder):
+ super().__init__(encoder, decoder)
+ self.tgt_dict = decoder.dictionary
+ self.bos = decoder.dictionary.bos()
+ self.eos = decoder.dictionary.eos()
+ self.pad = decoder.dictionary.pad()
+ self.unk = decoder.dictionary.unk()
+
+ @staticmethod
+ def add_args(parser):
+ TransformerModel.add_args(parser)
+ parser.add_argument(
+ "--apply-bert-init",
+ action="store_true",
+ help="use custom param initialization for BERT",
+ )
+ parser.add_argument(
+ "--early-exit",
+ default="6,6,6",
+ type=str,
+ help="number of decoder layers before mask_ins, word_ins and word_del heads",
+ )
+
+ @classmethod
+ def build_decoder(cls, args, tgt_dict, embed_tokens):
+ decoder = LevenshteinTransformerDecoder(args, tgt_dict, embed_tokens)
+ if getattr(args, "apply_bert_init", False):
+ decoder.apply(init_bert_params)
+ return decoder
+
+ @classmethod
+ def build_encoder(cls, args, src_dict, embed_tokens):
+ encoder = TransformerEncoder(args, src_dict, embed_tokens)
+ if getattr(args, "apply_bert_init", False):
+ encoder.apply(init_bert_params)
+ return encoder
+
+ def forward(
+ self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
+ ):
+
+ assert tgt_tokens is not None, "forward function only supports training."
+
+ # encoding
+ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
+
+ # generate training labels for insertion
+ masked_tgt_masks, masked_tgt_tokens, mask_ins_targets = _get_ins_targets(
+ prev_output_tokens, tgt_tokens, self.pad, self.unk
+ )
+ mask_ins_targets = mask_ins_targets.clamp(min=0, max=255) # for safe prediction
+ mask_ins_masks = prev_output_tokens[:, 1:].ne(self.pad)
+
+ mask_ins_out, _ = self.decoder.forward_mask_ins(
+ prev_output_tokens, encoder_out=encoder_out
+ )
+ word_ins_out, _ = self.decoder.forward_word_ins(
+ masked_tgt_tokens, encoder_out=encoder_out
+ )
+
+ # make online prediction
+ word_predictions = F.log_softmax(word_ins_out, dim=-1).max(2)[1]
+ word_predictions.masked_scatter_(
+ ~masked_tgt_masks, tgt_tokens[~masked_tgt_masks]
+ )
+
+ # generate training labels for deletion
+ word_del_targets = _get_del_targets(word_predictions, tgt_tokens, self.pad)
+ word_del_out, _ = self.decoder.forward_word_del(
+ word_predictions, encoder_out)
+
+ return {
+ "mask_ins_out": mask_ins_out,
+ "mask_ins_tgt": mask_ins_targets,
+ "mask_ins_mask": mask_ins_masks,
+ "word_ins_out": word_ins_out,
+ "word_ins_tgt": tgt_tokens,
+ "word_ins_mask": masked_tgt_masks,
+ "word_del_out": word_del_out,
+ "word_del_tgt": word_del_targets,
+ "word_del_mask": word_predictions.ne(self.pad),
+ }
+
+ def forward_encoder(self, encoder_inputs):
+ return self.encoder(*encoder_inputs)
+
+ def forward_decoder(
+ self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs
+ ):
+
+ output_tokens = decoder_out["output_tokens"]
+ output_scores = decoder_out["output_scores"]
+ attn = decoder_out["attn"]
+
+ if max_ratio is None:
+ max_lens = output_tokens.new(output_tokens.size(0)).fill_(255)
+ else:
+ max_lens = (
+ (~encoder_out["encoder_padding_mask"]).sum(1) * max_ratio
+ ).clamp(min=10)
+
+ # delete words
+ # do not delete tokens if it is
+ can_del_word = output_tokens.ne(self.pad).sum(1) > 2
+ if can_del_word.sum() != 0: # we cannot delete, skip
+ word_del_out, word_del_attn = self.decoder.forward_word_del(
+ _skip(output_tokens, can_del_word), _skip(encoder_out, can_del_word)
+ )
+ word_del_score = F.log_softmax(word_del_out, 2)
+ word_del_pred = word_del_score.max(-1)[1].bool()
+
+ _tokens, _scores, _attn = _apply_del_words(
+ output_tokens[can_del_word],
+ output_scores[can_del_word],
+ word_del_attn,
+ word_del_pred,
+ self.pad,
+ self.bos,
+ self.eos,
+ )
+ output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad)
+ output_scores = _fill(output_scores, can_del_word, _scores, 0)
+ attn = _fill(attn, can_del_word, _attn, 0.)
+
+ # insert placeholders
+ can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens
+ if can_ins_mask.sum() != 0:
+ mask_ins_out, _ = self.decoder.forward_mask_ins(
+ _skip(output_tokens, can_ins_mask), _skip(encoder_out, can_ins_mask)
+ )
+ mask_ins_score = F.log_softmax(mask_ins_out, 2)
+ if eos_penalty > 0.0:
+ mask_ins_score[:, :, 0] -= eos_penalty
+ mask_ins_pred = mask_ins_score.max(-1)[1]
+ mask_ins_pred = torch.min(
+ mask_ins_pred, max_lens[:, None].expand_as(mask_ins_pred)
+ )
+
+ _tokens, _scores = _apply_ins_masks(
+ output_tokens[can_ins_mask],
+ output_scores[can_ins_mask],
+ mask_ins_pred,
+ self.pad,
+ self.unk,
+ self.eos,
+ )
+ output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad)
+ output_scores = _fill(output_scores, can_ins_mask, _scores, 0)
+
+ # insert words
+ can_ins_word = output_tokens.eq(self.unk).sum(1) > 0
+ if can_ins_word.sum() != 0:
+ word_ins_out, word_ins_attn = self.decoder.forward_word_ins(
+ _skip(output_tokens, can_ins_word), _skip(encoder_out, can_ins_word)
+ )
+ word_ins_score = F.log_softmax(word_ins_out, 2)
+ word_ins_pred = word_ins_score.max(-1)[1]
+
+ _tokens, _scores = _apply_ins_words(
+ output_tokens[can_ins_word],
+ output_scores[can_ins_word],
+ word_ins_pred,
+ word_ins_score,
+ self.unk,
+ )
+
+ output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad)
+ output_scores = _fill(output_scores, can_ins_word, _scores, 0)
+ attn = _fill(attn, can_ins_word, word_ins_attn, 0.)
+
+ # delete some unnecessary paddings
+ cut_off = output_tokens.ne(self.pad).sum(1).max()
+ output_tokens = output_tokens[:, :cut_off]
+ output_scores = output_scores[:, :cut_off]
+ attn = None if attn is None else attn[:, :cut_off, :]
+ return {
+ "output_tokens": output_tokens,
+ "output_scores": output_scores,
+ "attn": attn,
+ }
+
+ def initialize_output_tokens(self, encoder_out, src_tokens):
+ initial_output_tokens = src_tokens.new_zeros(src_tokens.size(0), 2)
+ initial_output_tokens[:, 0] = self.bos
+ initial_output_tokens[:, 1] = self.eos
+
+ initial_output_scores = initial_output_tokens.new_zeros(
+ *initial_output_tokens.size()
+ ).type_as(encoder_out["encoder_out"])
+ return {
+ "output_tokens": initial_output_tokens,
+ "output_scores": initial_output_scores,
+ "attn": None,
+ }
+
+
+class LevenshteinTransformerDecoder(TransformerDecoder):
+ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
+ super().__init__(
+ args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
+ )
+ self.dictionary = dictionary
+ self.bos = dictionary.bos()
+ self.unk = dictionary.unk()
+ self.eos = dictionary.eos()
+
+ self.embed_mask_ins = Embedding(256, self.output_embed_dim * 2, None)
+ self.embed_word_del = Embedding(2, self.output_embed_dim, None)
+ # del_word, ins_mask, ins_word
+ self.early_exit = [int(i) for i in args.early_exit.split(',')]
+ assert len(self.early_exit) == 3
+
+ def extract_features(
+ self, prev_output_tokens, encoder_out=None, early_exit=None, **unused
+ ):
+ """
+ Similar to *forward* but only return features.
+
+ Inputs:
+ prev_output_tokens: Tensor(B, T)
+ encoder_out: a dictionary of hidden states and masks
+
+ Returns:
+ tuple:
+ - the decoder's features of shape `(batch, tgt_len, embed_dim)`
+ - a dictionary with any model-specific outputs
+ the LevenshteinTransformer decoder has full-attention to all generated tokens
+ """
+ # embed positions
+ positions = (
+ self.embed_positions(prev_output_tokens)
+ if self.embed_positions is not None
+ else None
+ )
+
+ # embed tokens and positions
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens)
+ if self.project_in_dim is not None:
+ x = self.project_in_dim(x)
+
+ if positions is not None:
+ x += positions
+ x = F.dropout(x, p=self.dropout, training=self.training)
+
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+ attn = None
+ inner_states = [x]
+
+ # decoder layers
+ decoder_padding_mask = prev_output_tokens.eq(self.padding_idx)
+ for i, layer in enumerate(self.layers):
+
+ # early exit from the decoder.
+ if (early_exit is not None) and (i >= early_exit):
+ break
+
+ x, attn = layer(
+ x,
+ encoder_out["encoder_out"] if encoder_out is not None else None,
+ encoder_out["encoder_padding_mask"]
+ if encoder_out is not None
+ else None,
+ self_attn_mask=None,
+ self_attn_padding_mask=decoder_padding_mask,
+ )
+ inner_states.append(x)
+
+ if self.layer_norm:
+ x = self.layer_norm(x)
+
+ # T x B x C -> B x T x C
+ x = x.transpose(0, 1)
+
+ if self.project_out_dim is not None:
+ x = self.project_out_dim(x)
+
+ return x, {"attn": attn, "inner_states": inner_states}
+
+ def forward_mask_ins(self, prev_output_tokens, encoder_out=None):
+ features, extra = self.extract_features(
+ prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[1]
+ )
+ features_cat = torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
+ return F.linear(features_cat, self.embed_mask_ins.weight), extra['attn']
+
+ def forward_word_ins(self, prev_output_tokens, encoder_out=None):
+ features, extra = self.extract_features(
+ prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[2]
+ )
+ return self.output_layer(features), extra['attn']
+
+ def forward_word_del(self, prev_output_tokens, encoder_out=None):
+ features, extra = self.extract_features(
+ prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[0]
+ )
+ return F.linear(features, self.embed_word_del.weight), extra['attn']
+
+ def forward_word_del_mask_ins(self, prev_output_tokens, encoder_out=None):
+ # merge the word-deletion and mask insertion into one operation,
+ assert self.early_exit[0] == self.early_exit[1], "must the same depth."
+ features, extra = self.extract_features(
+ prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[2]
+ )
+ features_cat = torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
+ f_word_del = F.linear(features, self.embed_word_del.weight)
+ f_mask_ins = F.linear(features_cat, self.embed_mask_ins.weight)
+ return f_word_del, f_mask_ins, extra['attn']
+
+
+@register_model_architecture("levenshtein_transformer", "levenshtein_transformer")
+def base_architecture(args):
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
+ args.decoder_ffn_embed_dim = getattr(
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
+ )
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
+ args.activation_fn = getattr(args, "activation_fn", "relu")
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
+ args.share_decoder_input_output_embed = getattr(
+ args, "share_decoder_input_output_embed", False
+ )
+ args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
+ args.no_token_positional_embeddings = getattr(
+ args, "no_token_positional_embeddings", False
+ )
+ args.adaptive_input = getattr(args, "adaptive_input", False)
+ args.apply_bert_init = getattr(args, "apply_bert_init", False)
+
+ args.decoder_output_dim = getattr(
+ args, "decoder_output_dim", args.decoder_embed_dim
+ )
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
+
+ args.early_exit = getattr(args, "early_exit", "(6, 6, 6)")
+
+
+@register_model_architecture(
+ "levenshtein_transformer", "levenshtein_transformer_wmt_en_de"
+)
+def levenshtein_transformer_wmt_en_de(args):
+ base_architecture(args)
+
+
+# similar parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
+@register_model_architecture(
+ "levenshtein_transformer", "levenshtein_transformer_vaswani_wmt_en_de_big"
+)
+def levenshtein_transformer_vaswani_wmt_en_de_big(args):
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
+ args.dropout = getattr(args, "dropout", 0.3)
+ base_architecture(args)
+
+
+# default parameters used in tensor2tensor implementation
+@register_model_architecture(
+ "levenshtein_transformer", "levenshtein_transformer_wmt_en_de_big"
+)
+def levenshtein_transformer_wmt_en_de_big_t2t(args):
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
+ args.activation_dropout = getattr(args, "activation_dropout", 0.1)
+ levenshtein_transformer_vaswani_wmt_en_de_big(args)
diff --git a/fairseq/models/model_utils.py b/fairseq/models/model_utils.py
new file mode 100644
index 0000000000..8217731c9e
--- /dev/null
+++ b/fairseq/models/model_utils.py
@@ -0,0 +1,62 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+
+def skip_tensors(x, mask):
+ """
+ Getting sliced (dim=0) tensor by mask. Supporting tensor and list/dict of tensors.
+ """
+ if isinstance(x, int):
+ return x
+
+ if x is None:
+ return None
+
+ if isinstance(x, torch.Tensor):
+ if x.size(0) == mask.size(0):
+ return x[mask]
+ elif x.size(1) == mask.size(0):
+ return x[:, mask]
+
+ if isinstance(x, list):
+ return [skip_tensors(x_i, mask) for x_i in x]
+
+ if isinstance(x, dict):
+ return {k: skip_tensors(v, mask) for k, v in x.items()}
+
+ raise NotImplementedError
+
+
+def fill_tensors(x, mask, y, padding_idx):
+ """
+ Filling tensor x with y at masked positions (dim=0).
+ """
+ if x is None:
+ return y
+ assert x.dim() == y.dim() and mask.size(0) == x.size(0)
+ assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
+ n_selected = mask.sum()
+ assert n_selected == y.size(0)
+
+ if n_selected == x.size(0):
+ return y
+
+ if x.size(1) < y.size(1):
+ dims = [x.size(0), y.size(1) - x.size(1)]
+ if x.dim() == 3:
+ dims.append(x.size(2))
+ x = torch.cat([x, x.new_zeros(*dims).fill_(padding_idx)], 1)
+ x[mask] = y
+ elif x.size(1) > y.size(1):
+ x[mask] = padding_idx
+ if x.dim() == 2:
+ x[mask, :y.size(1)] = y
+ else:
+ x[mask, :y.size(1), :] = y
+ else:
+ x[mask] = y
+ return x
diff --git a/fairseq/models/nonautoregressive_transformer.py b/fairseq/models/nonautoregressive_transformer.py
new file mode 100644
index 0000000000..d45a5b443b
--- /dev/null
+++ b/fairseq/models/nonautoregressive_transformer.py
@@ -0,0 +1,640 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+from fairseq import utils
+from fairseq.models import register_model, register_model_architecture
+from fairseq.models.transformer import (
+ Embedding,
+ TransformerDecoder,
+ TransformerDecoderLayer,
+ TransformerEncoder,
+ TransformerModel,
+)
+from fairseq.modules import MultiheadAttention
+from fairseq.modules.transformer_sentence_encoder import init_bert_params
+
+
+def _mean_pooling(enc_feats, src_masks):
+ # enc_feats: T x B x C
+ # src_masks: B x T or None
+ if src_masks is None:
+ enc_feats = enc_feats.mean(0)
+ else:
+ src_masks = (~src_masks).transpose(0, 1).type_as(enc_feats)
+ enc_feats = (
+ (enc_feats / src_masks.sum(0)[None, :, None]) * src_masks[:, :, None]
+ ).sum(0)
+ return enc_feats
+
+
+def _argmax(x, dim):
+ return (x == x.max(dim, keepdim=True)[0]).type_as(x)
+
+
+def _dynamic_programming(tokens, scores):
+ N, B, T = tokens.size()
+ cum_scores = scores[:, :, 0].clone() # N x B
+ cum_choice = tokens.new_zeros(B, T)
+
+ # forward
+ for t in range(T - 1):
+ score, choice = cum_scores.max(0)
+ cum_choice[:, t] = choice
+ cum_scores[0] = score + scores[0, :, t + 1]
+ cum_scores[1:] = cum_scores[:-1] + scores[1:, :, t + 1]
+
+ # back-tracking
+ end_score, end_choice = cum_scores.max(0)
+ cum_choice[:, T - 1] = end_choice
+ for t in range(T - 2, -1, -1):
+ is_start = (cum_choice[:, t + 1] == 0).type_as(cum_choice)
+ cum_choice[:, t] = (cum_choice[:, t + 1] - 1) * ~is_start + cum_choice[
+ :, t
+ ] * is_start
+
+ # finalize the prediction
+ tokens = tokens.gather(0, cum_choice.unsqueeze(0)).squeeze(0)
+ scores = scores.gather(0, cum_choice.unsqueeze(0)).squeeze(0)
+ return scores, tokens
+
+
+def _beam_search(tokens, scores, W=None):
+ N, B, T = tokens.size()
+
+ if (W is None) or (W > N):
+ W = N
+
+
+def _uniform_assignment(src_lens, trg_lens):
+ max_trg_len = trg_lens.max()
+ steps = (src_lens.float() - 1) / (trg_lens.float() - 1) # step-size
+ # max_trg_len
+ index_t = torch.arange(max_trg_len, device=trg_lens.device).float()
+ index_t = steps[:, None] * index_t[None, :] # batch_size X max_trg_len
+ index_t = torch.round(index_t).long().detach()
+ return index_t
+
+
+@register_model("nonautoregressive_transformer")
+class NATransformerModel(TransformerModel):
+ def __init__(self, encoder, decoder):
+ super().__init__(encoder, decoder)
+ self.tgt_dict = decoder.dictionary
+ self.bos = decoder.dictionary.bos()
+ self.eos = decoder.dictionary.eos()
+ self.pad = decoder.dictionary.pad()
+ self.unk = decoder.dictionary.unk()
+
+ @staticmethod
+ def add_args(parser):
+ TransformerModel.add_args(parser)
+ parser.add_argument(
+ "--apply-bert-init",
+ action="store_true",
+ help="use custom param initialization for BERT",
+ )
+
+ # length prediction
+ parser.add_argument("--src-embedding-copy", action="store_true",
+ help="copy encoder word embeddings as the initial input of the decoder")
+ parser.add_argument("--pred-length-offset", action="store_true",
+ help="predicting the length difference between the target and source sentences")
+ parser.add_argument("--sg-length-pred", action="store_true",
+ help="stop the gradients back-propagated from the length predictor")
+ parser.add_argument("--length-loss-factor", type=float,
+ help="weights on the length prediction loss")
+
+ # n-gram predictor
+ parser.add_argument(
+ "--ngram-predictor",
+ nargs="?",
+ const=4,
+ default=1,
+ type=int,
+ help="adding an additional n-gram predictor.",
+ )
+
+ @classmethod
+ def build_decoder(cls, args, tgt_dict, embed_tokens):
+ decoder = NATransformerDecoder(args, tgt_dict, embed_tokens)
+ if getattr(args, "apply_bert_init", False):
+ decoder.apply(init_bert_params)
+ return decoder
+
+ @classmethod
+ def build_encoder(cls, args, src_dict, embed_tokens):
+ encoder = TransformerEncoder(args, src_dict, embed_tokens)
+ if getattr(args, "apply_bert_init", False):
+ encoder.apply(init_bert_params)
+ return encoder
+
+ def forward(
+ self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
+ ):
+ # encoding
+ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
+ length_out, length_tgt = self.decoder.forward_length_prediction(
+ encoder_out, tgt_tokens
+ )
+
+ word_ins_out, word_ins_tgt, word_ins_mask = self.decoder(
+ prev_output_tokens, encoder_out=encoder_out, tgt_tokens=tgt_tokens
+ )
+
+ return {
+ "word_ins_out": word_ins_out,
+ "word_ins_tgt": word_ins_tgt,
+ "word_ins_mask": word_ins_mask,
+ "length_out": length_out,
+ "length_tgt": length_tgt,
+ "length_w": self.decoder.length_loss_factor,
+ }
+
+ def forward_encoder(self, encoder_inputs):
+ return self.encoder(*encoder_inputs)
+
+ def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
+ step = decoder_out["step"]
+ output_tokens = decoder_out["output_tokens"]
+ output_scores = decoder_out["output_scores"]
+
+ # execute the decoder
+ output_masks = output_tokens.ne(self.pad)
+ _scores, _tokens = self.decoder(
+ output_tokens,
+ encoder_out=encoder_out,
+ decoding_format=decoding_format,
+ step=step,
+ )
+ output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
+ output_scores.masked_scatter_(output_masks, _scores[output_masks])
+
+ return {"output_tokens": output_tokens, "output_scores": output_scores}
+
+ def initialize_output_tokens(self, encoder_out, src_tokens):
+ # length prediction
+ _, length_tgt = self.decoder.forward_length_prediction(encoder_out)
+ max_length = length_tgt.max()
+ idx_length = torch.arange(max_length, device=src_tokens.device)
+
+ initial_output_tokens = src_tokens.new_zeros(
+ src_tokens.size(0), max_length
+ ).fill_(self.pad)
+ initial_output_tokens.masked_fill_(
+ idx_length[None, :] < length_tgt[:, None], self.unk
+ )
+ initial_output_tokens[:, 0] = self.bos
+ initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos)
+
+ initial_output_scores = initial_output_tokens.new_zeros(
+ *initial_output_tokens.size()
+ ).type_as(encoder_out["encoder_out"])
+
+ return {
+ "output_tokens": initial_output_tokens,
+ "output_scores": initial_output_scores,
+ }
+
+
+class NATransformerDecoder(TransformerDecoder):
+ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
+ super().__init__(
+ args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
+ )
+
+ self.dictionary = dictionary
+ self.bos = dictionary.bos()
+ self.unk = dictionary.unk()
+ self.eos = dictionary.eos()
+
+ self.encoder_embed_dim = args.encoder_embed_dim
+ self.sg_length_pred = getattr(args, "sg_length_pred", False)
+ self.pred_length_offset = getattr(args, "pred_length_offset", False)
+ self.length_loss_factor = getattr(args, "length_loss_factor", 0.1)
+ self.src_embedding_copy = getattr(args, "src_embedding_copy", False)
+ self.embed_length = Embedding(256, self.encoder_embed_dim, None)
+
+ self.ngram_predictor = getattr(args, "ngram_predictor", 1)
+ self.ngram_layer = (
+ None if (self.ngram_predictor == 1) else NgramDecoderLayer(args, True)
+ )
+
+ def forward(
+ self,
+ prev_output_tokens,
+ encoder_out=None,
+ tgt_tokens=None,
+ decoding_format=None,
+ step=0,
+ **kwargs
+ ):
+
+ features, _ = self.extract_features(
+ prev_output_tokens,
+ encoder_out=encoder_out,
+ embedding_copy=(step == 0) & self.src_embedding_copy,
+ )
+
+ if tgt_tokens is not None:
+ if self.ngram_layer is None:
+ word_ins_mask = tgt_tokens.ne(self.padding_idx)
+ word_ins_tgt = tgt_tokens
+ else:
+ context_embeds, context_masks = self.forward_ngram_context(tgt_tokens)
+ features = self.ngram_layer(features, context_embeds=context_embeds)
+ word_ins_tgt = tgt_tokens[:, :, None].repeat(1, 1, self.ngram_predictor)
+ word_ins_mask = word_ins_tgt.ne(self.padding_idx) & context_masks
+
+ return self.output_layer(features), word_ins_tgt, word_ins_mask
+
+ else:
+ if self.ngram_layer is None:
+ return F.log_softmax(self.output_layer(features), -1).max(-1)
+ else:
+ # inner iterations
+ return self.forward_ngram_decoding(
+ features, prev_output_tokens.eq(self.padding_idx), decoding_format
+ )
+
+ def extract_features(
+ self,
+ prev_output_tokens,
+ encoder_out=None,
+ early_exit=None,
+ embedding_copy=False,
+ **unused
+ ):
+ """
+ Similar to *forward* but only return features.
+
+ Inputs:
+ prev_output_tokens: Tensor(B, T)
+ encoder_out: a dictionary of hidden states and masks
+
+ Returns:
+ tuple:
+ - the decoder's features of shape `(batch, tgt_len, embed_dim)`
+ - a dictionary with any model-specific outputs
+ the LevenshteinTransformer decoder has full-attention to all generated tokens
+ """
+ # embedding
+ if embedding_copy:
+ src_embd = encoder_out["encoder_embedding"]
+ src_mask = encoder_out["encoder_padding_mask"]
+ src_mask = (
+ ~src_mask
+ if src_mask is not None
+ else prev_output_tokens.new_ones(*src_embd.size()[:2]).bool()
+ )
+
+ x, decoder_padding_mask = self.forward_embedding(
+ prev_output_tokens,
+ self.forward_copying_source(
+ src_embd, src_mask, prev_output_tokens.ne(self.padding_idx)
+ ),
+ )
+
+ else:
+
+ x, decoder_padding_mask = self.forward_embedding(prev_output_tokens)
+
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+ attn = None
+ inner_states = [x]
+
+ # decoder layers
+ for i, layer in enumerate(self.layers):
+
+ # early exit from the decoder.
+ if (early_exit is not None) and (i >= early_exit):
+ break
+
+ x, attn = layer(
+ x,
+ encoder_out["encoder_out"] if encoder_out is not None else None,
+ encoder_out["encoder_padding_mask"]
+ if encoder_out is not None
+ else None,
+ self_attn_mask=None,
+ self_attn_padding_mask=decoder_padding_mask,
+ )
+ inner_states.append(x)
+
+ if self.layer_norm:
+ x = self.layer_norm(x)
+
+ # T x B x C -> B x T x C
+ x = x.transpose(0, 1)
+
+ if self.project_out_dim is not None:
+ x = self.project_out_dim(x)
+
+ return x, {"attn": attn, "inner_states": inner_states}
+
+ def forward_ngram_context(self, tgt_tokens):
+ tgt_embeds = self.forward_embedding(tgt_tokens)
+ n_contexts = self.ngram_predictor - 1
+
+ # shifting the embeddings
+ # context_embeds: N x B x T x C
+ # context_masks: B x T x N
+ context_embeds = tgt_embeds.new_zeros(n_contexts, *tgt_embeds.size())
+ context_masks = tgt_embeds.new_ones(
+ *tgt_embeds.size()[:2], self.ngram_predictor
+ ).bool()
+
+ for k in range(n_contexts):
+ context_embeds[k, :, k + 1:] = tgt_embeds[:, : -k - 1]
+ context_masks[:, : k + 1, k + 1] = 0
+
+ return context_embeds, context_masks
+
+ def forward_ngram_decoding(self, features, padding_mask=None, decoding_format=None):
+ context_embeds = None
+ scores, tokens = [], []
+ ensemble_score = None
+ ensemble_index = None
+
+ if decoding_format is None:
+ decoding_format = "ensemble"
+
+ for k in range(self.ngram_predictor):
+ ngram_out = self.ngram_layer(
+ features, context_embeds=context_embeds, incremental=True
+ )
+ ngram_scores = F.log_softmax(self.output_layer(ngram_out), -1)
+ max_score, max_token = ngram_scores.max(-1)
+
+ if decoding_format == "vote":
+ ngram_scores = _argmax(ngram_scores, -1)
+
+ if ensemble_score is None:
+ ensemble_score = ngram_scores
+ ensemble_index = ensemble_score.new_ones(*ensemble_score.size()[:2])
+ else:
+ ensemble_index[:, k:] = ensemble_index[:, k:] + 1
+ ensemble_score = ensemble_score + ngram_scores.masked_fill_(
+ (ensemble_index < k)
+ .unsqueeze(2)
+ .repeat(1, 1, ensemble_score.size(2)),
+ 0,
+ )
+ max_score[:, :k] = float("-inf")
+
+ if decoding_format == "unigram":
+ break
+
+ scores.append(max_score.masked_fill_(padding_mask, 0))
+ tokens.append(max_token.masked_fill_(padding_mask, self.padding_idx))
+
+ # context_embeds: N x B x T x C
+ if context_embeds is None:
+ context_embeds = self.forward_embedding(max_token).unsqueeze(0)
+
+ else:
+ context_embeds = torch.cat(
+ [self.forward_embedding(max_token).unsqueeze(0), context_embeds], 0
+ )
+
+ context_embeds[:, :, 1:] = context_embeds[:, :, :-1]
+
+ if decoding_format != "dp":
+ ensemble_score = ensemble_score / ensemble_index.unsqueeze(2)
+ return ensemble_score.max(-1)
+
+ else:
+ tokens = torch.cat([t.unsqueeze(0) for t in tokens], 0)
+ scores = torch.cat([s.unsqueeze(0) for s in scores], 0)
+ return _dynamic_programming(tokens, scores)
+
+ def forward_embedding(self, prev_output_tokens, states=None):
+ # embed positions
+ positions = (
+ self.embed_positions(prev_output_tokens)
+ if self.embed_positions is not None
+ else None
+ )
+
+ # embed tokens and positions
+ if states is None:
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens)
+ if self.project_in_dim is not None:
+ x = self.project_in_dim(x)
+ else:
+ x = states
+
+ if positions is not None:
+ x += positions
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ decoder_padding_mask = prev_output_tokens.eq(self.padding_idx)
+ return x, decoder_padding_mask
+
+ def forward_copying_source(self, src_embeds, src_masks, tgt_masks):
+ length_sources = src_masks.sum(1)
+ length_targets = tgt_masks.sum(1)
+ mapped_inputs = _uniform_assignment(length_sources, length_targets).masked_fill(
+ ~tgt_masks, 0
+ )
+ copied_embedding = torch.gather(
+ src_embeds,
+ 1,
+ mapped_inputs.unsqueeze(-1).expand(
+ *mapped_inputs.size(), src_embeds.size(-1)
+ ),
+ )
+ return copied_embedding
+
+ def forward_length_prediction(self, encoder_out, tgt_tokens=None):
+ enc_feats = encoder_out["encoder_out"] # T x B x C
+ src_masks = encoder_out["encoder_padding_mask"] # B x T or None
+
+ if self.pred_length_offset:
+ if src_masks is None:
+ src_lengs = enc_feats.new_ones(enc_feats.size(1)).fill_(
+ enc_feats.size(0)
+ )
+ else:
+ src_lengs = (~src_masks).transpose(0, 1).type_as(enc_feats).sum(0)
+ src_lengs = src_lengs.long()
+
+ enc_feats = _mean_pooling(enc_feats, src_masks)
+ if self.sg_length_pred:
+ enc_feats = enc_feats.detach()
+
+ length_out = F.linear(enc_feats, self.embed_length.weight)
+
+ if tgt_tokens is not None:
+ # obtain the length target
+ tgt_lengs = tgt_tokens.ne(self.padding_idx).sum(1).long()
+ if self.pred_length_offset:
+ length_tgt = tgt_lengs - src_lengs + 128
+ else:
+ length_tgt = tgt_lengs
+ length_tgt = length_tgt.clamp(min=0, max=255)
+
+ else:
+ # predict the length target (greedy for now)
+ # TODO: implementing length-beam
+ pred_lengs = length_out.max(-1)[1]
+ if self.pred_length_offset:
+ length_tgt = pred_lengs - 128 + src_lengs
+ else:
+ length_tgt = pred_lengs
+
+ return length_out, length_tgt
+
+
+class NgramDecoderLayer(TransformerDecoderLayer):
+ """
+ N-gram Decoder Layer:
+
+ This module can be pluged in the last layer of any Non-autoregressive Model's
+ It provides an alternative way to capture local n-gram information by running the block multiple times.
+ """
+
+ def __init__(self, args, no_encoder_attn=False):
+ super(NgramDecoderLayer, self).__init__(args, no_encoder_attn=no_encoder_attn)
+ self.self_attn = MultiheadAttention(
+ embed_dim=self.embed_dim,
+ num_heads=1, # maybe n-gram does not need too many heads.
+ dropout=args.attention_dropout,
+ self_attention=False,
+ encoder_decoder_attention=True,
+ )
+
+ def forward(
+ self,
+ x,
+ encoder_out=None,
+ encoder_padding_mask=None,
+ context_embeds=None,
+ incremental=False,
+ ):
+ # x: T x B x C
+ # context_embeds: N x T x B x C
+ T, B, C = x.size()
+
+ residual = x
+ x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
+ x = x.contiguous().view(1, T * B, C).contiguous()
+
+ if context_embeds is not None:
+ N = context_embeds.size(0)
+ context_embeds = context_embeds.view(N, T * B, C).contiguous()
+
+ if not incremental:
+ assert context_embeds is not None, "we need context for training"
+ # attn_weights: (n_head x T x B) x 1 x N
+ # v: (n_head x T x B) x N x (dim / n_head)
+ # -- move the attention computation outside --
+ attn_weights, values = self.self_attn(
+ query=x, key=context_embeds, value=context_embeds, before_softmax=True
+ )
+
+ attn_weights = attn_weights.repeat(1, N, 1)
+ attn_masks = attn_weights.new_ones(N, N).triu_(1).bool()
+ attn_masks = attn_masks.unsqueeze(0).repeat(attn_weights.size(0), 1, 1)
+
+ attn_weights = attn_weights.masked_fill(attn_masks, float("-inf"))
+ attn_weights = utils.softmax(attn_weights, dim=-1).type_as(attn_weights)
+ attn_weights = F.dropout(
+ attn_weights, p=self.self_attn.dropout, training=self.training
+ )
+
+ # (n_head x T x B) x N x (dim / n_head)
+ attn = torch.bmm(attn_weights, values)
+ attn = attn.transpose(0, 1).contiguous()
+ attn = attn.view(N, T * B, C).contiguous()
+ attn = attn.transpose(1, 0).contiguous()
+ attn = attn.view(T, B, N, C)
+
+ residual = residual.unsqueeze(2)
+ x = self.self_attn.out_proj(attn)
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = torch.cat([residual, residual + x], 2)
+
+ else:
+ if context_embeds is None:
+ x = residual
+
+ else:
+ x, _ = self.self_attn(query=x, key=context_embeds, value=context_embeds)
+ x = x.view(T, B, C)
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = residual + x
+
+ x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
+
+ if self.encoder_attn is not None:
+ raise NotImplementedError
+
+ residual = x
+ x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
+ x = self.activation_fn(self.fc1(x))
+ x = F.dropout(x, p=self.activation_dropout, training=self.training)
+ x = self.fc2(x)
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = residual + x
+ x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
+ return x
+
+
+@register_model_architecture(
+ "nonautoregressive_transformer", "nonautoregressive_transformer"
+)
+def base_architecture(args):
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
+ args.decoder_ffn_embed_dim = getattr(
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
+ )
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
+ args.activation_fn = getattr(args, "activation_fn", "relu")
+ args.dropout = getattr(args, "dropout", 0.1)
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
+ args.share_decoder_input_output_embed = getattr(
+ args, "share_decoder_input_output_embed", False
+ )
+ args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
+ args.no_token_positional_embeddings = getattr(
+ args, "no_token_positional_embeddings", False
+ )
+ args.adaptive_input = getattr(args, "adaptive_input", False)
+ args.apply_bert_init = getattr(args, "apply_bert_init", False)
+
+ args.decoder_output_dim = getattr(
+ args, "decoder_output_dim", args.decoder_embed_dim
+ )
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
+
+ # --- special arguments ---
+ args.sg_length_pred = getattr(args, "sg_length_pred", False)
+ args.pred_length_offset = getattr(args, "pred_length_offset", False)
+ args.length_loss_factor = getattr(args, "length_loss_factor", 0.1)
+ args.src_embedding_copy = getattr(args, "src_embedding_copy", False)
+ args.ngram_predictor = getattr(args, "ngram_predictor", 1)
+
+
+@register_model_architecture(
+ "nonautoregressive_transformer", "nonautoregressive_transformer_wmt_en_de"
+)
+def nonautoregressive_transformer_wmt_en_de(args):
+ base_architecture(args)
diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py
index 7fedc77550..dd10ae5357 100644
--- a/fairseq/models/transformer.py
+++ b/fairseq/models/transformer.py
@@ -172,7 +172,7 @@ def build_embedding(dictionary, embed_dim, path=None):
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
- return TransformerModel(encoder, decoder)
+ return cls(encoder, decoder)
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
@@ -222,7 +222,15 @@ def __init__(self, args, dictionary, embed_tokens):
else:
self.layer_norm = None
- def forward(self, src_tokens, src_lengths):
+ def forward_embedding(self, src_tokens):
+ # embed tokens and positions
+ embed = self.embed_scale * self.embed_tokens(src_tokens)
+ if self.embed_positions is not None:
+ x = embed + self.embed_positions(src_tokens)
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ return x, embed
+
+ def forward(self, src_tokens, src_lengths, cls_input=None):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
@@ -237,11 +245,7 @@ def forward(self, src_tokens, src_lengths):
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
"""
- # embed tokens and positions
- x = self.embed_scale * self.embed_tokens(src_tokens)
- if self.embed_positions is not None:
- x += self.embed_positions(src_tokens)
- x = F.dropout(x, p=self.dropout, training=self.training)
+ x, encoder_embedding = self.forward_embedding(src_tokens)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
@@ -261,6 +265,7 @@ def forward(self, src_tokens, src_lengths):
return {
'encoder_out': x, # T x B x C
'encoder_padding_mask': encoder_padding_mask, # B x T
+ 'encoder_embedding': encoder_embedding, # B x T x C
}
def reorder_encoder_out(self, encoder_out, new_order):
@@ -332,7 +337,7 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
embed_dim = args.decoder_embed_dim
self.output_embed_dim = args.decoder_output_dim
- padding_idx = embed_tokens.padding_idx
+ self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
@@ -341,7 +346,7 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None
self.embed_positions = PositionalEmbedding(
- args.max_target_positions, embed_dim, padding_idx,
+ args.max_target_positions, embed_dim, self.padding_idx,
learned=args.decoder_learned_pos,
) if not args.no_token_positional_embeddings else None
diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py
index 4da628655e..8c28255dfb 100644
--- a/fairseq/modules/multihead_attention.py
+++ b/fairseq/modules/multihead_attention.py
@@ -91,7 +91,7 @@ def reset_parameters(self):
nn.init.xavier_normal_(self.bias_v)
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
- need_weights=True, static_kv=False, attn_mask=None):
+ need_weights=True, static_kv=False, attn_mask=None, before_softmax=False):
"""Input shape: Time x Batch x Channel
Timesteps can be masked by supplying a T x T mask in the
@@ -239,6 +239,9 @@ def forward(self, query, key, value, key_padding_mask=None, incremental_state=No
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+ if before_softmax:
+ return attn_weights, v
+
attn_weights = utils.softmax(
attn_weights, dim=-1, onnx_trace=self.onnx_trace,
).type_as(attn_weights)
diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py
index 5da4909ca2..f4a80cceea 100644
--- a/fairseq/modules/transformer_layer.py
+++ b/fairseq/modules/transformer_layer.py
@@ -83,7 +83,7 @@ def forward(self, x, encoder_padding_mask, attn_mask=None):
residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
if attn_mask is not None:
- attn_mask = attn_mask.masked_fill(attn_mask.byte(), -1e8)
+ attn_mask = attn_mask.masked_fill(attn_mask.bool(), -1e8)
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py
index 1699291253..9be7ab3080 100644
--- a/fairseq/modules/transformer_sentence_encoder.py
+++ b/fairseq/modules/transformer_sentence_encoder.py
@@ -36,7 +36,8 @@ def init_bert_params(module):
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
- module.weight.data[module.padding_idx].zero_()
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
if isinstance(module, MultiheadAttention):
module.in_proj_weight.data.normal_(mean=0.0, std=0.02)
diff --git a/fairseq/options.py b/fairseq/options.py
index 54c7863908..bb1e27aeb7 100644
--- a/fairseq/options.py
+++ b/fairseq/options.py
@@ -280,6 +280,8 @@ def add_dataset_args(parser, train=False, gen=False):
' (train, valid, valid1, test, test1)')
group.add_argument('--validate-interval', type=int, default=1, metavar='N',
help='validate every N epochs')
+ group.add_argument('--fixed-validation-seed', default=None, type=int, metavar='N',
+ help='specified random seed for validation')
group.add_argument('--disable-validation', action='store_true',
help='disable validation')
group.add_argument('--max-tokens-valid', type=int, metavar='N',
@@ -493,6 +495,18 @@ def add_generation_args(parser):
help='strength of diversity penalty for Diverse Beam Search')
group.add_argument('--print-alignment', action='store_true',
help='if set, uses attention feedback to compute and print alignment to source tokens')
+ group.add_argument('--print-step', action='store_true')
+
+ # arguments for iterative refinement generator
+ group.add_argument('---iter-decode-eos-penalty', default=0.0, type=float, metavar='N',
+ help='if > 0.0, it penalized early-stopping in decoding.')
+ group.add_argument('--iter-decode-max-iter', default=10, type=int, metavar='N',
+ help='maximum iterations for iterative refinement.')
+ group.add_argument('--iter-decode-force-max-iter', action='store_true',
+ help='if set, run exact the maximum number of iterations without early stop')
+
+ # special decoding format for advanced decoding.
+ group.add_argument('--decoding-format', default=None, type=str, choices=['unigram', 'ensemble', 'vote', 'dp', 'bs'])
# fmt: on
return group
diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py
index d3f51cb35c..f3d60403ba 100644
--- a/fairseq/tasks/translation.py
+++ b/fairseq/tasks/translation.py
@@ -12,6 +12,7 @@
data_utils,
indexed_dataset,
LanguagePairDataset,
+ PrependTokenDataset,
)
from . import FairseqTask, register_task
@@ -22,7 +23,8 @@ def load_langpair_dataset(
src, src_dict,
tgt, tgt_dict,
combine, dataset_impl, upsample_primary,
- left_pad_source, left_pad_target, max_source_positions, max_target_positions,
+ left_pad_source, left_pad_target, max_source_positions,
+ max_target_positions, prepend_bos=False,
):
def split_exists(split, src, tgt, lang, data_path):
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
@@ -67,6 +69,11 @@ def split_exists(split, src, tgt, lang, data_path):
src_dataset = ConcatDataset(src_datasets, sample_ratios)
tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
+ if prepend_bos:
+ assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
+ src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
+ tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
+
return LanguagePairDataset(
src_dataset, src_dataset.sizes, src_dict,
tgt_dataset, tgt_dataset.sizes, tgt_dict,
diff --git a/fairseq/tasks/translation_lev.py b/fairseq/tasks/translation_lev.py
new file mode 100644
index 0000000000..47d6a3ed4a
--- /dev/null
+++ b/fairseq/tasks/translation_lev.py
@@ -0,0 +1,149 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from fairseq.tasks import register_task
+from fairseq.tasks.translation import TranslationTask, load_langpair_dataset
+
+
+@register_task('translation_lev')
+class TranslationLevenshteinTask(TranslationTask):
+ """
+ Translation (Sequence Generation) task for Levenshtein Transformer
+ See `"Levenshtein Transformer" `_.
+ """
+
+ @staticmethod
+ def add_args(parser):
+ """Add task-specific arguments to the parser."""
+ # fmt: off
+ TranslationTask.add_args(parser)
+ parser.add_argument(
+ '--noise',
+ default='random_delete',
+ choices=['random_delete', 'random_mask', 'no_noise', 'full_mask'])
+
+ def load_dataset(self, split, epoch=0, combine=False, **kwargs):
+ """Load a given dataset split.
+
+ Args:
+ split (str): name of the split (e.g., train, valid, test)
+ """
+ paths = self.args.data.split(':')
+ assert len(paths) > 0
+ data_path = paths[epoch % len(paths)]
+
+ # infer langcode
+ src, tgt = self.args.source_lang, self.args.target_lang
+
+ self.datasets[split] = load_langpair_dataset(
+ data_path, split, src, self.src_dict, tgt, self.tgt_dict,
+ combine=combine, dataset_impl=self.args.dataset_impl,
+ upsample_primary=self.args.upsample_primary,
+ left_pad_source=self.args.left_pad_source,
+ left_pad_target=self.args.left_pad_target,
+ max_source_positions=self.args.max_source_positions,
+ max_target_positions=self.args.max_target_positions,
+ prepend_bos=True,
+ )
+
+ def inject_noise(self, target_tokens):
+ def _random_delete(target_tokens):
+ pad = self.tgt_dict.pad()
+ bos = self.tgt_dict.bos()
+ eos = self.tgt_dict.eos()
+
+ max_len = target_tokens.size(1)
+ target_mask = target_tokens.eq(pad)
+ target_score = target_tokens.clone().float().uniform_()
+ target_score.masked_fill_(
+ target_tokens.eq(bos) | target_tokens.eq(eos), 0.0)
+ target_score.masked_fill_(target_mask, 1)
+ target_score, target_rank = target_score.sort(1)
+ target_length = target_mask.size(1) - target_mask.float().sum(
+ 1, keepdim=True)
+
+ # do not delete and (we assign 0 score for them)
+ target_cutoff = 2 + ((target_length - 2) * target_score.new_zeros(
+ target_score.size(0), 1).uniform_()).long()
+ target_cutoff = target_score.sort(1)[1] >= target_cutoff
+
+ prev_target_tokens = target_tokens.gather(
+ 1, target_rank).masked_fill_(target_cutoff, pad).gather(
+ 1,
+ target_rank.masked_fill_(target_cutoff,
+ max_len).sort(1)[1])
+ prev_target_tokens = prev_target_tokens[:, :prev_target_tokens.
+ ne(pad).sum(1).max()]
+
+ return prev_target_tokens
+
+ def _random_mask(target_tokens):
+ pad = self.tgt_dict.pad()
+ bos = self.tgt_dict.bos()
+ eos = self.tgt_dict.eos()
+ unk = self.tgt_dict.unk()
+
+ target_mask = target_tokens.eq(bos) | target_tokens.eq(
+ eos) | target_tokens.eq(pad)
+ target_score = target_tokens.clone().float().uniform_()
+ target_score.masked_fill_(target_mask, 1.0)
+
+ prev_target_tokens = target_tokens.masked_fill(
+ target_score < target_score.new_zeros(target_score.size(0),
+ 1).uniform_(), unk)
+ return prev_target_tokens
+
+ def _full_mask(target_tokens):
+ pad = self.tgt_dict.pad()
+ bos = self.tgt_dict.bos()
+ eos = self.tgt_dict.eos()
+ unk = self.tgt_dict.unk()
+
+ target_mask = target_tokens.eq(bos) | target_tokens.eq(
+ eos) | target_tokens.eq(pad)
+ return target_tokens.masked_fill(~target_mask, unk)
+
+ if self.args.noise == 'random_delete':
+ return _random_delete(target_tokens)
+ elif self.args.noise == 'random_mask':
+ return _random_mask(target_tokens)
+ elif self.args.noise == 'full_mask':
+ return _full_mask(target_tokens)
+ elif self.args.noise == 'no_noise':
+ return target_tokens
+ else:
+ raise NotImplementedError
+
+ def build_generator(self, args):
+ from fairseq.iterative_refinement_generator import IterativeRefinementGenerator
+ return IterativeRefinementGenerator(
+ self.target_dictionary,
+ eos_penalty=getattr(args, 'iter_decode_eos_penalty', 0.0),
+ max_iter=getattr(args, 'iter_decode_max_iter', 10),
+ decoding_format=getattr(args, 'decoding_format', None),
+ adaptive=not getattr(args, 'iter_decode_force_max_iter', False))
+
+ def train_step(self,
+ sample,
+ model,
+ criterion,
+ optimizer,
+ ignore_grad=False):
+ model.train()
+ sample['prev_target'] = self.inject_noise(sample['target'])
+ loss, sample_size, logging_output = criterion(model, sample)
+ if ignore_grad:
+ loss *= 0
+ optimizer.backward(loss)
+ return loss, sample_size, logging_output
+
+ def valid_step(self, sample, model, criterion):
+ model.eval()
+ with torch.no_grad():
+ sample['prev_target'] = self.inject_noise(sample['target'])
+ loss, sample_size, logging_output = criterion(model, sample)
+ return loss, sample_size, logging_output
diff --git a/fairseq/utils.py b/fairseq/utils.py
index 1af2394434..80ecb6d083 100644
--- a/fairseq/utils.py
+++ b/fairseq/utils.py
@@ -359,3 +359,11 @@ def has_parameters(module):
return True
except StopIteration:
return False
+
+
+def set_torch_seed(seed):
+ # Set seed based on args.seed and the update number so that we get
+ # reproducible results when resuming from checkpoints
+ assert isinstance(seed, int)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
diff --git a/generate.py b/generate.py
index c23cc79868..6de1a69abd 100644
--- a/generate.py
+++ b/generate.py
@@ -159,6 +159,9 @@ def main(args):
' '.join(map(lambda x: str(utils.item(x)), alignment))
))
+ if args.print_step:
+ print('I-{}\t{}'.format(sample_id, hypo['steps']))
+
# Score only the top hypothesis
if has_target and j == 0:
if align_dict is not None or args.remove_bpe is not None:
diff --git a/setup.py b/setup.py
index 8f4604be11..33849f8105 100644
--- a/setup.py
+++ b/setup.py
@@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
from setuptools import setup, find_packages, Extension
+from torch.utils import cpp_extension
import sys
@@ -60,6 +61,12 @@ def include_dirs(self, dirs):
language='c++',
extra_compile_args=extra_compile_args,
),
+ cpp_extension.CppExtension(
+ 'fairseq.libnat',
+ sources=[
+ 'fairseq/clib/libnat/edit_dist.cpp',
+ ],
+ )
]
@@ -106,5 +113,6 @@ def include_dirs(self, dirs):
'fairseq-validate = fairseq_cli.validate:cli_main',
],
},
+ cmdclass={'build_ext': cpp_extension.BuildExtension},
zip_safe=False,
)
diff --git a/tests/test_binaries.py b/tests/test_binaries.py
index b517278273..8cede3c9fa 100644
--- a/tests/test_binaries.py
+++ b/tests/test_binaries.py
@@ -180,6 +180,52 @@ def test_dynamicconv(self):
])
generate_main(data_dir)
+ def test_levenshtein_transformer(self):
+ with contextlib.redirect_stdout(StringIO()):
+ with tempfile.TemporaryDirectory('test_levenshtein_transformer') as data_dir:
+ create_dummy_data(data_dir)
+ preprocess_translation_data(data_dir)
+ train_translation_model(data_dir, 'levenshtein_transformer', [
+ '--apply-bert-init', '--early-exit', '6,6,6',
+ '--criterion', 'nat_loss'
+ ], task='translation_lev')
+ generate_main(data_dir)
+
+ def test_nonautoregressive_transformer(self):
+ with contextlib.redirect_stdout(StringIO()):
+ with tempfile.TemporaryDirectory('test_nonautoregressive_transformer') as data_dir:
+ create_dummy_data(data_dir)
+ preprocess_translation_data(data_dir)
+ train_translation_model(data_dir, 'nonautoregressive_transformer', [
+ '--apply-bert-init', '--src-embedding-copy', '--criterion',
+ 'nat_loss', '--noise', 'full_mask', '--pred-length-offset',
+ '--length-loss-factor', '0.1'
+ ], task='translation_lev')
+ generate_main(data_dir)
+
+ def test_iterative_nonautoregressive_transformer(self):
+ with contextlib.redirect_stdout(StringIO()):
+ with tempfile.TemporaryDirectory('test_iterative_nonautoregressive_transformer') as data_dir:
+ create_dummy_data(data_dir)
+ preprocess_translation_data(data_dir)
+ train_translation_model(data_dir, 'iterative_nonautoregressive_transformer', [
+ '--apply-bert-init', '--src-embedding-copy', '--criterion',
+ 'nat_loss', '--noise', 'full_mask', '--stochastic-approx',
+ '--dae-ratio', '0.5', '--train-step', '3'
+ ], task='translation_lev')
+ generate_main(data_dir)
+
+ def test_insertion_transformer(self):
+ with contextlib.redirect_stdout(StringIO()):
+ with tempfile.TemporaryDirectory('test_insertion_transformer') as data_dir:
+ create_dummy_data(data_dir)
+ preprocess_translation_data(data_dir)
+ train_translation_model(data_dir, 'insertion_transformer', [
+ '--apply-bert-init', '--criterion', 'nat_loss', '--noise',
+ 'random_mask'
+ ], task='translation_lev')
+ generate_main(data_dir)
+
def test_mixture_of_experts(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_moe') as data_dir:
diff --git a/train.py b/train.py
index db04dc2190..3879375fe9 100644
--- a/train.py
+++ b/train.py
@@ -194,6 +194,11 @@ def get_training_stats(trainer):
def validate(args, trainer, task, epoch_itr, subsets):
"""Evaluate the model on the validation set(s) and return the losses."""
+
+ if args.fixed_validation_seed is not None:
+ # set fixed seed for every validation
+ utils.set_torch_seed(args.fixed_validation_seed)
+
valid_losses = []
for subset in subsets:
# Initialize data iterator