Skip to content

Commit 03f0676

Browse files
committed
Added changes to ensure mxint8 compilations of VLMs work.
Modified modelling files of InternVL and Llava to have 'vision_embeds' as the name of the image_embeddings. Modified modeling_auto file to incorporate mxint8 modifications for VLMs. LIMITATIONS: It is expected that the Processor of a model always gives vision components in 'float16'. Signed-off-by: quic-dhirajku <quic_dhirajku@quicinc.com>
1 parent a706a01 commit 03f0676

File tree

3 files changed

+70
-150
lines changed

3 files changed

+70
-150
lines changed

QEfficient/transformers/models/internvl/modeling_internvl.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def __init__(self, model):
2020
self.model = model
2121

2222
def forward(self, pixel_values):
23-
vit_embeds = self.model.extract_feature(pixel_values)
24-
return vit_embeds
23+
vision_embeds = self.model.extract_feature(pixel_values)
24+
return vision_embeds
2525

2626

2727
class QEffInternDecoderWrapper(nn.Module):
@@ -31,21 +31,21 @@ def __init__(self, model):
3131
self.config = self.model.language_model.config
3232
self.language_model = self.model.language_model
3333

34-
def forward(self, input_ids, vit_embeds, position_ids, past_key_values):
34+
def forward(self, input_ids, vision_embeds, position_ids, past_key_values):
3535
input_embeds = self.model.language_model.get_input_embeddings()(input_ids)
3636
B, N, C = input_embeds.shape
3737
image_input_embeds = input_embeds.reshape(B * N, C)
3838
image_input_ids = input_ids.reshape(B * N)
3939
selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN
4040
indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1
4141
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
42-
image_features_expanded = vit_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
42+
image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
4343
image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
4444
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
4545
outputs = self.model.language_model(
4646
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
4747
)
48-
return outputs.logits, vit_embeds, outputs.past_key_values
48+
return outputs.logits, vision_embeds, outputs.past_key_values
4949

5050

5151
class QEffInternVLModel(nn.Module):
@@ -122,7 +122,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
122122
lang_dynamic_axes = {}
123123
lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}
124124
lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
125-
lang_dynamic_axes["vit_embeds"] = {0: "num_patches"}
125+
lang_dynamic_axes["vision_embeds"] = {0: "num_patches"}
126126
vision_dynamic_axes["pixel_values"] = {0: "num_patches", 2: "img_size", 3: "img_size"}
127127

128128
pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"}
@@ -139,15 +139,15 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
139139
return dynamic_axes
140140

141141
def get_output_names(self, kv_offload: bool = False):
142-
vision_output_names = ["vit_embeds"]
142+
vision_output_names = ["vision_embeds"]
143143
lang_output_names = ["logits"]
144144
for i in range(self.language_model.config.num_hidden_layers):
145145
for kv in ["key", "value"]:
146146
lang_output_names.append(f"past_{kv}.{i}_RetainedState")
147147

148148
output_names = {}
149149
if kv_offload:
150-
lang_output_names.insert(1, "vit_embeds_RetainedState")
150+
lang_output_names.insert(1, "vision_embeds_RetainedState")
151151
output_names["vision"] = vision_output_names
152152
output_names["lang"] = lang_output_names
153153
else:
@@ -175,7 +175,7 @@ def get_dummy_inputs(self, kv_offload: bool = False):
175175
# Define shapes
176176
inputs_shapes = {}
177177
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
178-
inputs_shapes["vit_embeds"] = (
178+
inputs_shapes["vision_embeds"] = (
179179
constants.INTERN_NUM_PATCHES,
180180
constants.INTERN_FEATURE_SIZE,
181181
self.language_model.config.hidden_size,
@@ -196,7 +196,7 @@ def get_dummy_inputs(self, kv_offload: bool = False):
196196
lang_inputs = {}
197197
vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32)
198198
lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
199-
lang_inputs["vit_embeds"] = torch.zeros((inputs_shapes["vit_embeds"]), dtype=torch.float32)
199+
lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32)
200200
lang_inputs["position_ids"] = (
201201
torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64)
202202
.view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
@@ -220,21 +220,21 @@ def get_dummy_inputs(self, kv_offload: bool = False):
220220
inputs["vision"] = vision_inputs
221221
inputs["lang"] = lang_inputs
222222
else:
223-
lang_inputs.pop("vit_embeds")
223+
lang_inputs.pop("vision_embeds")
224224
inputs = {**vision_inputs, **lang_inputs}
225225

226226
return inputs
227227

