Skip to content

Commit fc323f4

Browse files
committed
Added support for Llava model single QPC (#265)
Added support for Laava model single QPC Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>
1 parent 87e07d0 commit fc323f4

File tree

8 files changed

+403
-135
lines changed

8 files changed

+403
-135
lines changed

QEfficient/base/modeling_qeff.py

+1
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def _export(
190190

191191
except Exception as e:
192192
logger.error(f"ONNX export (or) ONNXTransforms failed: {e}")
193+
193194
raise e
194195

195196
finally:

QEfficient/generation/text_generation_inference.py

+13
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,19 @@ def __repr__(self):
6363
\nTotal (E2E) inference time is= {round(self.perf_metrics.total_time, 2)}"
6464

6565

66+
@dataclass
67+
class CloudAI100ExecInfoNew:
68+
batch_size: int
69+
generated_ids: Union[List[np.ndarray], np.ndarray]
70+
perf_metrics: PerfMetrics
71+
72+
def __repr__(self):
73+
return f"Average Prefill time a.k.a TTFT is= {round(self.perf_metrics.prefill_time, 2)}\
74+
\nDecode token/sec is= {round(self.perf_metrics.decode_perf * self.batch_size, 2)}\
75+
\nTotal token/sec is= {round(self.perf_metrics.total_perf * self.batch_size, 2)}\
76+
\nTotal (E2E) inference time is= {round(self.perf_metrics.total_time, 2)}"
77+
78+
6679
io_files = []
6780

6881

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
from typing import List, Optional, Tuple, Union
8+
9+
import torch
10+
import torch.utils.checkpoint
11+
from torch import nn
12+
from transformers.models.llava.modeling_llava import (
13+
LlavaCausalLMOutputWithPast,
14+
LlavaForConditionalGeneration,
15+
logger,
16+
)
17+
18+
BS = 1
19+
NUM_CHANNEL = 3
20+
SEQ_LEN = 592
21+
IMAGE_SIZE = 336
22+
CTX_LEN = 1024
23+
24+
25+
class QEffLlavaForConditionalGeneration(LlavaForConditionalGeneration):
26+
def forward(
27+
self,
28+
input_ids: torch.LongTensor = None,
29+
pixel_values: torch.FloatTensor = None,
30+
attention_mask: Optional[torch.Tensor] = None,
31+
position_ids: Optional[torch.LongTensor] = None,
32+
past_key_values: Optional[List[torch.FloatTensor]] = None,
33+
inputs_embeds: Optional[torch.FloatTensor] = None,
34+
vision_feature_layer: Optional[int] = None,
35+
vision_feature_select_strategy: Optional[str] = None,
36+
labels: Optional[torch.LongTensor] = None,
37+
use_cache: Optional[bool] = None,
38+
output_attentions: Optional[bool] = None,
39+
output_hidden_states: Optional[bool] = None,
40+
return_dict: Optional[bool] = None,
41+
cache_position: Optional[torch.LongTensor] = None,
42+
num_logits_to_keep: int = 0,
43+
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
44+
r"""
45+
Args:
46+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
47+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
48+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
49+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
50+
51+
num_logits_to_keep (`int`, *optional*):
52+
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
53+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
54+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
55+
56+
57+
Returns:
58+
59+
Example:
60+
61+
```python
62+
>>> from PIL import Image
63+
>>> import requests
64+
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration
65+
66+
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
67+
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
68+
69+
>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
70+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
71+
>>> image = Image.open(requests.get(url, stream=True).raw)
72+
73+
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
74+
75+
>>> # Generate
76+
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
77+
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
78+
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
79+
```"""
80+
81+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
82+
output_hidden_states = (
83+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
84+
)
85+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
86+
vision_feature_layer = (
87+
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
88+
)
89+
vision_feature_select_strategy = (
90+
vision_feature_select_strategy
91+
if vision_feature_select_strategy is not None
92+
else self.config.vision_feature_select_strategy
93+
)
94+
95+
if (input_ids is None) ^ (inputs_embeds is not None):
96+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
97+
98+
if pixel_values is not None and inputs_embeds is not None:
99+
raise ValueError(
100+
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
101+
)
102+
103+
legacy_processing = False
104+
if inputs_embeds is None:
105+
inputs_embeds = self.get_input_embeddings()(input_ids)
106+
107+
# if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
108+
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
109+
# In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
110+
legacy_processing = (
111+
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
112+
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
113+
114+
if pixel_values is not None:
115+
image_features = self.get_image_features(
116+
pixel_values=pixel_values,
117+
vision_feature_layer=vision_feature_layer,
118+
vision_feature_select_strategy=vision_feature_select_strategy,
119+
)
120+
121+
if legacy_processing:
122+
logger.warning_once(
123+
"Expanding inputs for image tokens in LLaVa should be done in processing. "
124+
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
125+
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
126+
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
127+
)
128+
# prefill stage vs decoding stage (legacy behavior copied)
129+
if input_ids.shape[1] != 1:
130+
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
131+
image_features, inputs_embeds, input_ids, attention_mask, labels
132+
)
133+
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
134+
else:
135+
# Retrieve the first layer to inspect the logits and mask out the hidden states
136+
# that are set to 0
137+
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
138+
139+
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
140+
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
141+
142+
# Get the target length
143+
target_length = input_ids.shape[1]
144+
past_length = first_layer_past_key_value.shape[-1]
145+
146+
extended_attention_mask = torch.ones(
147+
(attention_mask.shape[0], past_length),
148+
dtype=attention_mask.dtype,
149+
device=attention_mask.device,
150+
)
151+
152+
# Filter out only the tokens that can be un-attended, this can happen
153+
# if one uses Llava + Fused modules where the cache on the
154+
# first iteration is already big enough, or if one passes custom cache
155+
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
156+
new_batch_index = batch_index[valid_indices]
157+
new_non_attended_tokens = non_attended_tokens[valid_indices]
158+
159+
# Zero-out the places where we don't need to attend
160+
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
161+
162+
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
163+
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
164+
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
165+
-target_length:
166+
]
167+
168+
# TODO: @raushan retain only the new behavior after v4.47
169+
else:
170+
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
171+
n_image_features = image_features.shape[1]
172+
if n_image_tokens != n_image_features:
173+
raise ValueError(
174+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
175+
)
176+
177+
mask = input_ids == self.config.image_token_index
178+
indices1 = mask.to(torch.int64).cumsum(1) - 1
179+
indices0 = torch.arange(mask.shape[0]).view(-1, 1)
180+
image_features_expanded = image_features[indices0, indices1]
181+
image_inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds)
182+
# *where to skip image encoder for decode*
183+
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_inputs_embeds)
184+
185+
outputs = self.language_model(
186+
attention_mask=attention_mask,
187+
position_ids=position_ids,
188+
past_key_values=past_key_values,
189+
inputs_embeds=inputs_embeds,
190+
use_cache=use_cache,
191+
output_attentions=output_attentions,
192+
output_hidden_states=output_hidden_states,
193+
return_dict=return_dict,
194+
cache_position=cache_position,
195+
num_logits_to_keep=num_logits_to_keep,
196+
)
197+
198+
logits = outputs[0]
199+
200+
loss = None
201+
if labels is not None:
202+
# Shift so that tokens < n predict n
203+
if attention_mask is not None:
204+
# we use the input attention mask to shift the logits and labels, because it is 2D.
205+
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
206+
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
207+
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
208+
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
209+
else:
210+
shift_logits = logits[..., :-1, :].contiguous()
211+
shift_labels = labels[..., 1:].contiguous()
212+
# Flatten the tokens
213+
loss_fct = nn.CrossEntropyLoss()
214+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device))
215+
216+
if not return_dict:
217+
output = (logits,) + outputs[1:]
218+
return (loss,) + output if loss is not None else output
219+
220+
return logits, pixel_values, outputs.past_key_values
221+
222+
def get_dummy_inputs(self, **kwargs):
223+
num_layers = self.config.text_config.num_hidden_layers
224+
num_key_value_heads = self.config.text_config.num_key_value_heads
225+
head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads
226+
227+
inputs = {
228+
"input_ids": torch.ones((BS, SEQ_LEN), dtype=torch.int64),
229+
"attention_mask": torch.ones((BS, SEQ_LEN), dtype=torch.int64),
230+
"pixel_values": torch.zeros((BS, NUM_CHANNEL, IMAGE_SIZE, IMAGE_SIZE), dtype=torch.float32),
231+
}
232+
inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1)
233+
inputs["past_key_values"] = []
234+
for i in range(num_layers):
235+
inputs["past_key_values"].append(
236+
(
237+
torch.zeros(BS, num_key_value_heads, CTX_LEN, head_dim),
238+
torch.zeros(BS, num_key_value_heads, CTX_LEN, head_dim),
239+
)
240+
)
241+
inputs["position_ids"] = torch.full(inputs["position_ids"].shape, CTX_LEN - 1)
242+
return inputs
243+
244+
def get_specializations(
245+
self, batch_size: int, prefill_seq_len: int, ctx_len: int, img_size: int, **compiler_options
246+
):
247+
# TODO: check if this should be named num_crops or something else
248+
max_num_images = compiler_options.get("max_num_images", 1)
249+
prefill_seq_len = prefill_seq_len if prefill_seq_len else SEQ_LEN
250+
ctx_len = ctx_len if ctx_len else CTX_LEN
251+
img_size = img_size if img_size else IMAGE_SIZE
252+
253+
return [
254+
{
255+
"batch_size": batch_size,
256+
"seq_len": prefill_seq_len,
257+
"ctx_len": ctx_len,
258+
"max_num_images": max_num_images,
259+
"img_size": img_size,
260+
},
261+
{
262+
"batch_size": batch_size,
263+
"seq_len": "1",
264+
"ctx_len": ctx_len,
265+
"max_num_images": max_num_images,
266+
"img_size": img_size,
267+
},
268+
]
269+
270+
def get_onnx_dynamic_axes(
271+
self,
272+
):
273+
# Define dynamic axes
274+
num_layers = self.config.text_config.num_hidden_layers
275+
276+
dynamic_axes = {
277+
"input_ids": {0: "batch_size", 1: "seq_len"},
278+
"position_ids": {0: "batch_size", 1: "seq_len"},
279+
"pixel_values": {0: "batch_size", 2: "img_size", 3: "img_size"},
280+
}
281+
for i in range(num_layers):
282+
dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"}
283+
dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"}
284+
285+
return dynamic_axes
286+
287+
def get_output_names(
288+
self,
289+
):
290+
output_names = ["logits", "pixel_values_RetainedState"]
291+
for i in range(self.language_model.config.num_hidden_layers):
292+
for kv in ["key", "value"]:
293+
output_names.append(f"past_{kv}.{i}_RetainedState")
294+
return output_names

