Skip to content

Commit

Permalink
add qwen2vl support (#599)
Browse files Browse the repository at this point in the history
Co-authored-by: Casper <casperbh.96@gmail.com>
  • Loading branch information
kq-chen and casper-hansen authored Nov 14, 2024
1 parent 12c91b7 commit 76bc0a8
Show file tree
Hide file tree
Showing 6 changed files with 627 additions and 5 deletions.
1 change: 1 addition & 0 deletions awq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
from .deepseek_v2 import DeepseekV2AWQForCausalLM
from .minicpm import MiniCPMAWQForCausalLM
from .internlm2 import InternLM2AWQForCausalLM
from .qwen2vl import Qwen2VLAWQForCausalLM
1 change: 1 addition & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"deepseek_v2": DeepseekV2AWQForCausalLM,
"minicpm": MiniCPMAWQForCausalLM,
"internlm2": InternLM2AWQForCausalLM,
"qwen2_vl": Qwen2VLAWQForCausalLM,
}


Expand Down
17 changes: 12 additions & 5 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
PreTrainedModel,
PretrainedConfig,
AutoProcessor,
CLIPImageProcessor,
BaseImageProcessor,
PreTrainedTokenizer,
)
from accelerate.big_modeling import (
Expand Down Expand Up @@ -74,6 +74,7 @@
"baichuan": "AutoModelForCausalLM",
"llava": "AutoModelForVision2Seq",
"qwen2": "AutoModelForCausalLM",
"qwen2_vl": "AutoModelForVision2Seq",
"gemma": "AutoModelForCausalLM",
"gemma2": "AutoModelForCausalLM",
"stablelm": "AutoModelForCausalLM",
Expand All @@ -84,6 +85,7 @@
"deepseek_v2": "AutoModelForCausalLM",
"minicpm": "AutoModelForCausalLM",
"internlm2": "AutoModelForCausalLM",
"qwen2_vl": "AutoModelForVision2Seq",
}


Expand All @@ -100,7 +102,7 @@ def __init__(
AwqConfig, Doc("The quantization config of the model.")
],
processor: Annotated[
AutoProcessor, Doc("An optional processor, e.g. for vision models.")
BaseImageProcessor, Doc("An optional processor, e.g. for vision models.")
],
):
"""The base model for all AutoAWQ models."""
Expand All @@ -111,7 +113,7 @@ def __init__(
self.search_result = None
self.config: PretrainedConfig = config
self.quant_config: AwqConfig = quant_config
self.processor: CLIPImageProcessor = processor
self.processor: BaseImageProcessor = processor

def to(self, device: Annotated[str, Doc("The device to move your model to.")]):
"""A utility function for moving the model to a device."""
Expand Down Expand Up @@ -186,6 +188,11 @@ def quantize(
] = 1024
* 1024
* 1024,
quantizer_cls: Annotated[
AwqQuantizer,
Doc("If you want to customize the quantization class, you can use AwqQuantizer as a base class.")
] = AwqQuantizer,
**kwargs,
):
"""
The main quantization function that you can use to quantize your model.
Expand All @@ -209,7 +216,7 @@ def quantize(
if hasattr(self, "modules_to_not_convert"):
self.quant_config.modules_to_not_convert = self.modules_to_not_convert

self.quantizer = AwqQuantizer(
self.quantizer = quantizer_cls(
self,
self.model,
tokenizer,
Expand All @@ -228,6 +235,7 @@ def quantize(
max_calib_samples=max_calib_samples,
max_calib_seq_len=max_calib_seq_len,
max_chunk_memory=max_chunk_memory,
**kwargs,
)
self.quantizer.quantize()

Expand Down Expand Up @@ -373,7 +381,6 @@ def from_pretrained(
processor = None
if target_cls_name == "AutoModelForVision2Seq":
processor = AutoProcessor.from_pretrained(model_weights_path)
processor: CLIPImageProcessor = processor.image_processor

# If not quantized, must load with AutoModelForCausalLM
model = target_cls.from_pretrained(
Expand Down
75 changes: 75 additions & 0 deletions awq/models/qwen2vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from .base import BaseAWQForCausalLM
from typing_extensions import TYPE_CHECKING

if TYPE_CHECKING:
from transformers import Qwen2VLForConditionalGeneration
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLDecoderLayer

class Qwen2VLAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "Qwen2VLDecoderLayer"
max_seq_len_key = "max_position_embeddings"
modules_to_not_convert = ["visual"]

@staticmethod
def get_model_layers(model: "Qwen2VLForConditionalGeneration"):
return model.model.layers

@staticmethod
def get_act_for_scaling(module: "Qwen2VLForConditionalGeneration"):
return dict(is_scalable=False)

@staticmethod
def move_embed(model: "Qwen2VLForConditionalGeneration", device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.visual = model.visual.to(device)

@staticmethod
def get_layers_for_scaling(module: "Qwen2VLDecoderLayer", input_feat, module_kwargs):
layers = []

# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)

# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)

# linear 1
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)

# linear 2
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)

return layers
Loading

0 comments on commit 76bc0a8

Please # to comment.