Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Exllama kernels support for AWQ models #28634

Merged
merged 16 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ RUN python3 -m pip install --no-cache-dir auto-gptq --extra-index-url https://hu
RUN python3 -m pip install --no-cache-dir einops

# Add autoawq for quantization testing
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp38-cp38-linux_x86_64.whl
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.9/autoawq-0.1.9+cu118-cp38-cp38-linux_x86_64.whl

# For bettertransformer + gptq
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/optimum@main#egg=optimum
Expand Down
12 changes: 10 additions & 2 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@


_import_structure = {
"awq": ["fuse_awq_modules", "replace_with_awq_linear"],
"awq": [
"fuse_awq_modules",
"post_init_awq_exllama_modules",
"replace_with_awq_linear",
],
"bitsandbytes": [
"get_keys_to_not_convert",
"replace_8bit_linear",
Expand Down Expand Up @@ -80,7 +84,11 @@
}

if TYPE_CHECKING:
from .awq import fuse_awq_modules, replace_with_awq_linear
from .awq import (
fuse_awq_modules,
post_init_awq_exllama_modules,
replace_with_awq_linear,
)
from .bitsandbytes import (
get_keys_to_not_convert,
replace_8bit_linear,
Expand Down
41 changes: 35 additions & 6 deletions src/transformers/integrations/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
from ..activations import ACT2FN
from ..modeling_utils import PreTrainedModel
from ..utils import is_auto_awq_available, is_torch_available
from ..utils.quantization_config import AwqBackendPackingMethod, AwqConfig, AWQLinearVersion
from ..utils.quantization_config import (
AwqBackendPackingMethod,
AwqConfig,
AWQLinearVersion,
)


if is_torch_available():
Expand Down Expand Up @@ -91,13 +95,23 @@ def replace_with_awq_linear(
)

if backend == AwqBackendPackingMethod.AUTOAWQ:
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
elif backend == AwqBackendPackingMethod.LLMAWQ:
from awq.quantize.qmodule import WQLinear
if quantization_config.version == AWQLinearVersion.GEMM:
from awq.modules.linear.gemm import WQLinear_GEMM

if backend == AwqBackendPackingMethod.AUTOAWQ:
target_cls = WQLinear_GEMM if quantization_config.version == AWQLinearVersion.GEMM else WQLinear_GEMV
target_cls = WQLinear_GEMM
elif quantization_config.version == AWQLinearVersion.GEMV:
from awq.modules.linear.gemv import WQLinear_GEMV

target_cls = WQLinear_GEMV
elif quantization_config.version == AWQLinearVersion.EXLLAMA:
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2

target_cls = WQLinear_ExllamaV2
else:
raise ValueError(f"Unsupported AWQLinearVersion: {quantization_config.version}")
else:
from awq.quantize.qmodule import WQLinear

target_cls = WQLinear

for name, module in model.named_children():
Expand Down Expand Up @@ -372,3 +386,18 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
setattr(parent, child_name, fused_attention_layer.to(previous_device))

del q_proj, k_proj, v_proj, o_proj


def post_init_awq_exllama_modules(model):
"""
Runs post init for Exllama layers which performs:
- Weights unpacking, reordering and repacking
- Devices scratch space allocation
"""
from awq.modules.linear.exllamav2 import exllamav2_post_init

# default values for exllamav2 from
# https://github.com/AutoGPTQ/AutoGPTQ/blob/6ba14f17ef73c161c2c4707cbf0b41e569a9c6dd/auto_gptq/nn_modules/qlinear/qlinear_exllamav2.py#L171
model = exllamav2_post_init(model, max_input_len=2048, max_batch_size=8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couldn't we make max_input_len configurable through AwqConfig - wdyt?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marc suggested we leave it as is for now #28634 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh okay ! I think it would makes sense to directly expose a exllama_config I think - wdyt @SunMarc ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it would make more sense to expose it in a exllama_config !

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess in another PR right ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm I think it should be better to add it now and not leave the main branch with hardcoded config values, it shouldn't be super complex as you can just copy over the existing logic in GptqConfig right?


return model
11 changes: 11 additions & 0 deletions src/transformers/quantizers/quantizer_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ..modeling_utils import PreTrainedModel

from ..utils import is_accelerate_available, is_auto_awq_available, is_torch_available, logging
from ..utils.quantization_config import AWQLinearVersion


if is_torch_available():
Expand Down Expand Up @@ -95,12 +96,22 @@ def _process_model_after_weight_loading(self, model):
model = fuse_awq_modules(model, self.quantization_config)
model._awq_is_fused = True # TODO: consider storing this flag in model.config instead

if self.quantization_config.version == AWQLinearVersion.EXLLAMA:
from ..integrations import post_init_awq_exllama_modules

model = post_init_awq_exllama_modules(model)

@property
def is_serializable(self):
# AWQ through auto-awq has been always serializable, except if the model is fused.
if self.quantization_config.do_fuse:
logger.warning("You cannot save an AWQ model that uses fused modules!")
return False

if self.quantization_config.version == AWQLinearVersion.EXLLAMA:
logger.warning("You cannot save an AWQ model that uses Exllama backend!")
return False

return True

@property
Expand Down
23 changes: 20 additions & 3 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class QuantizationMethod(str, Enum):
class AWQLinearVersion(str, Enum):
GEMM = "gemm"
GEMV = "gemv"
EXLLAMA = "exllama"

@staticmethod
def from_str(version: str):
Expand All @@ -51,6 +52,8 @@ def from_str(version: str):
return AWQLinearVersion.GEMM
elif version == "gemv":
return AWQLinearVersion.GEMV
elif version == "exllama":
return AWQLinearVersion.EXLLAMA
else:
raise ValueError(f"Unknown AWQLinearVersion {version}")

Expand Down Expand Up @@ -669,9 +672,9 @@ def post_init(self):
)

self.version = AWQLinearVersion.from_str(self.version)
if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV]:
if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA]:
raise ValueError(
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV] - not recognized version {self.version}"
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA] - not recognized version {self.version}"
)

if self.backend == AwqBackendPackingMethod.LLMAWQ:
Expand Down Expand Up @@ -726,8 +729,22 @@ def post_init(self):
f"Required fields are missing in the fusing mapping, required fields are {required_keys}"
)

if self.version == AWQLinearVersion.EXLLAMA:
awq_version_supports_exllama = False
MIN_AWQ_VERSION = "0.1.9"
if is_auto_awq_available():
awq_version_supports_exllama = version.parse(importlib.metadata.version("autoawq")) >= version.parse(
MIN_AWQ_VERSION
)

if not awq_version_supports_exllama:
raise ValueError(
f"You current version of `autoawq` does not support exllama backend, "
f"please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
)

def get_loading_attributes(self):
attibutes_dict = copy.deepcopy(self.__dict__)
loading_attibutes = ["do_fuse", "modules_to_fuse", "fuse_max_seq_len"]
loading_attibutes = ["version", "do_fuse", "modules_to_fuse", "fuse_max_seq_len"]
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
return loading_attibutes_dict
14 changes: 14 additions & 0 deletions tests/quantization/autoawq/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,20 @@ def test_quantized_model_bf16(self):
output = quantized_model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT_BF16)

def test_quantized_model_exllama(self):
"""
Simple test that checks if the quantized model is working properly with exllama backend
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)

quantization_config = AwqConfig(version="exllama")
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, quantization_config=quantization_config
).to(torch_device)

output = quantized_model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

def test_quantized_model_no_device_map(self):
"""
Simple test that checks if the quantized model is working properly
Expand Down
Loading