Skip to content

Added changes to ensure mxint8 compilations of VLMs work. #336

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions QEfficient/transformers/models/internvl/modeling_internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def __init__(self, model):
self.model = model

def forward(self, pixel_values):
vit_embeds = self.model.extract_feature(pixel_values)
return vit_embeds
vision_embeds = self.model.extract_feature(pixel_values)
return vision_embeds


class QEffInternDecoderWrapper(nn.Module):
Expand All @@ -31,21 +31,21 @@ def __init__(self, model):
self.config = self.model.language_model.config
self.language_model = self.model.language_model

def forward(self, input_ids, vit_embeds, position_ids, past_key_values):
def forward(self, input_ids, vision_embeds, position_ids, past_key_values):
input_embeds = self.model.language_model.get_input_embeddings()(input_ids)
B, N, C = input_embeds.shape
image_input_embeds = input_embeds.reshape(B * N, C)
image_input_ids = input_ids.reshape(B * N)
selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN
indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
image_features_expanded = vit_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
outputs = self.model.language_model(
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
)
return outputs.logits, vit_embeds, outputs.past_key_values
return outputs.logits, vision_embeds, outputs.past_key_values


class QEffInternVLModel(nn.Module):
Expand Down Expand Up @@ -122,7 +122,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
lang_dynamic_axes = {}
lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}
lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
lang_dynamic_axes["vit_embeds"] = {0: "num_patches"}
lang_dynamic_axes["vision_embeds"] = {0: "num_patches"}
vision_dynamic_axes["pixel_values"] = {0: "num_patches", 2: "img_size", 3: "img_size"}

pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"}
Expand All @@ -139,15 +139,15 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
return dynamic_axes

def get_output_names(self, kv_offload: bool = False):
vision_output_names = ["vit_embeds"]
vision_output_names = ["vision_embeds"]
lang_output_names = ["logits"]
for i in range(self.language_model.config.num_hidden_layers):
for kv in ["key", "value"]:
lang_output_names.append(f"past_{kv}.{i}_RetainedState")

output_names = {}
if kv_offload:
lang_output_names.insert(1, "vit_embeds_RetainedState")
lang_output_names.insert(1, "vision_embeds_RetainedState")
output_names["vision"] = vision_output_names
output_names["lang"] = lang_output_names
else:
Expand Down Expand Up @@ -175,7 +175,7 @@ def get_dummy_inputs(self, kv_offload: bool = False):
# Define shapes
inputs_shapes = {}
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
inputs_shapes["vit_embeds"] = (
inputs_shapes["vision_embeds"] = (
constants.INTERN_NUM_PATCHES,
constants.INTERN_FEATURE_SIZE,
self.language_model.config.hidden_size,
Expand All @@ -196,7 +196,7 @@ def get_dummy_inputs(self, kv_offload: bool = False):
lang_inputs = {}
vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32)
lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
lang_inputs["vit_embeds"] = torch.zeros((inputs_shapes["vit_embeds"]), dtype=torch.float32)
lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32)
lang_inputs["position_ids"] = (
torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64)
.view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
Expand All @@ -220,21 +220,21 @@ def get_dummy_inputs(self, kv_offload: bool = False):
inputs["vision"] = vision_inputs
inputs["lang"] = lang_inputs
else:
lang_inputs.pop("vit_embeds")
lang_inputs.pop("vision_embeds")
inputs = {**vision_inputs, **lang_inputs}

return inputs

def forward(self, input_ids, pixel_values, position_ids, past_key_values):
input_embeds = self.language_model.get_input_embeddings()(input_ids)
vit_embeds = self.extract_feature(pixel_values)
vision_embeds = self.extract_feature(pixel_values)
B, N, C = input_embeds.shape
image_input_embeds = input_embeds.reshape(B * N, C)
image_input_ids = input_ids.reshape(B * N)
selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN
indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
image_features_expanded = vit_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
outputs = self.language_model(
Expand Down
30 changes: 15 additions & 15 deletions QEfficient/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def forward(self, pixel_values):
selected_image_feature = selected_image_feature
else:
raise ValueError(f"Unexpected select feature strategy: {self.model.config.vision_feature_select_strategy}")
image_features = self.model.multi_modal_projector(selected_image_feature)
vision_embeds = self.model.multi_modal_projector(selected_image_feature)

return image_features
return vision_embeds


class QEFFLlavaDecoderWrapper(nn.Module):
Expand All @@ -50,21 +50,21 @@ def __init__(self, model):
self.config = self.model.config
self.language_model = self.model.language_model

def forward(self, input_ids, image_features, position_ids, past_key_values):
def forward(self, input_ids, vision_embeds, position_ids, past_key_values):
inputs_embeds = self.model.get_input_embeddings()(input_ids)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
mask = input_ids == self.model.config.image_token_index
indices1 = mask.to(torch.int64).cumsum(1) - 1
indices0 = torch.arange(mask.shape[0]).view(-1, 1)
image_features_expanded = image_features[indices0, indices1]
inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds)
vision_embeds_expanded = vision_embeds[indices0, indices1]
inputs_embeds = torch.where(mask.unsqueeze(-1), vision_embeds_expanded, inputs_embeds)
outputs = self.model.language_model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
past_key_values=past_key_values,
)

