diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index 8ab178e2e..85ff5e96f 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -20,8 +20,8 @@ def __init__(self, model): self.model = model def forward(self, pixel_values): - vit_embeds = self.model.extract_feature(pixel_values) - return vit_embeds + vision_embeds = self.model.extract_feature(pixel_values) + return vision_embeds class QEffInternDecoderWrapper(nn.Module): @@ -31,7 +31,7 @@ def __init__(self, model): self.config = self.model.language_model.config self.language_model = self.model.language_model - def forward(self, input_ids, vit_embeds, position_ids, past_key_values): + def forward(self, input_ids, vision_embeds, position_ids, past_key_values): input_embeds = self.model.language_model.get_input_embeddings()(input_ids) B, N, C = input_embeds.shape image_input_embeds = input_embeds.reshape(B * N, C) @@ -39,13 +39,13 @@ def forward(self, input_ids, vit_embeds, position_ids, past_key_values): selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1 indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) - image_features_expanded = vit_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] + image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds) outputs = self.model.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True ) - return outputs.logits, vit_embeds, outputs.past_key_values + return outputs.logits, vision_embeds, outputs.past_key_values class QEffInternVLModel(nn.Module): @@ -122,7 +122,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} - lang_dynamic_axes["vit_embeds"] = {0: "num_patches"} + lang_dynamic_axes["vision_embeds"] = {0: "num_patches"} vision_dynamic_axes["pixel_values"] = {0: "num_patches", 2: "img_size", 3: "img_size"} pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"} @@ -139,7 +139,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): return dynamic_axes def get_output_names(self, kv_offload: bool = False): - vision_output_names = ["vit_embeds"] + vision_output_names = ["vision_embeds"] lang_output_names = ["logits"] for i in range(self.language_model.config.num_hidden_layers): for kv in ["key", "value"]: @@ -147,7 +147,7 @@ def get_output_names(self, kv_offload: bool = False): output_names = {} if kv_offload: - lang_output_names.insert(1, "vit_embeds_RetainedState") + lang_output_names.insert(1, "vision_embeds_RetainedState") output_names["vision"] = vision_output_names output_names["lang"] = lang_output_names else: @@ -175,7 +175,7 @@ def get_dummy_inputs(self, kv_offload: bool = False): # Define shapes inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) - inputs_shapes["vit_embeds"] = ( + inputs_shapes["vision_embeds"] = ( constants.INTERN_NUM_PATCHES, constants.INTERN_FEATURE_SIZE, self.language_model.config.hidden_size, @@ -196,7 +196,7 @@ def get_dummy_inputs(self, kv_offload: bool = False): lang_inputs = {} vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) - lang_inputs["vit_embeds"] = torch.zeros((inputs_shapes["vit_embeds"]), dtype=torch.float32) + lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32) lang_inputs["position_ids"] = ( torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) @@ -220,21 +220,21 @@ def get_dummy_inputs(self, kv_offload: bool = False): inputs["vision"] = vision_inputs inputs["lang"] = lang_inputs else: - lang_inputs.pop("vit_embeds") + lang_inputs.pop("vision_embeds") inputs = {**vision_inputs, **lang_inputs} return inputs def forward(self, input_ids, pixel_values, position_ids, past_key_values): input_embeds = self.language_model.get_input_embeddings()(input_ids) - vit_embeds = self.extract_feature(pixel_values) + vision_embeds = self.extract_feature(pixel_values) B, N, C = input_embeds.shape image_input_embeds = input_embeds.reshape(B * N, C) image_input_ids = input_ids.reshape(B * N) selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1 indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) - image_features_expanded = vit_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] + image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds) outputs = self.language_model( diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index 4ce9f087e..d99d8dfc1 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -38,9 +38,9 @@ def forward(self, pixel_values): selected_image_feature = selected_image_feature else: raise ValueError(f"Unexpected select feature strategy: {self.model.config.vision_feature_select_strategy}") - image_features = self.model.multi_modal_projector(selected_image_feature) + vision_embeds = self.model.multi_modal_projector(selected_image_feature) - return image_features + return vision_embeds class QEFFLlavaDecoderWrapper(nn.Module): @@ -50,21 +50,21 @@ def __init__(self, model): self.config = self.model.config self.language_model = self.model.language_model - def forward(self, input_ids, image_features, position_ids, past_key_values): + def forward(self, input_ids, vision_embeds, position_ids, past_key_values): inputs_embeds = self.model.get_input_embeddings()(input_ids) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.model.config.image_token_index indices1 = mask.to(torch.int64).cumsum(1) - 1 indices0 = torch.arange(mask.shape[0]).view(-1, 1) - image_features_expanded = image_features[indices0, indices1] - inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) + vision_embeds_expanded = vision_embeds[indices0, indices1] + inputs_embeds = torch.where(mask.unsqueeze(-1), vision_embeds_expanded, inputs_embeds) outputs = self.model.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, ) - return outputs.logits, image_features, outputs.past_key_values + return outputs.logits, vision_embeds, outputs.past_key_values class QEffLlavaForConditionalGeneration(LlavaForConditionalGeneration): @@ -86,14 +86,14 @@ def forward(self, input_ids, position_ids, pixel_values, past_key_values): selected_image_feature = selected_image_feature else: raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") - image_features = self.multi_modal_projector(selected_image_feature) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + vision_embeds = self.multi_modal_projector(selected_image_feature) + vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.config.image_token_index indices1 = mask.to(torch.int64).cumsum(1) - 1 indices0 = torch.arange(mask.shape[0]).view(-1, 1) - image_features_expanded = image_features[indices0, indices1] - image_inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) + vision_embeds_expanded = vision_embeds[indices0, indices1] + image_inputs_embeds = torch.where(mask.unsqueeze(-1), vision_embeds_expanded, inputs_embeds) # *where to skip image encoder for decode* inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_inputs_embeds) outputs = self.language_model( @@ -118,7 +118,7 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): } lang_inputs = { "input_ids": torch.ones((BS, SEQ_LEN), dtype=torch.int64), - "image_features": torch.ones((BS, 576, self.language_model.config.hidden_size), dtype=torch.float32), + "vision_embeds": torch.ones((BS, 576, self.language_model.config.hidden_size), dtype=torch.float32), "attention_mask": torch.ones((BS, SEQ_LEN), dtype=torch.int64), } lang_inputs["position_ids"] = lang_inputs.pop("attention_mask").cumsum(1) @@ -137,7 +137,7 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): inputs["vision"] = vision_inputs inputs["lang"] = lang_inputs else: - lang_inputs.pop("image_features") + lang_inputs.pop("vision_embeds") inputs = {**vision_inputs, **lang_inputs} return inputs @@ -218,7 +218,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): return dynamic_axes def get_output_names(self, kv_offload: bool = False): - vision_output_names = ["image_features"] + vision_output_names = ["vision_embeds"] lang_output_names = ["logits"] for i in range(self.language_model.config.num_hidden_layers): for kv in ["key", "value"]: @@ -226,7 +226,7 @@ def get_output_names(self, kv_offload: bool = False): output_names = {} if kv_offload: - lang_output_names.insert(1, "image_features_RetainedState") + lang_output_names.insert(1, "vision_embeds_RetainedState") output_names["vision"] = vision_output_names output_names["lang"] = lang_output_names else: diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 7faaff590..b96444b25 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -638,9 +638,12 @@ def compile( custom_io_vision = {} kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" - custom_io_vision["pixel_values"] = kv_cache_dtype + custom_io_vision["pixel_values"] = "float16" for output_name in output_names["vision"]: - custom_io_vision[output_name] = kv_cache_dtype + if output_name.startswith("past_"): + custom_io_vision[output_name] = kv_cache_dtype + else: + custom_io_vision[output_name] = "float16" if vision_onnx_path: self.vision_model.onnx_path = vision_onnx_path @@ -670,12 +673,14 @@ def compile( # Inputs for output_name in output_names["lang"]: if output_name.endswith("_RetainedState"): - custom_io_lang[output_name[: -len("_RetainedState")]] = kv_cache_dtype + custom_io_lang[output_name[: -len("_RetainedState")]] = ( + "float16" if "vision_embeds" in output_name else kv_cache_dtype + ) # outputs for output_name in output_names["lang"]: if output_name.endswith("_RetainedState"): - custom_io_lang[output_name] = kv_cache_dtype + custom_io_lang[output_name] = "float16" if "vision_embeds" in output_name else kv_cache_dtype self.lang_model._compile( compile_dir, @@ -964,12 +969,14 @@ def compile( # inputs for input_name in output_names: if input_name.endswith("_RetainedState"): - custom_io[input_name[: -len("_RetainedState")]] = kv_cache_dtype + custom_io[input_name[: -len("_RetainedState")]] = ( + "float16" if "pixel_values" in input_name else kv_cache_dtype + ) # outputs for output_name in output_names: if output_name.endswith("_RetainedState"): - custom_io[output_name] = kv_cache_dtype + custom_io[output_name] = "float16" if "pixel_values" in output_name else kv_cache_dtype self._compile( onnx_path,