-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathget_pretrained_embeddings.py
185 lines (156 loc) · 5.67 KB
/
get_pretrained_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import torch
import torch.nn as nn
from transformers import BertModel, AutoTokenizer
from models.configurations import VISION_PRETRAINED, TEXT_PRETRAINED
from utils.utils import get_text_embeds_raw, get_image_embeds_raw
from utils.dataset_utils import get_dataloader
from utils.dataset_utils import pickle_dataset
from utils.model_utils import load_vision_model
import logging
from dataset.dataset import multimodal_collator
import logging
import os
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--vision_model", type=str, help="Choose from [resnet-ae, dinov2-s, dinov2-b]"
)
parser.add_argument(
"--text_model",
type=str,
help="Choose from [bert, biobert, pubmedbert, cxrbert, clinicalbert]",
)
parser.add_argument(
"--text_embeds_raw_dir",
type=str,
default="/vol/bitbucket/jq619/adaptor-thesis/saved_embeddings/text_embeds",
help="path to raw text embeddings",
)
parser.add_argument(
"--image_embeds_raw_dir",
type=str,
default="/vol/bitbucket/jq619/adaptor-thesis/saved_embeddings/image_embeds",
help="path to raw image embeddings",
)
# parser.add_argument('--num_of_batches', type=int, default=100, help='number of batches to use for training')
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument(
"--force_rebuild_dataset",
action="store_true",
help="Whether to force rebuild dataset, if not can load pickled file if available",
)
parser.add_argument("--cpu", action="store_true", help="Whether to run on cpu")
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument(
"--data_pct",
type=float,
default=1.0,
help="percentage of data to use. If setting 1.0, then use all data with no shuffling",
)
parser.add_argument("--crop_size", type=int, default=224)
parser.add_argument(
"--full", action="store_true", help="Compute global and local vision features."
)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument(
"--local_rank", default=-1, type=int, help="node rank for distributed training"
)
args = parser.parse_args()
torch.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")
print(f"Using device: {device}")
do_text = args.text_model is not None
do_vision = args.vision_model is not None
if do_vision:
if args.vision_model not in VISION_PRETRAINED.keys():
raise ValueError(
f"Vision model {args.vision_model} not available."
f"Choose from {list(VISION_PRETRAINED.keys())}"
)
else:
args.vision_model = "swin-base"
if do_text:
if args.text_model not in TEXT_PRETRAINED.keys():
raise ValueError(
f"Text model {args.text_model} not available."
f"Choose from {list(TEXT_PRETRAINED.keys())}"
)
else:
args.text_model = "biobert"
vision_model_config = VISION_PRETRAINED[args.vision_model]
args.vision_pretrained = vision_model_config["pretrained_weight"]
args.vision_model_type = vision_model_config["vision_model_type"]
args.vision_output_dim = vision_model_config["vision_output_dim"]
data_transforms = vision_model_config["data_transform"]
args.text_pretrained = TEXT_PRETRAINED[args.text_model]
# Load pretrained models
vision_model = load_vision_model(args.vision_model_type, args.vision_pretrained)
vision_model.to(device)
### Load text model
text_model = BertModel.from_pretrained(args.text_pretrained)
tokenizer = AutoTokenizer.from_pretrained(args.text_pretrained)
text_model.to(device)
### Load dataset
postfix = "_ae" if args.vision_model_type == "ae" else ""
train_dataset_pkl = f"saved_datasets/train_dataset_{args.text_model}{postfix}.pkl"
val_dataset_pkl = f"saved_datasets/val_dataset_{args.text_model}{postfix}.pkl"
train_dataset = pickle_dataset(
train_dataset_pkl,
split="train",
transform=data_transforms(True, args.crop_size),
data_pct=args.data_pct,
force_rebuild=args.force_rebuild_dataset,
validate_path=args.force_rebuild_dataset,
tokenizer=tokenizer,
)
val_dataset = pickle_dataset(
val_dataset_pkl,
split="valid",
transform=data_transforms(False, args.crop_size),
data_pct=args.data_pct,
force_rebuild=args.force_rebuild_dataset,
validate_path=args.force_rebuild_dataset,
tokenizer=tokenizer,
)
# Get dataloaders
train_dataloader = get_dataloader(
train_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
collate_fn=multimodal_collator,
)
val_dataloader = get_dataloader(
val_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
collate_fn=multimodal_collator,
)
os.makedirs(args.image_embeds_raw_dir, exist_ok=True)
os.makedirs(args.text_embeds_raw_dir, exist_ok=True)
for split, dataloader in zip(["train", "valid"], [train_dataloader, val_dataloader]):
if do_text:
logging.info(f"Getting text embeddings for {split} split")
get_text_embeds_raw(
dataloader,
text_model=text_model,
save_path=args.text_embeds_raw_dir,
model_name=args.text_model,
batch_size=args.batch_size,
split=split,
device=device,
)
if do_vision:
logging.info(f"Getting vision embeddings for {split} split")
get_image_embeds_raw(
dataloader,
vision_model=vision_model,
vision_model_type=args.vision_model_type,
save_path=args.image_embeds_raw_dir,
model_name=args.vision_model,
batch_size=args.batch_size,
embedding_dim=args.vision_output_dim,
split=split,
device=device,
full=args.full,
)