diff --git a/examples/speech_recognition/criterions/ASG_loss.py b/examples/speech_recognition/criterions/ASG_loss.py index 29c8a3d78e..c0b7cf693d 100644 --- a/examples/speech_recognition/criterions/ASG_loss.py +++ b/examples/speech_recognition/criterions/ASG_loss.py @@ -13,8 +13,6 @@ from fairseq.criterions import FairseqCriterion, register_criterion from examples.speech_recognition.data.replabels import pack_replabels -from wav2letter.criterion import ASGLoss, CriterionScaleMode - @register_criterion("asg_loss") class ASGCriterion(FairseqCriterion): @@ -43,6 +41,8 @@ def add_args(parser): ) def __init__(self, args, task): + from wav2letter.criterion import ASGLoss, CriterionScaleMode + super().__init__(args, task) self.tgt_dict = task.target_dictionary self.eos = self.tgt_dict.eos() diff --git a/examples/speech_recognition/datasets/asr_prep_json.py b/examples/speech_recognition/datasets/asr_prep_json.py index e4b5d8f52f..2bab825b89 100644 --- a/examples/speech_recognition/datasets/asr_prep_json.py +++ b/examples/speech_recognition/datasets/asr_prep_json.py @@ -14,7 +14,6 @@ import json import sentencepiece as spm import multiprocessing -import torchaudio from fairseq.data import Dictionary @@ -22,6 +21,7 @@ def process_sample(aud_path, lable, utt_id, sp, tgt_dict): + import torchaudio input = {} output = {} si, ei = torchaudio.info(aud_path) diff --git a/examples/speech_recognition/w2l_decoder.py b/examples/speech_recognition/w2l_decoder.py index 141d41d6ca..c56448ce3d 100644 --- a/examples/speech_recognition/w2l_decoder.py +++ b/examples/speech_recognition/w2l_decoder.py @@ -13,16 +13,22 @@ import torch from fairseq import utils from examples.speech_recognition.data.replabels import unpack_replabels -from wav2letter.common import create_word_dict, load_words -from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes -from wav2letter.decoder import ( - CriterionType, - DecoderOptions, - KenLM, - SmearingMode, - Trie, - WordLMDecoder, -) + +try: + from wav2letter.common import create_word_dict, load_words + from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes + from wav2letter.decoder import ( + CriterionType, + DecoderOptions, + KenLM, + SmearingMode, + Trie, + WordLMDecoder, + ) +except ImportError: + # wav2letter is a required dependency for the speech_recognition + # example, but don't break on import + pass class W2lDecoder(object): diff --git a/examples/translation_moe/README.md b/examples/translation_moe/README.md index debcde6630..33f1bee5cb 100644 --- a/examples/translation_moe/README.md +++ b/examples/translation_moe/README.md @@ -18,7 +18,7 @@ The following command will train a `hMoElp` model with `3` experts: fairseq-train --ddp-backend='no_c10d' \ data-bin/wmt17_en_de \ --max-update 100000 \ - --task translation_moe \ + --task translation_moe --user-dir examples/translation_moe/src \ --method hMoElp --mean-pool-gating-network \ --num-experts 3 \ --arch transformer_wmt_en_de --share-all-embeddings \ @@ -37,7 +37,7 @@ For example, to generate from expert 0: fairseq-generate data-bin/wmt17_en_de \ --path checkpoints/checkpoint_best.pt \ --beam 1 --remove-bpe \ - --task translation_moe \ + --task translation_moe --user-dir examples/translation_moe/src \ --method hMoElp --mean-pool-gating-network \ --num-experts 3 \ --gen-expert 0 @@ -61,7 +61,7 @@ for EXPERT in $(seq 0 2); do \ --beam 1 \ --bpe subword_nmt --bpe-codes $BPE_CODE \ --buffer-size 500 --max-tokens 6000 \ - --task translation_moe \ + --task translation_moe --user-dir examples/translation_moe/src \ --method hMoElp --mean-pool-gating-network \ --num-experts 3 \ --gen-expert $EXPERT ; \ diff --git a/examples/translation_moe/src/__init__.py b/examples/translation_moe/src/__init__.py new file mode 100644 index 0000000000..c0abe53e97 --- /dev/null +++ b/examples/translation_moe/src/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import translation_moe # noqa diff --git a/fairseq/modules/logsumexp_moe.py b/examples/translation_moe/src/logsumexp_moe.py similarity index 100% rename from fairseq/modules/logsumexp_moe.py rename to examples/translation_moe/src/logsumexp_moe.py diff --git a/fairseq/modules/mean_pool_gating_network.py b/examples/translation_moe/src/mean_pool_gating_network.py similarity index 100% rename from fairseq/modules/mean_pool_gating_network.py rename to examples/translation_moe/src/mean_pool_gating_network.py diff --git a/fairseq/tasks/translation_moe.py b/examples/translation_moe/src/translation_moe.py similarity index 97% rename from fairseq/tasks/translation_moe.py rename to examples/translation_moe/src/translation_moe.py index f0456dffab..8a3709183f 100644 --- a/fairseq/tasks/translation_moe.py +++ b/examples/translation_moe/src/translation_moe.py @@ -5,10 +5,13 @@ import torch -from fairseq import metrics, modules, utils +from fairseq import metrics, utils from fairseq.tasks import register_task from fairseq.tasks.translation import TranslationTask +from .logsumexp_moe import LogSumExpMoE +from .mean_pool_gating_network import MeanPoolGatingNetwork + @register_task('translation_moe') class TranslationMoETask(TranslationTask): @@ -100,7 +103,7 @@ def build_model(self, args): else: raise ValueError('Must specify --mean-pool-gating-network-dropout') - model.gating_network = modules.MeanPoolGatingNetwork( + model.gating_network = MeanPoolGatingNetwork( encoder_dim, args.num_experts, dropout, ) else: @@ -171,7 +174,7 @@ def get_lprob_yz(winners=None): loss = -get_lprob_yz(winners) else: lprob_yz = get_lprob_yz() # B x K - loss = -modules.LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1) + loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1) loss = loss.sum() sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py index 64b10ea364..59da00807a 100644 --- a/fairseq/modules/__init__.py +++ b/fairseq/modules/__init__.py @@ -17,8 +17,6 @@ from .learned_positional_embedding import LearnedPositionalEmbedding from .lightweight_convolution import LightweightConv, LightweightConv1dTBC from .linearized_convolution import LinearizedConvolution -from .logsumexp_moe import LogSumExpMoE -from .mean_pool_gating_network import MeanPoolGatingNetwork from .multihead_attention import MultiheadAttention from .positional_embedding import PositionalEmbedding from .scalar_bias import ScalarBias @@ -47,8 +45,6 @@ 'LightweightConv1dTBC', 'LightweightConv', 'LinearizedConvolution', - 'LogSumExpMoE', - 'MeanPoolGatingNetwork', 'MultiheadAttention', 'PositionalEmbedding', 'ScalarBias', diff --git a/fairseq/options.py b/fairseq/options.py index f56c9bc54d..07034cec70 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -113,6 +113,13 @@ def parse_args_and_arch( from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY + # Before creating the true parser, we need to import optional user module + # in order to eagerly import custom tasks, optimizers, architectures, etc. + usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) + usr_parser.add_argument("--user-dir", default=None) + usr_args, _ = usr_parser.parse_known_args(input_args) + utils.import_user_module(usr_args) + if modify_parser is not None: modify_parser(parser) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 0d6c2c5fe8..20bf3b6b30 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -375,6 +375,7 @@ def test_mixture_of_experts(self): preprocess_translation_data(data_dir) train_translation_model(data_dir, 'transformer_iwslt_de_en', [ '--task', 'translation_moe', + '--user-dir', 'examples/translation_moe/src', '--method', 'hMoElp', '--mean-pool-gating-network', '--num-experts', '3', @@ -385,6 +386,7 @@ def test_mixture_of_experts(self): ]) generate_main(data_dir, [ '--task', 'translation_moe', + '--user-dir', 'examples/translation_moe/src', '--method', 'hMoElp', '--mean-pool-gating-network', '--num-experts', '3',