return outputs.logits, image_features, outputs.past_key_values
return outputs.logits, vision_embeds, outputs.past_key_values


class QEffLlavaForConditionalGeneration(LlavaForConditionalGeneration):
Expand All @@ -86,14 +86,14 @@ def forward(self, input_ids, position_ids, pixel_values, past_key_values):
selected_image_feature = selected_image_feature
else:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
image_features = self.multi_modal_projector(selected_image_feature)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
vision_embeds = self.multi_modal_projector(selected_image_feature)
vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype)

mask = input_ids == self.config.image_token_index
indices1 = mask.to(torch.int64).cumsum(1) - 1
indices0 = torch.arange(mask.shape[0]).view(-1, 1)
image_features_expanded = image_features[indices0, indices1]
image_inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds)
vision_embeds_expanded = vision_embeds[indices0, indices1]
image_inputs_embeds = torch.where(mask.unsqueeze(-1), vision_embeds_expanded, inputs_embeds)
# *where to skip image encoder for decode*
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_inputs_embeds)
outputs = self.language_model(
Expand All @@ -118,7 +118,7 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs):
}
lang_inputs = {
"input_ids": torch.ones((BS, SEQ_LEN), dtype=torch.int64),
"image_features": torch.ones((BS, 576, self.language_model.config.hidden_size), dtype=torch.float32),
"vision_embeds": torch.ones((BS, 576, self.language_model.config.hidden_size), dtype=torch.float32),
"attention_mask": torch.ones((BS, SEQ_LEN), dtype=torch.int64),
}
lang_inputs["position_ids"] = lang_inputs.pop("attention_mask").cumsum(1)
Expand All @@ -137,7 +137,7 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs):
inputs["vision"] = vision_inputs
inputs["lang"] = lang_inputs
else:
lang_inputs.pop("image_features")
lang_inputs.pop("vision_embeds")
inputs = {**vision_inputs, **lang_inputs}
return inputs

Expand Down Expand Up @@ -218,15 +218,15 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
return dynamic_axes

def get_output_names(self, kv_offload: bool = False):
vision_output_names = ["image_features"]
vision_output_names = ["vision_embeds"]
lang_output_names = ["logits"]
for i in range(self.language_model.config.num_hidden_layers):
for kv in ["key", "value"]:
lang_output_names.append(f"past_{kv}.{i}_RetainedState")

output_names = {}
if kv_offload:
lang_output_names.insert(1, "image_features_RetainedState")
lang_output_names.insert(1, "vision_embeds_RetainedState")
output_names["vision"] = vision_output_names
output_names["lang"] = lang_output_names
else:
Expand Down
19 changes: 13 additions & 6 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,9 +638,12 @@ def compile(

custom_io_vision = {}
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
custom_io_vision["pixel_values"] = kv_cache_dtype
custom_io_vision["pixel_values"] = "float16"
for output_name in output_names["vision"]:
custom_io_vision[output_name] = kv_cache_dtype
if output_name.startswith("past_"):
custom_io_vision[output_name] = kv_cache_dtype
else:
custom_io_vision[output_name] = "float16"

if vision_onnx_path:
self.vision_model.onnx_path = vision_onnx_path
Expand Down Expand Up @@ -670,12 +673,14 @@ def compile(
# Inputs
for output_name in output_names["lang"]:
if output_name.endswith("_RetainedState"):
custom_io_lang[output_name[: -len("_RetainedState")]] = kv_cache_dtype
custom_io_lang[output_name[: -len("_RetainedState")]] = (
"float16" if "vision_embeds" in output_name else kv_cache_dtype
)

# outputs
for output_name in output_names["lang"]:
if output_name.endswith("_RetainedState"):
custom_io_lang[output_name] = kv_cache_dtype
custom_io_lang[output_name] = "float16" if "vision_embeds" in output_name else kv_cache_dtype

self.lang_model._compile(
compile_dir,
Expand Down Expand Up @@ -964,12 +969,14 @@ def compile(
# inputs
for input_name in output_names:
if input_name.endswith("_RetainedState"):
custom_io[input_name[: -len("_RetainedState")]] = kv_cache_dtype
custom_io[input_name[: -len("_RetainedState")]] = (
"float16" if "pixel_values" in input_name else kv_cache_dtype
)

# outputs
for output_name in output_names:
if output_name.endswith("_RetainedState"):
custom_io[output_name] = kv_cache_dtype
custom_io[output_name] = "float16" if "pixel_values" in output_name else kv_cache_dtype

self._compile(
onnx_path,
Expand Down
Loading