228228
def forward(self, input_ids, pixel_values, position_ids, past_key_values):
229229
input_embeds = self.language_model.get_input_embeddings()(input_ids)
230-
vit_embeds = self.extract_feature(pixel_values)
230+
vision_embeds = self.extract_feature(pixel_values)
231231
B, N, C = input_embeds.shape
232232
image_input_embeds = input_embeds.reshape(B * N, C)
233233
image_input_ids = input_ids.reshape(B * N)
234234
selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN
235235
indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1
236236
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
237-
image_features_expanded = vit_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
237+
image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
238238
image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
239239
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
240240
outputs = self.language_model(

QEfficient/transformers/models/llava/modeling_llava.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def forward(self, pixel_values):
3838
selected_image_feature = selected_image_feature
3939
else:
4040
raise ValueError(f"Unexpected select feature strategy: {self.model.config.vision_feature_select_strategy}")
41-
image_features = self.model.multi_modal_projector(selected_image_feature)
41+
vision_embeds = self.model.multi_modal_projector(selected_image_feature)
4242

43-
return image_features
43+
return vision_embeds
4444

4545

4646
class QEFFLlavaDecoderWrapper(nn.Module):
@@ -50,21 +50,21 @@ def __init__(self, model):
5050
self.config = self.model.config
5151
self.language_model = self.model.language_model
5252

53-
def forward(self, input_ids, image_features, position_ids, past_key_values):
53+
def forward(self, input_ids, vision_embeds, position_ids, past_key_values):
5454
inputs_embeds = self.model.get_input_embeddings()(input_ids)
55-
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
55+
vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
5656
mask = input_ids == self.model.config.image_token_index
5757
indices1 = mask.to(torch.int64).cumsum(1) - 1
5858
indices0 = torch.arange(mask.shape[0]).view(-1, 1)
59-
image_features_expanded = image_features[indices0, indices1]
60-
inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds)
59+
vision_embeds_expanded = vision_embeds[indices0, indices1]
60+
inputs_embeds = torch.where(mask.unsqueeze(-1), vision_embeds_expanded, inputs_embeds)
6161
outputs = self.model.language_model(
6262
inputs_embeds=inputs_embeds,
6363
position_ids=position_ids,
6464
past_key_values=past_key_values,
6565
)
6666

67-
return outputs.logits, image_features, outputs.past_key_values
67+
return outputs.logits, vision_embeds, outputs.past_key_values
6868

6969

7070
class QEffLlavaForConditionalGeneration(LlavaForConditionalGeneration):
@@ -86,14 +86,14 @@ def forward(self, input_ids, position_ids, pixel_values, past_key_values):
8686
selected_image_feature = selected_image_feature
8787
else:
8888
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
89-
image_features = self.multi_modal_projector(selected_image_feature)
90-
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
89+
vision_embeds = self.multi_modal_projector(selected_image_feature)
90+
vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
9191

9292
mask = input_ids == self.config.image_token_index
9393
indices1 = mask.to(torch.int64).cumsum(1) - 1
9494
indices0 = torch.arange(mask.shape[0]).view(-1, 1)
95-
image_features_expanded = image_features[indices0, indices1]
96-
image_inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds)
95+
vision_embeds_expanded = vision_embeds[indices0, indices1]
96+
image_inputs_embeds = torch.where(mask.unsqueeze(-1), vision_embeds_expanded, inputs_embeds)
9797
# *where to skip image encoder for decode*
9898
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_inputs_embeds)
9999
outputs = self.language_model(
@@ -118,7 +118,7 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs):
118118
}
119119
lang_inputs = {
120120
"input_ids": torch.ones((BS, SEQ_LEN), dtype=torch.int64),
121-
"image_features": torch.ones((BS, 576, self.language_model.config.hidden_size), dtype=torch.float32),
121+
"vision_embeds": torch.ones((BS, 576, self.language_model.config.hidden_size), dtype=torch.float32),
122122
"attention_mask": torch.ones((BS, SEQ_LEN), dtype=torch.int64),
123123
}
124124
lang_inputs["position_ids"] = lang_inputs.pop("attention_mask").cumsum(1)
@@ -137,7 +137,7 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs):
137137
inputs["vision"] = vision_inputs
138138
inputs["lang"] = lang_inputs
139139
else:
140-
lang_inputs.pop("image_features")
140+
lang_inputs.pop("vision_embeds")
141141
inputs = {**vision_inputs, **lang_inputs}
142142
return inputs
143143

@@ -218,15 +218,15 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
218218
return dynamic_axes
219219

220220
def get_output_names(self, kv_offload: bool = False):
221-
vision_output_names = ["image_features"]
221+
vision_output_names = ["vision_embeds"]
222222
lang_output_names = ["logits"]
223223
for i in range(self.language_model.config.num_hidden_layers):
224224
for kv in ["key", "value"]:
225225
lang_output_names.append(f"past_{kv}.{i}_RetainedState")
226226

227227
output_names = {}
228228
if kv_offload:
229-
lang_output_names.insert(1, "image_features_RetainedState")
229+
lang_output_names.insert(1, "vision_embeds_RetainedState")
230230
output_names["vision"] = vision_output_names
231231
output_names["lang"] = lang_output_names
232232
else:

0 commit comments

Comments
 (0)