-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathbert_verifier_data_model.py
66 lines (51 loc) · 2.58 KB
/
bert_verifier_data_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
from base_data_model import BaseDataModel, BaseDataset
from data_preprocess import DataProcessor
from typing import List, Union, Tuple, Optional, Dict, Callable
class BertVerifierDataModel(BaseDataModel):
def __init__(self, args, tokenizer):
super().__init__(args, tokenizer)
def get_examples(self, path, type):
examples = DataProcessor._read_jsonl(path)
print(f"{len(examples)} examples")
return examples
@staticmethod
def collate_fn(batch, args, tokenizer):
batch_data = {}
for key in batch[0]:
batch_data[key] = [example[key] for example in batch]
inputs_encoding = tokenizer(batch_data['question'], batch_data['solution'], add_special_tokens=True, return_tensors="pt", padding=True, truncation=True, max_length=512)
final_token_idx = inputs_encoding.attention_mask.sum(-1).view(-1, 1) - 1
return dict(**batch_data, **inputs_encoding, verifier_labels=torch.FloatTensor(batch_data['is_correct']), final_token_idx=final_token_idx)
if __name__ == '__main__':
import argparse
import pytorch_lightning as pl
from transformers import BertTokenizer, AutoModelForMaskedLM
from base_model import BaseModel
from base_trainer import BaseTrainer
from bert_verifier_modeling_gsm8k import BertModelForVerifier
import transformers
transformers.logging.set_verbosity_error()
total_parser = argparse.ArgumentParser()
# * data preprocessing args
total_parser = BertVerifierDataModel.add_data_specific_args(total_parser)
# * training args
total_parser = BaseTrainer.add_trainer_specific_args(total_parser)
# * model specific args
total_parser = BaseModel.add_model_specific_args(total_parser)
# * Bert specific args
total_parser = BertModelForVerifier.add_model_specific_args(total_parser)
args = total_parser.parse_args()
tokenizer = BertTokenizer.from_pretrained(args.model_name, use_fast=True)
tokenizer.add_tokens(['[QUES]', '[ANS]', '[THOUGHT]', '[VERIFIER]'])
bert = AutoModelForMaskedLM.from_pretrained(args.model_name)
if bert.config.vocab_size < len(tokenizer):
bert.resize_token_embeddings(new_num_tokens=len(tokenizer))
verifier_head = nn.Linear(1, 1, bias=True)
model = BertModelForVerifier(args, bert, tokenizer, verifier_head)
verifier_data_model = BertVerifierDataModel(args, tokenizer)
train_dataloader = verifier_data_model.train_dataloader()
# val_dataloader = verifier_data_model.val_dataloader()
trainer = BaseTrainer(args, model)
trainer.train(verifier_data_model)