Skip to content

Commit b6b232a

Browse files
committed
Approach changed
Signed-off-by: amitraj <quic_amitraj@quicinc.com>
1 parent 750bc87 commit b6b232a

File tree

6 files changed

+33
-28
lines changed

6 files changed

+33
-28
lines changed

QEfficient/transformers/custom_attention.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
from torch import nn
1313
from transformers.models.bert.modeling_bert import BertSelfAttention
1414

15-
from QEfficient.utils.constants import BLOCK_SIZE
16-
1715

1816
class QEffBertSelfAttention(BertSelfAttention):
1917
def forward(
@@ -25,7 +23,7 @@ def forward(
2523
encoder_attention_mask: Optional[torch.FloatTensor] = None,
2624
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
2725
output_attentions: Optional[bool] = False,
28-
block_size: int = BLOCK_SIZE,
26+
block_size: int = None,
2927
) -> Tuple[torch.Tensor]:
3028
mixed_query_layer = self.query(hidden_states)
3129

QEfficient/transformers/models/modeling_auto.py

+10-21
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,11 @@ def __repr__(self) -> str:
5757
def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, *args, **kwargs):
5858
if kwargs.get("attn_implementation", None) not in {None, "eager"}:
5959
logger.warning('Updating attn_implementation="eager"')
60-
kwargs.update({"attn_implementation": "eager"})
60+
6161
if kwargs.get("low_cpu_mem_usage", None):
6262
logger.warning("Updating low_cpu_mem_usage=False")
63-
kwargs.update({"low_cpu_mem_usage": False})
63+
64+
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
6465

6566
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
6667
return cls(model, is_tlm=is_tlm)
@@ -430,20 +431,16 @@ class QEFFAutoModel(QEFFTransformersBase):
430431
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
431432
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
432433

433-
def __init__(self, model: nn.Module, **kwargs):
434-
if kwargs.get("block_size", None):
435-
constants.BLOCK_SIZE = kwargs.get("block_size")
436-
self._pytorch_transforms.append(BlockAttentionTransorm)
437-
kwargs.update({"attn_implementation": "custom"})
438-
kwargs.pop("block_size")
439-
434+
def __init__(self, model: nn.Module, block_size: Optional[int] = None, **kwargs):
435+
if block_size:
436+
BlockAttentionTransorm.apply(model, block_size=block_size)
440437
super().__init__(model)
441438
self.model.config.use_cache = True
442439
self.num_layers = model.config.num_hidden_layers
443440

444441
@classmethod
445442
@with_replaced_quantizers
446-
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
443+
def from_pretrained(cls, pretrained_model_name_or_path, block_size: Optional[int] = None, *args, **kwargs):
447444
"""
448445
This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModel.
449446
Once the model is initialized, you can use other methods such as export, compile, and generate on the same object.
@@ -470,28 +467,20 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
470467
# You can now execute the model
471468
model.generate(inputs)
472469
"""
473-
if kwargs.get("block_size", None):
474-
constants.BLOCK_SIZE = kwargs.get("block_size")
475-
cls._pytorch_transforms.append(BlockAttentionTransorm)
476-
kwargs.update({"attn_implementation": "custom"})
477-
kwargs.pop("block_size")
478-
479-
if kwargs.get("attn_implementation", None) not in {None, "eager", "custom"}:
470+
if kwargs.get("attn_implementation", None) not in {None, "eager"}:
480471
logger.warning('Updating attn_implementation="eager"')
481-
kwargs.update({"attn_implementation": "eager"})
482472

483473
if kwargs.get("low_cpu_mem_usage", None):
484474
logger.warning("Updating low_cpu_mem_usage=False")
485-
kwargs.update({"low_cpu_mem_usage": False})
486475

476+
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False, "add_pooling_layer": False})
487477
try:
488-
kwargs.update({"add_pooling_layer": False})
489478
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
490479
warnings.warn("Removing pooling layer from the model if exist")
491480
except TypeError:
492481
kwargs.pop("add_pooling_layer", None)
493482
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
494-
return cls(model)
483+
return cls(model, block_size)
495484

496485
@property
497486
def model_hash(self) -> str:

QEfficient/transformers/models/pytorch_transforms.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
from functools import partial
89
from types import MethodType
910
from typing import Tuple
1011

@@ -355,6 +356,13 @@ class BlockAttentionTransorm(ModuleMappingTransform):
355356
}
356357

357358
@classmethod
358-
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
359-
model, transformed = super().apply(model)
359+
def apply(cls, model: nn.Module, block_size) -> Tuple[nn.Module, bool]:
360+
transformed = False
361+
for module in model.modules():
362+
if repl_module := cls._module_mapping.get(type(module)):
363+
module.__class__ = repl_module
364+
# Bind the partial function to the instance
365+
module.forward = MethodType(partial(repl_module.forward, block_size=block_size), module)
366+
transformed = True
367+
break
360368
return model, transformed

QEfficient/utils/_utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import json
99
import os
1010
import subprocess
11+
from contextlib import contextmanager
1112
from typing import Any, Dict, List, Optional, Tuple, Union
1213

1314
import requests
@@ -394,3 +395,13 @@ def create_json(file_path: str, json_data: object):
394395
json.dump(json_data, file, indent=4)
395396
except Exception as e:
396397
print(f"Failed to create JSON File {file_path}: {e}")
398+
399+
400+
@contextmanager
401+
def temporarily_remove_key(d, key):
402+
value = d.pop(key, None)
403+
try:
404+
yield
405+
finally:
406+
if value is not None:
407+
d[key] = value

QEfficient/utils/constants.py

-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def get_models_dir():
4949
ONNX_EXPORT_EXAMPLE_FBS = 4
5050
ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep
5151
ONNX_EXPORT_OPSET = 13
52-
BLOCK_SIZE = 32
5352

5453
COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"]
5554

tests/peft/test_peft_onnx_transforms.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_adapter_weights_to_inputs_transform():
2424
<
2525
float[32, 32] layer1_{adapter_name}_weight = [ "location" : "{external_tensors_file}" ],
2626
float[32, 32] layer2_{adapter_name}_weight = [ "location" : "{external_tensors_file}" ]
27-
>
27+
>f
2828
{{
2929
layer1output = MatMul (input, layer1_{adapter_name}_weight)
3030
output = MatMul (layer1output, layer2_{adapter_name}_weight)

0 commit comments

Comments
 (0)