QEfficient/transformers/models/mllama/modeling_mllama.py

-54
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
rotate_half,
4040
)
4141

42-
from QEfficient.transformers.cache_utils import QEffDynamicCache
4342
from QEfficient.transformers.modeling_utils import (
4443
_create_causal_mask,
4544
_prepare_aspect_ratio_attention_mask,
@@ -1204,56 +1203,3 @@ def generate_dummy_io_info(self, kv_offload=False):
12041203
output_names = lang_output_names
12051204

12061205
return inputs, output_names, dynamic_axes, inputs_shape
1207-
1208-
1209-
class ModelWrapper(nn.Module):
1210-
def __init__(self, mllama):
1211-
super().__init__()
1212-
self.mllama = mllama
1213-
self.num_hidden_layers = mllama.config.get_text_config().num_hidden_layers
1214-
self.config = self.mllama.config.get_text_config()
1215-
1216-
def forward(
1217-
self,
1218-
input_ids: Optional[torch.LongTensor] = None,
1219-
pixel_values: Optional[torch.FloatTensor] = None,
1220-
aspect_ratio_mask: Optional[torch.Tensor] = None,
1221-
aspect_ratio_ids: Optional[torch.Tensor] = None,
1222-
attention_mask: Optional[torch.Tensor] = None,
1223-
cross_attention_mask: Optional[torch.Tensor] = None,
1224-
cross_attention_states: Optional[torch.Tensor] = None,
1225-
position_ids: Optional[torch.LongTensor] = None,
1226-
past_key_values: Optional[List[torch.FloatTensor]] = None,
1227-
inputs_embeds: Optional[torch.FloatTensor] = None,
1228-
labels: Optional[torch.LongTensor] = None,
1229-
use_cache: Optional[bool] = None,
1230-
output_attentions: Optional[bool] = None,
1231-
output_hidden_states: Optional[bool] = None,
1232-
return_dict: Optional[bool] = None,
1233-
cache_position: Optional[torch.LongTensor] = None,
1234-
num_logits_to_keep: int = 0,
1235-
):
1236-
if past_key_values is not None:
1237-
past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)
1238-
outputs = self.mllama(
1239-
input_ids=input_ids,
1240-
pixel_values=pixel_values,
1241-
aspect_ratio_mask=aspect_ratio_mask,
1242-
aspect_ratio_ids=aspect_ratio_ids,
1243-
attention_mask=attention_mask,
1244-
cross_attention_mask=cross_attention_mask,
1245-
cross_attention_states=cross_attention_states,
1246-
position_ids=position_ids,
1247-
past_key_values=past_key_values,
1248-
inputs_embeds=inputs_embeds,
1249-
labels=labels,
1250-
use_cache=use_cache,
1251-
output_attentions=output_attentions,
1252-
output_hidden_states=output_hidden_states,
1253-
return_dict=return_dict,
1254-
cache_position=cache_position,
1255-
num_logits_to_keep=num_logits_to_keep,
1256-
)
1257-
if "past_key_values" in outputs:
1258-
outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache()
1259-
return outputs

0 commit comments

Comments
 (0)