forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added sentence ranking task and loss (facebookresearch#809)
Summary: This task and loss are used for sentence ranking and multiple choice tasks such as RACE Pull Request resolved: fairinternal/fairseq-py#809 Reviewed By: myleott Differential Revision: D16715745 Pulled By: myleott fbshipit-source-id: b3f3eae048017910e8c7e881026603a5e427ddbc
- Loading branch information
1 parent
838e108
commit a778c94
Showing
5 changed files
with
471 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Finetuning RoBERTa on RACE tasks | ||
|
||
### 1) Download the data from RACE website (http://www.cs.cmu.edu/~glai1/data/race/) | ||
|
||
### 2) Preprocess RACE data: | ||
```bash | ||
python ./examples/roberta/preprocess_RACE.py <input-dir> <extracted-data-dir> | ||
./examples/roberta/preprocess_RACE.sh <extracted-data-dir> <output-dir> | ||
``` | ||
|
||
### 3) Fine-tuning on RACE: | ||
|
||
```bash | ||
MAX_EPOCHS=5 # epoch number | ||
LR=1e-05 # Peak LR for fixed LR scheduler. | ||
NUM_CLASSES=4 | ||
MAX_SENTENCES=2 # batch size | ||
ROBERTA_PATH=/path/to/roberta/model.pt | ||
|
||
CUDA_VISIBLE_DEVICES=0 python train.py <race-preprocessed-dir>/ \ | ||
--restore-file $ROBERTA_PATH \ | ||
--max-positions 512 \ | ||
--max-sentences $MAX_SENTENCES \ | ||
--task sentence_ranking \ | ||
--reset-optimizer --reset-dataloader --reset-meters \ | ||
--required-batch-size-multiple 1 \ | ||
--init-token 0 --separator-token 2 \ | ||
--arch roberta_large \ | ||
--criterion sentence_ranking \ | ||
--num-classes $NUM_CLASSES \ | ||
--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ | ||
--clip-norm 0.0 \ | ||
--lr-scheduler fixed --lr $LR \ | ||
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ | ||
--max-epoch 10 \ | ||
--update-freq 8 \ | ||
--find-unused-parameters \ | ||
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric; | ||
``` | ||
|
||
**Note:** | ||
|
||
a) As contexts in RACE are relatively long, we are using smaller batch size per GPU while increasing update-freq to achieve larger effective batch size. | ||
|
||
b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--max-sentences`. | ||
|
||
c) The setting in above command is based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
#!/usr/bin/env python | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import argparse | ||
import json | ||
import os | ||
|
||
|
||
class InputExample: | ||
def __init__(self, paragraph, qa_list, label): | ||
self.paragraph = paragraph | ||
self.qa_list = qa_list | ||
self.label = label | ||
|
||
|
||
def get_examples(data_dir, set_type): | ||
""" | ||
Extract paragraph and question-answer list from each json file | ||
""" | ||
examples = [] | ||
|
||
levels = ["middle", "high"] | ||
set_type_c = set_type.split('-') | ||
if len(set_type_c) == 2: | ||
levels = [set_type_c[1]] | ||
set_type = set_type_c[0] | ||
for level in levels: | ||
cur_dir = os.path.join(data_dir, set_type, level) | ||
for filename in os.listdir(cur_dir): | ||
cur_path = os.path.join(cur_dir, filename) | ||
with open(cur_path, 'r') as f: | ||
cur_data = json.load(f) | ||
answers = cur_data["answers"] | ||
options = cur_data["options"] | ||
questions = cur_data["questions"] | ||
context = cur_data["article"].replace("\n", " ") | ||
for i in range(len(answers)): | ||
label = ord(answers[i]) - ord("A") | ||
qa_list = [] | ||
question = questions[i] | ||
for j in range(4): | ||
option = options[i][j] | ||
if "_" in question: | ||
qa_cat = question.replace("_", option) | ||
else: | ||
qa_cat = " ".join([question, option]) | ||
qa_list.append(qa_cat) | ||
examples.append(InputExample(context, qa_list, label)) | ||
|
||
return examples | ||
|
||
|
||
def main(): | ||
""" | ||
Helper script to extract paragraphs questions and answers from RACE datasets. | ||
""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--input-dir", | ||
help='input directory for downloaded RACE dataset', | ||
) | ||
parser.add_argument( | ||
"--output-dir", | ||
help='output directory for extracted data', | ||
) | ||
args = parser.parse_args() | ||
|
||
for set_type in ["train", "dev", "test-middle", "test-high"]: | ||
examples = get_examples(args.input_dir, set_type) | ||
qa_file_paths = [args.output_dir + set_type + ".input" + str(i + 1) for i in range(4)] | ||
qa_files = [open(qa_file_path, 'w') for qa_file_path in qa_file_paths] | ||
outf_context_path = args.output_dir + set_type + ".input0" | ||
outf_label_path = args.output_dir + set_type + ".label" | ||
outf_context = open(outf_context_path, 'w') | ||
outf_label = open(outf_label_path, 'w') | ||
for example in examples: | ||
outf_context.write(example.paragraph + '\n') | ||
for i in range(4): | ||
qa_files[i].write(example.qa_list[i] + '\n') | ||
outf_label.write(str(example.label) + '\n') | ||
|
||
for f in qa_files: | ||
f.close() | ||
outf_label.close() | ||
outf_context.close() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#!/bin/bash | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
# data should be downloaded and processed with reprocess_RACE.py | ||
if [[ $# -ne 2 ]]; then | ||
echo "Run as following:" | ||
echo "./examples/roberta/preprocess_RACE.sh <race_data_folder> <output_folder>" | ||
exit 1 | ||
fi | ||
|
||
RACE_DATA_FOLDER=$1 | ||
OUT_DATA_FOLDER=$2 | ||
|
||
# download bpe encoder.json, vocabulary and fairseq dictionary | ||
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json' | ||
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe' | ||
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt' | ||
|
||
SPLITS="train dev test-middle test-high" | ||
INPUT_TYPES="input0 input1 input2 input3 input4" | ||
for INPUT_TYPE in $INPUT_TYPES | ||
do | ||
for SPLIT in $SPLITS | ||
do | ||
echo "BPE encoding $SPLIT/$INPUT_TYPE" | ||
python -m examples.roberta.multiprocessing_bpe_encoder \ | ||
--encoder-json encoder.json \ | ||
--vocab-bpe vocab.bpe \ | ||
--inputs "$RACE_DATA_FOLDER/$SPLIT.$INPUT_TYPE" \ | ||
--outputs "$RACE_DATA_FOLDER/$SPLIT.$INPUT_TYPE.bpe" \ | ||
--workers 10 \ | ||
--keep-empty; | ||
|
||
done | ||
done | ||
|
||
for INPUT_TYPE in $INPUT_TYPES | ||
do | ||
LANG="input$INPUT_TYPE" | ||
fairseq-preprocess \ | ||
--dataset-impl cached \ | ||
--only-source \ | ||
--trainpref "$RACE_DATA_FOLDER/train.$INPUT_TYPE.bpe" \ | ||
--validpref "$RACE_DATA_FOLDER/dev.$INPUT_TYPE.bpe" \ | ||
--testpref "$RACE_DATA_FOLDER/test-middle.$INPUT_TYPE.bpe,$RACE_DATA_FOLDER/test-high.$INPUT_TYPE.bpe" \ | ||
--destdir "$OUT_DATA_FOLDER/$INPUT_TYPE" \ | ||
--workers 10 \ | ||
--srcdict dict.txt; | ||
done | ||
|
||
rm -rf "$OUT_DATA_FOLDER/label" | ||
mkdir -p "$OUT_DATA_FOLDER/label" | ||
cp "$RACE_DATA_FOLDER/train.label" "$OUT_DATA_FOLDER/label/" | ||
cp "$RACE_DATA_FOLDER/dev.label" "$OUT_DATA_FOLDER/label/valid.label" | ||
cp "$RACE_DATA_FOLDER/test-middle.label" "$OUT_DATA_FOLDER/label/test.label" | ||
cp "$RACE_DATA_FOLDER/test-high.label" "$OUT_DATA_FOLDER/label/test1.label" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import math | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from fairseq import utils | ||
|
||
from . import FairseqCriterion, register_criterion | ||
|
||
|
||
@register_criterion('sentence_ranking') | ||
class SentenceRankingCriterion(FairseqCriterion): | ||
|
||
def forward(self, model, sample, reduce=True): | ||
"""Compute ranking loss for the given sample. | ||
Returns a tuple with three elements: | ||
1) the loss | ||
2) the sample size, which is used as the denominator for the gradient | ||
3) logging outputs to display while training | ||
""" | ||
scores = [] | ||
for idx in range(self.args.num_classes): | ||
score, _ = model( | ||
**sample['net_input{idx}'.format(idx=idx+1)], | ||
features_only=True, | ||
classification_head_name='sentence_classification_head', | ||
) | ||
scores.append(score) | ||
|
||
logits = torch.cat(scores, dim=1) | ||
targets = model.get_targets(sample, [logits]).view(-1) | ||
sample_size = targets.numel() | ||
|
||
loss = F.nll_loss( | ||
F.log_softmax(logits, dim=-1, dtype=torch.float32), | ||
targets, | ||
reduction='sum', | ||
) | ||
|
||
logging_output = { | ||
'loss': utils.item(loss.data) if reduce else loss.data, | ||
'ntokens': sample['ntokens'], | ||
'nsentences': sample_size, | ||
'sample_size': sample_size, | ||
} | ||
logging_output.update( | ||
ncorrect=(logits.max(dim=1)[1] == targets).sum().item() | ||
) | ||
return loss, sample_size, logging_output | ||
|
||
@staticmethod | ||
def aggregate_logging_outputs(logging_outputs): | ||
"""Aggregate logging outputs from data parallel training.""" | ||
loss_sum = sum(log.get('loss', 0) for log in logging_outputs) | ||
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) | ||
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) | ||
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) | ||
|
||
agg_output = { | ||
'loss': loss_sum / sample_size / math.log(2), | ||
'ntokens': ntokens, | ||
'nsentences': nsentences, | ||
'sample_size': sample_size, | ||
} | ||
|
||
if len(logging_outputs) > 0 and 'ncorrect' in logging_outputs[0]: | ||
ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs) | ||
agg_output.update(accuracy=ncorrect/nsentences) | ||
|
||
if sample_size != ntokens: | ||
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) | ||
return agg_output |
Oops, something went wrong.