forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add dummy tasks and model for benchmarking (facebookresearch#1026)
Summary: Pull Request resolved: fairinternal/fairseq-py#1026 Differential Revision: D19834667 Pulled By: myleott fbshipit-source-id: 56ab6df5d8145dc37431252de444a2a9728e7898
- Loading branch information
1 parent
5b74c1e
commit 1a41d13
Showing
6 changed files
with
338 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from fairseq.data import Dictionary, FairseqDataset | ||
from fairseq.tasks import FairseqTask, register_task | ||
|
||
|
||
@register_task('dummy_lm') | ||
class DummyLMTask(FairseqTask): | ||
|
||
@staticmethod | ||
def add_args(parser): | ||
"""Add task-specific arguments to the parser.""" | ||
parser.add_argument('--dict-size', default=50000, type=int) | ||
parser.add_argument('--dataset-size', default=100000, type=int) | ||
parser.add_argument('--tokens-per-sample', default=512, type=int, | ||
help='max number of total tokens over all segments ' | ||
'per sample for BERT dataset') | ||
|
||
def __init__(self, args, dictionary): | ||
super().__init__(args) | ||
self.dictionary = dictionary | ||
self.seed = args.seed | ||
|
||
seq = torch.arange(args.tokens_per_sample + 1) + dictionary.pad() + 1 | ||
|
||
self.dummy_src = seq[:-1] | ||
self.dummy_tgt = seq[1:] | ||
|
||
@classmethod | ||
def setup_task(cls, args, **kwargs): | ||
"""Setup the task. """ | ||
dictionary = Dictionary() | ||
for i in range(args.dict_size): | ||
dictionary.add_symbol('word{}'.format(i)) | ||
print('| dictionary: {} types'.format(len(dictionary))) | ||
|
||
return cls(args, dictionary) | ||
|
||
def load_dataset(self, split, epoch=0, combine=False, **kwargs): | ||
"""Load a given dataset split. | ||
Args: | ||
split (str): name of the split (e.g., train, valid, test) | ||
""" | ||
bsz = self.args.max_sentences | ||
self.datasets[split] = DummyDataset( | ||
{ | ||
'id': 1, | ||
'net_input': { | ||
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]), | ||
'src_lengths': torch.full((bsz, ), self.args.tokens_per_sample), | ||
}, | ||
'target': torch.stack([self.dummy_tgt for _ in range(bsz)]), | ||
'nsentences': bsz, | ||
'ntokens': bsz * self.args.tokens_per_sample, | ||
}, | ||
num_items=self.args.dataset_size, | ||
item_size=self.args.tokens_per_sample, | ||
) | ||
|
||
@property | ||
def source_dictionary(self): | ||
return self.dictionary | ||
|
||
@property | ||
def target_dictionary(self): | ||
return self.dictionary | ||
|
||
|
||
class DummyDataset(FairseqDataset): | ||
|
||
def __init__(self, batch, num_items, item_size): | ||
super().__init__() | ||
self.batch = batch | ||
self.num_items = num_items | ||
self.item_size = item_size | ||
|
||
def __getitem__(self, index): | ||
return index | ||
|
||
def __len__(self): | ||
return self.num_items | ||
|
||
def collater(self, samples): | ||
return self.batch | ||
|
||
@property | ||
def sizes(self): | ||
return np.array([self.item_size] * self.num_items) | ||
|
||
def num_tokens(self, index): | ||
return self.item_size | ||
|
||
def size(self, index): | ||
return self.item_size | ||
|
||
def ordered_indices(self): | ||
return np.arange(self.num_items) | ||
|
||
@property | ||
def supports_prefetch(self): | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from fairseq.data import Dictionary, FairseqDataset | ||
from fairseq.tasks import FairseqTask, register_task | ||
|
||
|
||
@register_task('dummy_masked_lm') | ||
class DummyMaskedLMTask(FairseqTask): | ||
|
||
@staticmethod | ||
def add_args(parser): | ||
"""Add task-specific arguments to the parser.""" | ||
parser.add_argument('--dict-size', default=50000, type=int) | ||
parser.add_argument('--dataset-size', default=100000, type=int) | ||
parser.add_argument('--tokens-per-sample', default=512, type=int, | ||
help='max number of total tokens over all segments ' | ||
'per sample for BERT dataset') | ||
|
||
def __init__(self, args, dictionary): | ||
super().__init__(args) | ||
self.dictionary = dictionary | ||
self.seed = args.seed | ||
|
||
# add mask token | ||
self.mask_idx = dictionary.add_symbol('<mask>') | ||
assert len(dictionary) % 8 == 0 | ||
|
||
mask_idx = 0 | ||
pad_idx = 1 | ||
seq = torch.arange(args.tokens_per_sample) + pad_idx + 1 | ||
mask = torch.arange(2, args.tokens_per_sample, 7) # ~15% | ||
src = seq.clone() | ||
src[mask] = mask_idx | ||
tgt = torch.full_like(seq, pad_idx) | ||
tgt[mask] = seq[mask] | ||
|
||
self.dummy_src = src | ||
self.dummy_tgt = tgt | ||
|
||
@classmethod | ||
def setup_task(cls, args, **kwargs): | ||
"""Setup the task. """ | ||
dictionary = Dictionary() | ||
for i in range(args.dict_size): | ||
dictionary.add_symbol('word{}'.format(i)) | ||
print('| dictionary: {} types'.format(len(dictionary))) | ||
|
||
return cls(args, dictionary) | ||
|
||
def load_dataset(self, split, epoch=0, combine=False, **kwargs): | ||
"""Load a given dataset split. | ||
Args: | ||
split (str): name of the split (e.g., train, valid, test) | ||
""" | ||
bsz = self.args.max_sentences | ||
self.datasets[split] = DummyDataset( | ||
{ | ||
'id': 1, | ||
'net_input': { | ||
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]), | ||
'src_lengths': torch.full((bsz, ), self.args.tokens_per_sample), | ||
}, | ||
'target': torch.stack([self.dummy_tgt for _ in range(bsz)]), | ||
'nsentences': bsz, | ||
'ntokens': bsz * self.args.tokens_per_sample, | ||
}, | ||
num_items=self.args.dataset_size, | ||
item_size=self.args.tokens_per_sample, | ||
) | ||
|
||
@property | ||
def source_dictionary(self): | ||
return self.dictionary | ||
|
||
@property | ||
def target_dictionary(self): | ||
return self.dictionary | ||
|
||
|
||
class DummyDataset(FairseqDataset): | ||
|
||
def __init__(self, batch, num_items, item_size): | ||
super().__init__() | ||
self.batch = batch | ||
self.num_items = num_items | ||
self.item_size = item_size | ||
|
||
def __getitem__(self, index): | ||
return index | ||
|
||
def __len__(self): | ||
return self.num_items | ||
|
||
def collater(self, samples): | ||
return self.batch | ||
|
||
@property | ||
def sizes(self): | ||
return np.array([self.item_size] * self.num_items) | ||
|
||
def num_tokens(self, index): | ||
return self.item_size | ||
|
||
def size(self, index): | ||
return self.item_size | ||
|
||
def ordered_indices(self): | ||
return np.arange(self.num_items) | ||
|
||
@property | ||
def supports_prefetch(self): | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from fairseq.data import Dictionary | ||
from fairseq.models import ( | ||
FairseqDecoder, | ||
FairseqLanguageModel, | ||
register_model, | ||
register_model_architecture, | ||
) | ||
|
||
|
||
@register_model('dummy_model') | ||
class DummyModel(FairseqLanguageModel): | ||
|
||
def __init__(self, args, encoder): | ||
super().__init__(encoder) | ||
self.args = args | ||
|
||
@staticmethod | ||
def add_args(parser): | ||
parser.add_argument('--num-layers', type=int, default=24) | ||
parser.add_argument('--embed-dim', type=int, default=1024) | ||
|
||
@classmethod | ||
def build_model(cls, args, task): | ||
encoder = DummyEncoder( | ||
num_embed=len(task.target_dictionary), | ||
embed_dim=args.embed_dim, | ||
num_layers=args.num_layers, | ||
) | ||
return cls(args, encoder) | ||
|
||
def forward(self, src_tokens, **kwargs): | ||
return self.decoder(src_tokens) | ||
|
||
|
||
class DummyEncoder(FairseqDecoder): | ||
|
||
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24): | ||
super().__init__(Dictionary()) | ||
self.embed = nn.Embedding( | ||
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0 | ||
) | ||
self.layers_a = nn.ModuleList([ | ||
nn.Sequential( | ||
nn.LayerNorm(embed_dim), | ||
nn.Linear(embed_dim, 3*embed_dim), # q, k, v input projection | ||
nn.Linear(3*embed_dim, embed_dim), # skip self-attention | ||
nn.Linear(embed_dim, embed_dim), # output projection | ||
nn.Dropout(), | ||
) | ||
for i in range(num_layers) | ||
]) | ||
self.layers_b = nn.ModuleList([ | ||
nn.Sequential( | ||
nn.LayerNorm(embed_dim), | ||
nn.Linear(embed_dim, 4*embed_dim), # FFN | ||
nn.ReLU(), | ||
nn.Linear(4*embed_dim, embed_dim), # FFN | ||
nn.Dropout(0.1), | ||
) | ||
for i in range(num_layers) | ||
]) | ||
self.out_proj = nn.Linear(embed_dim, num_embed) | ||
|
||
def forward(self, tokens): | ||
x = self.embed(tokens) | ||
for layer_a, layer_b in zip(self.layers_a, self.layers_b): | ||
x = x + layer_a(x) | ||
x = x + layer_b(x) | ||
x = self.out_proj(x) | ||
return (x,) | ||
|
||
def max_positions(self): | ||
return 1024 | ||
|
||
def get_normalized_probs(self, net_output, log_probs, sample=None): | ||
logits = net_output[0].float() | ||
if log_probs: | ||
return F.log_softmax(logits, dim=-1) | ||
else: | ||
return F.softmax(logits, dim=-1) | ||
|
||
|
||
@register_model_architecture('dummy_model', 'dummy_model') | ||
def base_architecture(args): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters