Skip to content

Commit

Permalink
Make FusedLayerNorm wrapper pickleable
Browse files Browse the repository at this point in the history
Summary: See https://fb.workplace.com/groups/1405155842844877/permalink/3408783919148716/

Differential Revision: D20133239

fbshipit-source-id: 78ee12b7a573af7cfe9acb9fd2557cbbb78b0592
  • Loading branch information
Myle Ott authored and facebook-github-bot committed Feb 27, 2020
1 parent fdfdbec commit cef5653
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions fairseq/modules/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,21 @@
import torch


def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
if not export and torch.cuda.is_available():
try:
from apex.normalization import FusedLayerNorm
try:
from apex.normalization import FusedLayerNorm as _FusedLayerNorm
has_fused_layernorm = True

class FusedLayerNorm(_FusedLayerNorm):

class _FusedLayerNorm(FusedLayerNorm):
@torch.jit.unused
def forward(self, x):
return super().forward(x)

@torch.jit.unused
def forward(self, x):
return super().forward(x)
except ImportError:
has_fused_layernorm = False

return _FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass

def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
if not export and torch.cuda.is_available() and has_fused_layernorm:
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)

0 comments on commit cef5653

Please # to comment.