diff --git a/README.md b/README.md index 1c94920049..fd183b5aa6 100644 --- a/README.md +++ b/README.md @@ -89,7 +89,12 @@ and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more example * [PyTorch](http://pytorch.org/) version >= 1.4.0 * Python version >= 3.6 * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) -* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library with the `--cuda_ext` and `--deprecated_fused_adam` options +* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library: +```bash +git clone https://github.com/NVIDIA/apex +cd apex +pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--deprecated_fused_adam" --global-option="--xentropy" --global-option="--fast_multihead_attn" ./ +``` To install fairseq: ```bash diff --git a/fairseq/criterions/masked_lm.py b/fairseq/criterions/masked_lm.py index 8f98c5a7e7..52cf25b6f1 100644 --- a/fairseq/criterions/masked_lm.py +++ b/fairseq/criterions/masked_lm.py @@ -8,7 +8,7 @@ import torch import torch.nn.functional as F -from fairseq import metrics, utils +from fairseq import metrics, modules, utils from fairseq.criterions import FairseqCriterion, register_criterion @@ -47,12 +47,8 @@ def forward(self, model, sample, reduce=True): targets = model.get_targets(sample, [logits]) targets = targets[masked_tokens] - loss = F.nll_loss( - F.log_softmax( - logits.view(-1, logits.size(-1)), - dim=-1, - dtype=torch.float32, - ), + loss = modules.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1), reduction='sum', ignore_index=self.padding_idx, diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py index bb3a138e9a..422e123058 100644 --- a/fairseq/modules/__init__.py +++ b/fairseq/modules/__init__.py @@ -8,6 +8,7 @@ from .beamable_mm import BeamableMM from .character_token_embedder import CharacterTokenEmbedder from .conv_tbc import ConvTBC +from .cross_entropy import cross_entropy from .downsampled_multihead_attention import DownsampledMultiHeadAttention from .dynamic_convolution import DynamicConv, DynamicConv1dTBC from .dynamic_crf_layer import DynamicCRF @@ -36,6 +37,7 @@ 'BeamableMM', 'CharacterTokenEmbedder', 'ConvTBC', + 'cross_entropy', 'DownsampledMultiHeadAttention', 'DynamicConv1dTBC', 'DynamicConv', diff --git a/fairseq/modules/cross_entropy.py b/fairseq/modules/cross_entropy.py new file mode 100644 index 0000000000..83a7df96fe --- /dev/null +++ b/fairseq/modules/cross_entropy.py @@ -0,0 +1,50 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +import torch.nn.functional as F + + +logger = logging.getLogger(__name__) + + +def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction='mean'): + lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + return F.nll_loss( + lprobs, target, ignore_index=ignore_index, reduction=reduction, + ) + + +try: + from apex.contrib import xentropy + + logger.info('using fused cross entropy') + + def cross_entropy(logits, target, ignore_index=-100, reduction='mean'): + if logits.device == torch.device('cpu'): + return _cross_entropy_pytorch(logits, target, ignore_index, reduction) + else: + half_to_float = (logits.dtype == torch.half) + losses = xentropy.SoftmaxCrossEntropyLoss.apply( + logits, target, 0.0, ignore_index, half_to_float, + ) + if reduction == 'sum': + return losses.sum() + elif reduction == 'mean': + if ignore_index >= 0: + return losses.sum() / target.ne(ignore_index).sum() + else: + return losses.mean() + elif reduction == 'none': + return losses + else: + raise NotImplementedError + +except ImportError: + + def cross_entropy(logits, target, ignore_index=-100, reduction='mean'): + return _cross_entropy_pytorch(logits, target, ignore_index, reduction)