-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
CLIPModel.py
90 lines (68 loc) · 3.09 KB
/
CLIPModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from __future__ import annotations
import torch
import transformers
from PIL import Image
from torch import nn
class CLIPModel(nn.Module):
save_in_root: bool = True
def __init__(self, model_name: str = "openai/clip-vit-base-patch32", processor_name=None) -> None:
super().__init__()
if processor_name is None:
processor_name = model_name
self.model = transformers.CLIPModel.from_pretrained(model_name)
self.processor = transformers.CLIPProcessor.from_pretrained(processor_name)
def __repr__(self) -> str:
return "CLIPModel()"
def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
image_embeds = []
text_embeds = []
if "pixel_values" in features:
vision_outputs = self.model.vision_model(pixel_values=features["pixel_values"])
image_embeds = self.model.visual_projection(vision_outputs[1])
if "input_ids" in features:
text_outputs = self.model.text_model(
input_ids=features.get("input_ids"),
attention_mask=features.get("attention_mask", None),
position_ids=features.get("position_ids", None),
output_attentions=features.get("output_attentions", None),
output_hidden_states=features.get("output_hidden_states", None),
)
text_embeds = self.model.text_projection(text_outputs[1])
sentence_embedding = []
image_features = iter(image_embeds)
text_features = iter(text_embeds)
for idx, input_type in enumerate(features["image_text_info"]):
if input_type == 0:
sentence_embedding.append(next(image_features))
else:
sentence_embedding.append(next(text_features))
features["sentence_embedding"] = torch.stack(sentence_embedding).float()
return features
def tokenize(self, texts, padding: str | bool = True) -> dict[str, torch.Tensor]:
images = []
texts_values = []
image_text_info = []
for idx, data in enumerate(texts):
if isinstance(data, Image.Image): # An Image
images.append(data)
image_text_info.append(0)
else: # A text
texts_values.append(data)
image_text_info.append(1)
encoding = {}
if len(texts_values):
encoding = self.processor.tokenizer(texts_values, return_tensors="pt", padding=padding)
if len(images):
image_features = self.processor.image_processor(images, return_tensors="pt")
encoding["pixel_values"] = image_features.pixel_values
encoding["image_text_info"] = image_text_info
return dict(encoding)
@property
def tokenizer(self) -> transformers.CLIPProcessor:
return self.processor
def save(self, output_path: str) -> None:
self.model.save_pretrained(output_path)
self.processor.save_pretrained(output_path)
@staticmethod
def load(input_path: str) -> CLIPModel:
return CLIPModel(model_name=input_path)