-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathbert.py
31 lines (23 loc) · 875 Bytes
/
bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# labels: name::bert author::transformers task::Generative_AI license::apache-2.0
from turnkeyml.parser import parse
from transformers import BertModel, AutoConfig
import torch
torch.manual_seed(0)
# Parsing command-line arguments
pretrained, batch_size, max_seq_length = parse(
["pretrained", "batch_size", "max_seq_length"]
)
# Model and input configurations
if pretrained:
model = BertModel.from_pretrained("bert-base-uncased")
else:
config = AutoConfig.from_pretrained("bert-base-uncased")
model = BertModel(config)
# Make sure the user's sequence length fits within the model's maximum
assert max_seq_length <= model.config.max_position_embeddings
inputs = {
"input_ids": torch.ones(batch_size, max_seq_length, dtype=torch.long),
"attention_mask": torch.ones(batch_size, max_seq_length, dtype=torch.float),
}
# Call model
model(**inputs)