Skip to content

Commit 068d3c6

Browse files
committed
Rebased to update transformers version and addressed comments to edit out older qnn based changes.
Signed-off-by: quic-dhirajku <quic_dhirajku@quicinc.com>
1 parent 796bc33 commit 068d3c6

File tree

1 file changed

+115
-29
lines changed

1 file changed

+115
-29
lines changed

QEfficient/transformers/models/modeling_auto.py

+115-29
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,7 @@ def export(
584584
)
585585

586586
self.lang_model.export(inputs["lang"], output_names["lang"], dynamic_axes["lang"], export_dir)
587+
return self.onnx_path
587588

588589
def compile(
589590
self,
@@ -916,7 +917,7 @@ def export(
916917
inputs = self.model.get_dummy_inputs()
917918
dynamic_axes = self.model.get_onnx_dynamic_axes()
918919
output_names = self.model.get_output_names()
919-
self._export(inputs, output_names, dynamic_axes, export_dir=export_dir)
920+
return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir)
920921

921922
def compile(
922923
self,
@@ -1170,9 +1171,69 @@ def get_model_config(self) -> dict:
11701171

11711172
class QEFFAutoModelForImageTextToText:
11721173
"""
1173-
A factory class for creating QEFFAutoModelForImageTextToText instances with for single and Dual QPC approach
1174+
The QEFFAutoModelForImageTextToText class is used to work with multimodal language models from the HuggingFace hub.
1175+
While you can initialize the class directly, it's best to use the ``from_pretrained`` method for this purpose. This class supports both single and dual QPC approaches.
11741176
Attributes:
11751177
_hf_auto_class (class): The Hugging Face AutoModel class for ImageTextToText models.
1178+
1179+
``Mandatory`` Args:
1180+
:pretrained_model_name_or_path (str): Model card name from HuggingFace or local path to model directory.
1181+
1182+
``Optional`` Args:
1183+
:kv_offload (bool): Flag to toggle between single and dual QPC approaches. If set to False, the Single QPC approach will be used; otherwise, the dual QPC approach will be applied. Defaults to True.
1184+
1185+
.. code-block:: python
1186+
import requests
1187+
from PIL import Image
1188+
from transformers import AutoProcessor, TextStreamer
1189+
1190+
from QEfficient import QEFFAutoModelForImageTextToText
1191+
1192+
# Add HuggingFace Token to access the model
1193+
HF_TOKEN = ""
1194+
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
1195+
query = "Describe this image."
1196+
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
1197+
1198+
## STEP - 1 Load the Processor and Model, and kv_offload=True/False for dual and single qpc
1199+
processor = AutoProcessor.from_pretrained(model_name, token=token)
1200+
model = QEFFAutoModelForImageTextToText.from_pretrained(model_name, token=token, attn_implementation="eager", kv_offload=False)
1201+
1202+
## STEP - 2 Export & Compile the Model
1203+
model.compile(
1204+
prefill_seq_len=32,
1205+
ctx_len=512,
1206+
img_size=560,
1207+
num_cores=16,
1208+
num_devices=1,
1209+
mxfp6_matmul=False,
1210+
)
1211+
1212+
## STEP - 3 Load and process the inputs for Inference
1213+
image = Image.open(requests.get(image_url, stream=True).raw)
1214+
messages = [
1215+
{
1216+
"role": "user",
1217+
"content": [
1218+
{"type": "image"},
1219+
{"type": "text", "text": query},
1220+
],
1221+
}
1222+
]
1223+
input_text = [processor.apply_chat_template(messages, add_generation_prompt=True)]
1224+
inputs = processor(
1225+
text=input_text,
1226+
images=image,
1227+
return_tensors="pt",
1228+
add_special_tokens=False,
1229+
padding="max_length",
1230+
max_length=prefill_seq_len,
1231+
)
1232+
1233+
## STEP - 4 Run Inference on the compiled model
1234+
streamer = TextStreamer(processor.tokenizer)
1235+
model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len)
1236+
11761237
"""
11771238

11781239
_hf_auto_class = AutoModelForImageTextToText
@@ -1219,7 +1280,6 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
12191280
:model (nn.Module): PyTorch model
12201281
:continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later.
12211282
:is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode.
1222-
:enable_qnn (bool): Enables QNN Compilation path for the model.
12231283
12241284
12251285
.. code-block:: python
@@ -1250,7 +1310,6 @@ def __init__(
12501310
model: nn.Module,
12511311
continuous_batching: bool = False,
12521312
is_tlm: bool = False,
1253-
enable_qnn: bool = False,
12541313
**kwargs,
12551314
):
12561315
model_class_name = model.__class__.__name__
@@ -1282,8 +1341,6 @@ def __init__(
12821341
self.model, transformed = SpDTransform.apply(self.model)
12831342
self.is_tlm = is_tlm
12841343

1285-
self.enable_qnn = enable_qnn
1286-
12871344
@property
12881345
def model_name(self) -> str:
12891346
mname = self.model.__class__.__name__
@@ -1292,18 +1349,12 @@ def model_name(self) -> str:
12921349
return mname
12931350

12941351
def __repr__(self) -> str:
1295-
return self.__class__.__name__ + "\n" + self.model.__repr__
1352+
return self.__class__.__name__ + "\n" + self.model.__repr__()
12961353

12971354
@classmethod
12981355
@with_replaced_quantizers
12991356
def from_pretrained(
1300-
cls,
1301-
pretrained_model_name_or_path,
1302-
continuous_batching: bool = False,
1303-
is_tlm: bool = False,
1304-
enable_qnn: bool = False,
1305-
*args,
1306-
**kwargs,
1357+
cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, *args, **kwargs
13071358
):
13081359
"""
13091360
This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM.
@@ -1314,7 +1365,6 @@ def from_pretrained(
13141365
:pretrained_name_or_path (str): Model card name from HuggingFace or local path to model directory.
13151366
:continuous_batching (bool): Whether this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later.
13161367
:is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode.
1317-
:enable_qnn (bool): Enables QNN Compilation path for the model.
13181368
:args, kwargs: Additional arguments to pass to transformers.AutoModelForCausalLM.
13191369
13201370
.. code-block:: python
@@ -1348,7 +1398,6 @@ def from_pretrained(
13481398
kv_offload = kwargs.pop("kv_offload", None)
13491399

13501400
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
1351-
13521401
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
13531402

13541403
# This is support models that should be classified to in a different auto class but transformers load them via this class
@@ -1358,7 +1407,7 @@ def from_pretrained(
13581407
model, kv_offload=kv_offload
13591408
)
13601409

1361-
return cls(model, is_tlm=is_tlm, continuous_batching=continuous_batching, enable_qnn=enable_qnn)
1410+
return cls(model, is_tlm=is_tlm, continuous_batching=continuous_batching)
13621411

13631412
@property
13641413
def model_hash(self) -> str:
@@ -1738,20 +1787,26 @@ def export(self, export_dir: Optional[str] = None) -> str:
17381787
inputs = self.model.get_dummy_inputs()
17391788
dynamic_axes = self.model.get_onnx_dynamic_axes()
17401789
output_names = self.model.get_output_names()
1741-
self._export(inputs, output_names, dynamic_axes, export_dir=export_dir)
1790+
return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir)
17421791

17431792
def compile(
17441793
self,
17451794
onnx_path: Optional[str] = None,
17461795
compile_dir: Optional[str] = None,
17471796
*,
1748-
encoder_ctx_len: int = 1500,
1749-
decoder_ctx_len: int = 150,
1750-
feature_len: int = 3000,
1797+
prefill_seq_len: Optional[int] = 1,
1798+
encoder_ctx_len: Optional[int] = None,
1799+
ctx_len: int = 150,
1800+
full_batch_size: Optional[int] = None,
1801+
kv_cache_batch_size: Optional[int] = None,
17511802
batch_size: int = 1,
17521803
num_devices: int = 1,
17531804
num_cores: int = 16, # FIXME: Make this mandatory arg
17541805
mxfp6_matmul: bool = False,
1806+
mxint8_kv_cache: bool = False,
1807+
num_speculative_tokens: Optional[int] = None,
1808+
enable_qnn: bool = False,
1809+
qnn_config: Optional[str] = None,
17551810
**compiler_options,
17561811
) -> str:
17571812
"""
@@ -1762,19 +1817,41 @@ def compile(
17621817
``Optional`` Args:
17631818
:onnx_path (str, optional): Path to pre-exported onnx model.
17641819
:compile_dir (str, optional): Path for saving the qpc generated.
1765-
:seq_len (int, optional): The length of the prompt should be less that ``seq_len``. ``Defaults to 32``.
1820+
:encoder_ctx_len (int, optional): The maximum length of context for encoder, based on the AutoProcessor output. ``Defaults to checking config, if None in config then 1500``
1821+
:ctx_len (int, optional): The maximum length of context to keep for decoding. ``Defaults to 150``.
17661822
:batch_size (int, optional): Batch size. ``Defaults to 1``.
17671823
:num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1.
17681824
:num_cores (int): Number of cores used to compile the model.
17691825
:mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to False``.
17701826
:aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
1771-
:allow_mxint8_mdp_io (bool, optional): Allows MXINT8 compression of MDP IO traffic. ``Defaults to False.``
1827+
1828+
Other args are not yet implemented for AutoModelForSpeechSeq2Seq
17721829
Returns:
17731830
:str: Path of the compiled ``qpc`` package.
17741831
"""
1775-
specializations = self.model.get_specializations(batch_size, encoder_ctx_len, decoder_ctx_len, feature_len)
1832+
specializations, compiler_options = self.model.get_specializations(
1833+
batch_size,
1834+
encoder_ctx_len,
1835+
ctx_len,
1836+
**compiler_options,
1837+
)
17761838

1777-
self._compile(
1839+
if full_batch_size:
1840+
logger.warning("Continuous batching is not yet enabled for AutoModelForSpeechSeq2Seq")
1841+
1842+
if kv_cache_batch_size:
1843+
logger.warning("Prefix caching is not yet enabled for AutoModelForSpeechSeq2Seq")
1844+
1845+
if mxint8_kv_cache:
1846+
logger.warning("mxint8 cache is not yet enabled for AutoModelForSpeechSeq2Seq")
1847+
1848+
if num_speculative_tokens:
1849+
logger.warning("Speculative decoding is not yet enabled for AutoModelForSpeechSeq2Seq")
1850+
1851+
if enable_qnn or qnn_config:
1852+
logger.warning("QNN compile is not yet enabled for AutoModelForSpeechSeq2Seq")
1853+
1854+
return self._compile(
17781855
onnx_path,
17791856
compile_dir,
17801857
compile_only=True,
@@ -1792,7 +1869,6 @@ def generate(
17921869
inputs: torch.Tensor,
17931870
generation_len: int,
17941871
streamer: Optional[TextStreamer] = None,
1795-
enable_debug_logs: bool = False,
17961872
device_ids: List[int] = None,
17971873
) -> Union[torch.Tensor, np.ndarray]:
17981874
"""
@@ -1801,9 +1877,8 @@ def generate(
18011877
18021878
``Mandatory`` Args:
18031879
:processor: autoprocessor to process inputs and decode logits
1804-
:inputs (np.ndarray): inputs to run the execution.
1880+
:inputs (torch.Tensor): inputs to run the execution.
18051881
:generation_len (int): length upto which to generate
1806-
:sample_rate (int): sampling rate at which input audio is stored in inputs (needed for processor)
18071882
:device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
18081883
Returns:
18091884
:dict: Output from the ``AI_100`` or ``PyTorch`` runtime.
@@ -1814,9 +1889,20 @@ def generate(
18141889
inputs = self.auto_correct_inputs(inputs)
18151890

18161891
if self.qpc_session is None:
1817-
self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids, enable_debug_logs=enable_debug_logs)
1892+
self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids)
18181893
self.batch_size = self.qpc_session.bindings[0].dims[0]
18191894

1895+
inputs["input_features"] = inputs["input_features"].numpy().astype(np.float32)
1896+
1897+
# add start token id and initial position ids to inputs
1898+
seq_len = 1
1899+
inputs["decoder_input_ids"] = (
1900+
torch.ones((self.batch_size, seq_len), dtype=torch.int64) * self.model.config.decoder_start_token_id
1901+
).numpy()
1902+
inputs["decoder_position_ids"] = (
1903+
torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(self.batch_size, 1).numpy()
1904+
)
1905+
18201906
self.qpc_session.skip_buffers(
18211907
[x for x in self.qpc_session.input_names + self.qpc_session.output_names if x.startswith("past_")]
18221908
)

0 commit comments

Comments
 (0)