From 4ab7dbb7c260137bc90a39b07abcf0b4f4747055 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8F=AF=E4=BA=B2?= Date: Thu, 8 Aug 2024 20:49:56 +0800 Subject: [PATCH 1/4] add qwen2vl support --- awq/models/base.py | 1 + awq/models/qwen2vl.py | 255 +++++++++++++++++++++++++++++++++++ examples/quantize_qwen2vl.py | 87 ++++++++++++ 3 files changed, 343 insertions(+) create mode 100644 awq/models/qwen2vl.py create mode 100644 examples/quantize_qwen2vl.py diff --git a/awq/models/base.py b/awq/models/base.py index 66dc02e6..e0938b95 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -75,6 +75,7 @@ "baichuan": "AutoModelForCausalLM", "llava": "AutoModelForVision2Seq", "qwen2": "AutoModelForCausalLM", + "qwen2_vl": "AutoModelForVision2Seq", "gemma": "AutoModelForCausalLM", "stablelm": "AutoModelForCausalLM", "starcoder2": "AutoModelForCausalLM", diff --git a/awq/models/qwen2vl.py b/awq/models/qwen2vl.py new file mode 100644 index 00000000..dbd483a2 --- /dev/null +++ b/awq/models/qwen2vl.py @@ -0,0 +1,255 @@ +"""hack to use qwen2vl model with awq""" +import torch +from torch import nn +from typing_extensions import TYPE_CHECKING + +from ..quantize.quantizer import AwqQuantizer, clear_memory, get_best_device +from .base import ( + Annotated, + AwqConfig, + BaseAWQForCausalLM, + Dict, + Doc, + List, + PreTrainedTokenizer, + Union, +) + + +if TYPE_CHECKING: + from transformers import Qwen2VLForConditionalGeneration + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLDecoderLayer + + + +# hack to +# 1. use `self.calib_data` as processed input data +# 2. set the `layer_kwargs` and `inps` correctly +class Qwen2VLAwqQuantizer(AwqQuantizer): + def init_quant(self, n_samples=None, max_seq_len=None): + modules = self.awq_model.get_model_layers(self.model) + samples = self.calib_data + + inps = [] + layer_kwargs = {} + + best_device = get_best_device() + modules[0] = modules[0].to(best_device) + self.awq_model.move_embed(self.model, best_device) + + # get input and kwargs to layer 0 + # with_kwargs is only supported in PyTorch 2.0 + # use this Catcher hack for now + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + # assume first input to forward is hidden states + if len(args) > 0: + hidden_states = args[0] + del args + else: + first_key = list(kwargs.keys())[0] + hidden_states = kwargs.pop(first_key) + + inps.append(hidden_states) + layer_kwargs.update(kwargs) + raise ValueError # early exit to break later inference + + def move_to_device(obj: torch.Tensor | nn.Module, device: torch.device): + def get_device(obj: torch.Tensor | nn.Module): + if isinstance(obj, torch.Tensor): + return obj.device + return next(obj.parameters()).device + + if get_device(obj) != device: + obj = obj.to(device) + return obj + + # patch layer 0 to catch input and kwargs + modules[0] = Catcher(modules[0]) + for k, v in samples.items(): + if isinstance(v, (torch.Tensor, nn.Module)): + samples[k] = move_to_device(v, best_device) + try: + self.model(**samples) + except ValueError: # work with early exit + pass + finally: + for k, v in samples.items(): + if isinstance(v, (torch.Tensor, nn.Module)): + samples[k] = move_to_device(v, "cpu") + modules[0] = modules[0].module # restore + + del samples + inps = inps[0] + + modules[0] = modules[0].cpu() + self.awq_model.move_embed(self.model, "cpu") + + clear_memory() + + return modules, layer_kwargs, inps + + +class Qwen2VLAWQForConditionalGeneration(BaseAWQForCausalLM): + layer_type = "Qwen2VLDecoderLayer" + max_seq_len_key = "max_position_embeddings" + modules_to_not_convert = ["visual"] + + @staticmethod + def get_model_layers(model: "Qwen2VLForConditionalGeneration"): + return model.model.layers + + @staticmethod + def get_act_for_scaling(module: "Qwen2VLForConditionalGeneration"): + return dict(is_scalable=False) + + @staticmethod + def move_embed(model: "Qwen2VLForConditionalGeneration", device: str): + model.model.embed_tokens = model.model.embed_tokens.to(device) + model.visual = model.visual.to(device) + + @staticmethod + def get_layers_for_scaling(module: "Qwen2VLDecoderLayer", input_feat, module_kwargs): + layers = [] + + # attention input + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) + + # attention out + # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 + if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + + # linear 1 + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.gate_proj, module.mlp.up_proj], + inp=input_feat["mlp.gate_proj"], + module2inspect=module.mlp, + ) + ) + + # linear 2 + layers.append( + dict( + prev_op=module.mlp.up_proj, + layers=[module.mlp.down_proj], + inp=input_feat["mlp.down_proj"], + ) + ) + + return layers + + # hack to use `Qwen2VLAwqQuantizer` as quantizer + @torch.no_grad() + def quantize( + self, + tokenizer: Annotated[ + PreTrainedTokenizer, Doc("The tokenizer to use for quantization.") + ] = None, + quant_config: Annotated[ + Dict, Doc("The quantization config you want to use.") + ] = {}, + calib_data: Annotated[ + Union[str, List[str]], + Doc( + "The calibration dataset. Either a string pointing to Huggingface or a list of preloaded examples." + ), + ] = "pileval", + split: Annotated[str, Doc("The split of calib_data.")] = "train", + text_column: Annotated[str, Doc("The text column of calib_data.")] = "text", + duo_scaling: Annotated[ + bool, Doc("Whether to scale using both w/x or just x.") + ] = True, + export_compatible: Annotated[ + bool, + Doc( + "This argument avoids real quantization by only applying the scales without quantizing down to FP16." + ), + ] = False, + apply_clip: Annotated[ + bool, + Doc( + "Whether to apply clipping to the model during quantization. Some models may perform better with this set to False." + ), + ] = True, + n_parallel_calib_samples: Annotated[ + int, + Doc( + "The number of parallel samples to run through the model. " + "A high number of parallel samples can result in OOM during quantization if max_calib_samples is high enough. " + "If None, runs through all samples at the same time. " + "You can set this to a low number for more memory efficient quantization." + ), + ] = None, + max_calib_samples: Annotated[ + int, Doc("The maximum number of samples to run through the model.") + ] = 128, + max_calib_seq_len: Annotated[ + int, + Doc( + "The maximum sequence length of the calibration dataset. Discard samples greater than max_calib_seq_len." + ), + ] = 512, + max_chunk_memory: Annotated[ + int, + Doc( + "The loss computation and per-channel mean is optimized into chunked computations." + " Adjust this parameter to increase or decrease memory usage for these computations." + " Default is 1GB (1024 * 1024 * 1024)." + ), + ] = 1024 + * 1024 + * 1024, + ): + self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config) + + if hasattr(self, "modules_to_not_convert"): + self.quant_config.modules_to_not_convert = self.modules_to_not_convert + + self.quantizer = Qwen2VLAwqQuantizer( + self, + self.model, + tokenizer, + self.quant_config.w_bit, + self.quant_config.q_group_size, + self.quant_config.zero_point, + self.quant_config.version, + calib_data, + split, + text_column, + duo_scaling, + modules_to_not_convert=self.quant_config.modules_to_not_convert, + export_compatible=export_compatible, + apply_clip=apply_clip, + n_parallel_calib_samples=n_parallel_calib_samples, + max_calib_samples=max_calib_samples, + max_calib_seq_len=max_calib_seq_len, + max_chunk_memory=max_chunk_memory, + ) + self.quantizer.quantize() + + self.is_quantized = True diff --git a/examples/quantize_qwen2vl.py b/examples/quantize_qwen2vl.py new file mode 100644 index 00000000..08f0964d --- /dev/null +++ b/examples/quantize_qwen2vl.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import logging + +from qwen_vl_utils import process_vision_info +from transformers import Qwen2VLProcessor + +from awq.models.qwen2vl import Qwen2VLAWQForConditionalGeneration + + +logging.basicConfig( + format="%(asctime)s %(levelname)s [%(name)s] %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) + +# Specify paths and hyperparameters for quantization +model_path = "your_model_path" +quant_path = "your_quantized_model_path" +quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"} + +# Load your processor and model with AutoAWQ +processor = Qwen2VLProcessor.from_pretrained(model_path) +model = Qwen2VLAWQForConditionalGeneration.from_pretrained( + model_path, model_type="qwen2_vl", use_cache=False, attn_implementation="flash_attention_2" +) + + +# Then you need to prepare your data for calibaration. What you need to do is just put samples into a list, +# each of which is a typical chat message as shown below. you can specify text and image in `content` field: +# dataset = [ +# # message 0 +# [ +# {"role": "system", "content": "You are a helpful assistant."}, +# {"role": "user", "content": "Tell me who you are."}, +# {"role": "assistant", "content": "I am a large language model named Qwen..."}, +# ], +# # message 1 +# [ +# { +# "role": "user", +# "content": [ +# {"type": "image", "image": "file:///path/to/your/image.jpg"}, +# {"type": "text", "text": "Output all text in the image"}, +# ], +# }, +# {"role": "assistant", "content": "The text in the image is balabala..."}, +# ], +# # other messages... +# ..., +# ] +# here, we use a caption dataset **only for demonstration**. You should replace it with your own sft dataset. +def prepare_dataset(n_sample: int = 8) -> list[list[dict]]: + from datasets import load_dataset + + dataset = load_dataset("laion/220k-GPT4Vision-captions-from-LIVIS", split=f"train[:{n_sample}]") + return [ + [ + { + "role": "user", + "content": [ + {"type": "image", "image": sample["url"]}, + {"type": "text", "text": "generate a caption for this image"}, + ], + }, + {"role": "assistant", "content": sample["caption"]}, + ] + for sample in dataset + ] + + +dataset = prepare_dataset() + +# process the dataset into tensors +text = processor.apply_chat_template(dataset, tokenize=False, add_generation_prompt=True) +image_inputs, video_inputs = process_vision_info(dataset) +inputs = processor(text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") + +# Then just run the calibration process by one line of code: +model.quantize(calib_data=inputs, quant_config=quant_config) + +# Finally, save the quantized model: +model.model.config.use_cache = model.model.generation_config.use_cache = True +model.save_quantized(quant_path, safetensors=True, shard_size="4GB") +processor.save_pretrained(quant_path) + +# Then you can obtain your own AWQ quantized model for deployment. Enjoy! From acbdf26a3dcc3a5e89f90b9c1cfd5c644ba961cf Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Thu, 14 Nov 2024 06:09:59 +0000 Subject: [PATCH 2/4] Streamline quantization of Qwen2 VL --- awq/models/__init__.py | 1 + awq/models/auto.py | 1 + awq/models/base.py | 16 +- awq/models/qwen2vl.py | 186 +------------------ awq/utils/qwen_vl_utils.py | 339 +++++++++++++++++++++++++++++++++++ examples/quantize_qwen2vl.py | 117 ++++++++---- 6 files changed, 440 insertions(+), 220 deletions(-) create mode 100644 awq/utils/qwen_vl_utils.py diff --git a/awq/models/__init__.py b/awq/models/__init__.py index 2f1a88e2..9b3a4f27 100644 --- a/awq/models/__init__.py +++ b/awq/models/__init__.py @@ -24,3 +24,4 @@ from .deepseek_v2 import DeepseekV2AWQForCausalLM from .minicpm import MiniCPMAWQForCausalLM from .internlm2 import InternLM2AWQForCausalLM +from .qwen2vl import Qwen2VLAWQForCausalLM \ No newline at end of file diff --git a/awq/models/auto.py b/awq/models/auto.py index 1ce1b21d..495722ab 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -34,6 +34,7 @@ "deepseek_v2": DeepseekV2AWQForCausalLM, "minicpm": MiniCPMAWQForCausalLM, "internlm2": InternLM2AWQForCausalLM, + "qwen2_vl": Qwen2VLAWQForCausalLM, } diff --git a/awq/models/base.py b/awq/models/base.py index 4250dcac..abfb9b38 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -39,7 +39,7 @@ PreTrainedModel, PretrainedConfig, AutoProcessor, - CLIPImageProcessor, + BaseImageProcessor, PreTrainedTokenizer, ) from accelerate.big_modeling import ( @@ -85,6 +85,7 @@ "deepseek_v2": "AutoModelForCausalLM", "minicpm": "AutoModelForCausalLM", "internlm2": "AutoModelForCausalLM", + "qwen2_vl": "AutoModelForVision2Seq", } @@ -101,7 +102,7 @@ def __init__( AwqConfig, Doc("The quantization config of the model.") ], processor: Annotated[ - AutoProcessor, Doc("An optional processor, e.g. for vision models.") + BaseImageProcessor, Doc("An optional processor, e.g. for vision models.") ], ): """The base model for all AutoAWQ models.""" @@ -112,7 +113,7 @@ def __init__( self.search_result = None self.config: PretrainedConfig = config self.quant_config: AwqConfig = quant_config - self.processor: CLIPImageProcessor = processor + self.processor: BaseImageProcessor = processor def to(self, device: Annotated[str, Doc("The device to move your model to.")]): """A utility function for moving the model to a device.""" @@ -187,6 +188,11 @@ def quantize( ] = 1024 * 1024 * 1024, + quantizer_cls: Annotated[ + AwqQuantizer, + Doc("If you want to customize the quantization class, you can use AwqQuantizer as a base class.") + ] = AwqQuantizer, + **kwargs, ): """ The main quantization function that you can use to quantize your model. @@ -210,7 +216,7 @@ def quantize( if hasattr(self, "modules_to_not_convert"): self.quant_config.modules_to_not_convert = self.modules_to_not_convert - self.quantizer = AwqQuantizer( + self.quantizer = quantizer_cls( self, self.model, tokenizer, @@ -229,6 +235,7 @@ def quantize( max_calib_samples=max_calib_samples, max_calib_seq_len=max_calib_seq_len, max_chunk_memory=max_chunk_memory, + **kwargs, ) self.quantizer.quantize() @@ -374,7 +381,6 @@ def from_pretrained( processor = None if target_cls_name == "AutoModelForVision2Seq": processor = AutoProcessor.from_pretrained(model_weights_path) - processor: CLIPImageProcessor = processor.image_processor # If not quantized, must load with AutoModelForCausalLM model = target_cls.from_pretrained( diff --git a/awq/models/qwen2vl.py b/awq/models/qwen2vl.py index dbd483a2..4d8ad213 100644 --- a/awq/models/qwen2vl.py +++ b/awq/models/qwen2vl.py @@ -1,100 +1,11 @@ -"""hack to use qwen2vl model with awq""" -import torch -from torch import nn +from .base import BaseAWQForCausalLM from typing_extensions import TYPE_CHECKING -from ..quantize.quantizer import AwqQuantizer, clear_memory, get_best_device -from .base import ( - Annotated, - AwqConfig, - BaseAWQForCausalLM, - Dict, - Doc, - List, - PreTrainedTokenizer, - Union, -) - - if TYPE_CHECKING: from transformers import Qwen2VLForConditionalGeneration from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLDecoderLayer - - -# hack to -# 1. use `self.calib_data` as processed input data -# 2. set the `layer_kwargs` and `inps` correctly -class Qwen2VLAwqQuantizer(AwqQuantizer): - def init_quant(self, n_samples=None, max_seq_len=None): - modules = self.awq_model.get_model_layers(self.model) - samples = self.calib_data - - inps = [] - layer_kwargs = {} - - best_device = get_best_device() - modules[0] = modules[0].to(best_device) - self.awq_model.move_embed(self.model, best_device) - - # get input and kwargs to layer 0 - # with_kwargs is only supported in PyTorch 2.0 - # use this Catcher hack for now - class Catcher(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - - def forward(self, *args, **kwargs): - # assume first input to forward is hidden states - if len(args) > 0: - hidden_states = args[0] - del args - else: - first_key = list(kwargs.keys())[0] - hidden_states = kwargs.pop(first_key) - - inps.append(hidden_states) - layer_kwargs.update(kwargs) - raise ValueError # early exit to break later inference - - def move_to_device(obj: torch.Tensor | nn.Module, device: torch.device): - def get_device(obj: torch.Tensor | nn.Module): - if isinstance(obj, torch.Tensor): - return obj.device - return next(obj.parameters()).device - - if get_device(obj) != device: - obj = obj.to(device) - return obj - - # patch layer 0 to catch input and kwargs - modules[0] = Catcher(modules[0]) - for k, v in samples.items(): - if isinstance(v, (torch.Tensor, nn.Module)): - samples[k] = move_to_device(v, best_device) - try: - self.model(**samples) - except ValueError: # work with early exit - pass - finally: - for k, v in samples.items(): - if isinstance(v, (torch.Tensor, nn.Module)): - samples[k] = move_to_device(v, "cpu") - modules[0] = modules[0].module # restore - - del samples - inps = inps[0] - - modules[0] = modules[0].cpu() - self.awq_model.move_embed(self.model, "cpu") - - clear_memory() - - return modules, layer_kwargs, inps - - -class Qwen2VLAWQForConditionalGeneration(BaseAWQForCausalLM): +class Qwen2VLAWQForCausalLM(BaseAWQForCausalLM): layer_type = "Qwen2VLDecoderLayer" max_seq_len_key = "max_position_embeddings" modules_to_not_convert = ["visual"] @@ -161,95 +72,4 @@ def get_layers_for_scaling(module: "Qwen2VLDecoderLayer", input_feat, module_kwa ) ) - return layers - - # hack to use `Qwen2VLAwqQuantizer` as quantizer - @torch.no_grad() - def quantize( - self, - tokenizer: Annotated[ - PreTrainedTokenizer, Doc("The tokenizer to use for quantization.") - ] = None, - quant_config: Annotated[ - Dict, Doc("The quantization config you want to use.") - ] = {}, - calib_data: Annotated[ - Union[str, List[str]], - Doc( - "The calibration dataset. Either a string pointing to Huggingface or a list of preloaded examples." - ), - ] = "pileval", - split: Annotated[str, Doc("The split of calib_data.")] = "train", - text_column: Annotated[str, Doc("The text column of calib_data.")] = "text", - duo_scaling: Annotated[ - bool, Doc("Whether to scale using both w/x or just x.") - ] = True, - export_compatible: Annotated[ - bool, - Doc( - "This argument avoids real quantization by only applying the scales without quantizing down to FP16." - ), - ] = False, - apply_clip: Annotated[ - bool, - Doc( - "Whether to apply clipping to the model during quantization. Some models may perform better with this set to False." - ), - ] = True, - n_parallel_calib_samples: Annotated[ - int, - Doc( - "The number of parallel samples to run through the model. " - "A high number of parallel samples can result in OOM during quantization if max_calib_samples is high enough. " - "If None, runs through all samples at the same time. " - "You can set this to a low number for more memory efficient quantization." - ), - ] = None, - max_calib_samples: Annotated[ - int, Doc("The maximum number of samples to run through the model.") - ] = 128, - max_calib_seq_len: Annotated[ - int, - Doc( - "The maximum sequence length of the calibration dataset. Discard samples greater than max_calib_seq_len." - ), - ] = 512, - max_chunk_memory: Annotated[ - int, - Doc( - "The loss computation and per-channel mean is optimized into chunked computations." - " Adjust this parameter to increase or decrease memory usage for these computations." - " Default is 1GB (1024 * 1024 * 1024)." - ), - ] = 1024 - * 1024 - * 1024, - ): - self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config) - - if hasattr(self, "modules_to_not_convert"): - self.quant_config.modules_to_not_convert = self.modules_to_not_convert - - self.quantizer = Qwen2VLAwqQuantizer( - self, - self.model, - tokenizer, - self.quant_config.w_bit, - self.quant_config.q_group_size, - self.quant_config.zero_point, - self.quant_config.version, - calib_data, - split, - text_column, - duo_scaling, - modules_to_not_convert=self.quant_config.modules_to_not_convert, - export_compatible=export_compatible, - apply_clip=apply_clip, - n_parallel_calib_samples=n_parallel_calib_samples, - max_calib_samples=max_calib_samples, - max_calib_seq_len=max_calib_seq_len, - max_chunk_memory=max_chunk_memory, - ) - self.quantizer.quantize() - - self.is_quantized = True + return layers \ No newline at end of file diff --git a/awq/utils/qwen_vl_utils.py b/awq/utils/qwen_vl_utils.py new file mode 100644 index 00000000..08ba02f7 --- /dev/null +++ b/awq/utils/qwen_vl_utils.py @@ -0,0 +1,339 @@ +from __future__ import annotations + +import base64 +import logging +import math +import os +import sys +import time +import warnings +from functools import lru_cache +from io import BytesIO + +import requests +import torch +import torchvision +from packaging import version +from PIL import Image +from torchvision import io, transforms +from torchvision.transforms import InterpolationMode + + +logger = logging.getLogger(__name__) + +IMAGE_FACTOR = 28 +MIN_PIXELS = 4 * 28 * 28 +MAX_PIXELS = 16384 * 28 * 28 +MAX_RATIO = 200 + +VIDEO_MIN_PIXELS = 128 * 28 * 28 +VIDEO_MAX_PIXELS = 768 * 28 * 28 +VIDEO_TOTAL_PIXELS = 24576 * 28 * 28 +FRAME_FACTOR = 2 +FPS = 2.0 +FPS_MIN_FRAMES = 4 +FPS_MAX_FRAMES = 768 + + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +def smart_resize( + height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS +) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + +def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image: + if "image" in ele: + image = ele["image"] + else: + image = ele["image_url"] + image_obj = None + if isinstance(image, Image.Image): + image_obj = image + elif image.startswith("http://") or image.startswith("https://"): + image_obj = Image.open(requests.get(image, stream=True).raw) + elif image.startswith("file://"): + image_obj = Image.open(image[7:]) + elif image.startswith("data:image"): + if "base64," in image: + _, base64_data = image.split("base64,", 1) + data = base64.b64decode(base64_data) + image_obj = Image.open(BytesIO(data)) + else: + image_obj = Image.open(image) + if image_obj is None: + raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") + image = image_obj.convert("RGB") + ## resize + if "resized_height" in ele and "resized_width" in ele: + resized_height, resized_width = smart_resize( + ele["resized_height"], + ele["resized_width"], + factor=size_factor, + ) + else: + width, height = image.size + min_pixels = ele.get("min_pixels", MIN_PIXELS) + max_pixels = ele.get("max_pixels", MAX_PIXELS) + resized_height, resized_width = smart_resize( + height, + width, + factor=size_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + + return image + + +def smart_nframes( + ele: dict, + total_frames: int, + video_fps: int | float, +) -> int: + """calculate the number of frames for video used for model inputs. + + Args: + ele (dict): a dict contains the configuration of video. + support either `fps` or `nframes`: + - nframes: the number of frames to extract for model inputs. + - fps: the fps to extract frames for model inputs. + - min_frames: the minimum number of frames of the video, only used when fps is provided. + - max_frames: the maximum number of frames of the video, only used when fps is provided. + total_frames (int): the original total number of frames of the video. + video_fps (int | float): the original fps of the video. + + Raises: + ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. + + Returns: + int: the number of frames for video used for model inputs. + """ + assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`" + if "nframes" in ele: + nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) + else: + fps = ele.get("fps", FPS) + min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) + max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR) + nframes = total_frames / video_fps * fps + nframes = min(max(nframes, min_frames), max_frames) + nframes = round_by_factor(nframes, FRAME_FACTOR) + if not (FRAME_FACTOR <= nframes and nframes <= total_frames): + raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.") + return nframes + + +def _read_video_torchvision( + ele: dict, +) -> torch.Tensor: + """read video using torchvision.io.read_video + + Args: + ele (dict): a dict contains the configuration of video. + support keys: + - video: the path of video. support "file://", "http://", "https://" and local path. + - video_start: the start time of video. + - video_end: the end time of video. + Returns: + torch.Tensor: the video tensor with shape (T, C, H, W). + """ + video_path = ele["video"] + if version.parse(torchvision.__version__) < version.parse("0.19.0"): + if "http://" in video_path or "https://" in video_path: + warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.") + if "file://" in video_path: + video_path = video_path[7:] + st = time.time() + video, audio, info = io.read_video( + video_path, + start_pts=ele.get("video_start", 0.0), + end_pts=ele.get("video_end", None), + pts_unit="sec", + output_format="TCHW", + ) + total_frames, video_fps = video.size(0), info["video_fps"] + logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") + nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) + idx = torch.linspace(0, total_frames - 1, nframes).round().long() + video = video[idx] + return video + + +def is_decord_available() -> bool: + import importlib.util + + return importlib.util.find_spec("decord") is not None + + +def _read_video_decord( + ele: dict, +) -> torch.Tensor: + """read video using decord.VideoReader + + Args: + ele (dict): a dict contains the configuration of video. + support keys: + - video: the path of video. support "file://", "http://", "https://" and local path. + - video_start: the start time of video. + - video_end: the end time of video. + Returns: + torch.Tensor: the video tensor with shape (T, C, H, W). + """ + import decord + video_path = ele["video"] + st = time.time() + vr = decord.VideoReader(video_path) + # TODO: support start_pts and end_pts + if 'video_start' in ele or 'video_end' in ele: + raise NotImplementedError("not support start_pts and end_pts in decord for now.") + total_frames, video_fps = len(vr), vr.get_avg_fps() + logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") + nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) + idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() + video = vr.get_batch(idx).asnumpy() + video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format + return video + + +VIDEO_READER_BACKENDS = { + "decord": _read_video_decord, + "torchvision": _read_video_torchvision, +} + +FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None) + + +@lru_cache(maxsize=1) +def get_video_reader_backend() -> str: + if FORCE_QWENVL_VIDEO_READER is not None: + video_reader_backend = FORCE_QWENVL_VIDEO_READER + elif is_decord_available(): + video_reader_backend = "decord" + else: + video_reader_backend = "torchvision" + print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr) + return video_reader_backend + + +def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]: + if isinstance(ele["video"], str): + video_reader_backend = get_video_reader_backend() + video = VIDEO_READER_BACKENDS[video_reader_backend](ele) + nframes, _, height, width = video.shape + + min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) + total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS) + max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05)) + max_pixels = ele.get("max_pixels", max_pixels) + if "resized_height" in ele and "resized_width" in ele: + resized_height, resized_width = smart_resize( + ele["resized_height"], + ele["resized_width"], + factor=image_factor, + ) + else: + resized_height, resized_width = smart_resize( + height, + width, + factor=image_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + video = transforms.functional.resize( + video, + [resized_height, resized_width], + interpolation=InterpolationMode.BICUBIC, + antialias=True, + ).float() + return video + else: + assert isinstance(ele["video"], (list, tuple)) + process_info = ele.copy() + process_info.pop("type", None) + process_info.pop("video", None) + images = [ + fetch_image({"image": video_element, **process_info}, size_factor=image_factor) + for video_element in ele["video"] + ] + nframes = ceil_by_factor(len(images), FRAME_FACTOR) + if len(images) < nframes: + images.extend([images[-1]] * (nframes - len(images))) + return images + + +def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]: + vision_infos = [] + if isinstance(conversations[0], dict): + conversations = [conversations] + for conversation in conversations: + for message in conversation: + if isinstance(message["content"], list): + for ele in message["content"]: + if ( + "image" in ele + or "image_url" in ele + or "video" in ele + or ele["type"] in ("image", "image_url", "video") + ): + vision_infos.append(ele) + return vision_infos + + +def process_vision_info( + conversations: list[dict] | list[list[dict]], +) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None]: + vision_infos = extract_vision_info(conversations) + ## Read images or videos + image_inputs = [] + video_inputs = [] + for vision_info in vision_infos: + if "image" in vision_info or "image_url" in vision_info: + image_inputs.append(fetch_image(vision_info)) + elif "video" in vision_info: + video_inputs.append(fetch_video(vision_info)) + else: + raise ValueError("image, image_url or video should in content.") + if len(image_inputs) == 0: + image_inputs = None + if len(video_inputs) == 0: + video_inputs = None + return image_inputs, video_inputs \ No newline at end of file diff --git a/examples/quantize_qwen2vl.py b/examples/quantize_qwen2vl.py index 08f0964d..49fb3819 100644 --- a/examples/quantize_qwen2vl.py +++ b/examples/quantize_qwen2vl.py @@ -1,31 +1,19 @@ -from __future__ import annotations +import torch +import torch.nn as nn -import logging - -from qwen_vl_utils import process_vision_info -from transformers import Qwen2VLProcessor - -from awq.models.qwen2vl import Qwen2VLAWQForConditionalGeneration - - -logging.basicConfig( - format="%(asctime)s %(levelname)s [%(name)s] %(message)s", - level=logging.INFO, - datefmt="%Y-%m-%d %H:%M:%S", -) +from awq import AutoAWQForCausalLM +from awq.utils.qwen_vl_utils import process_vision_info +from awq.quantize.quantizer import AwqQuantizer, clear_memory, get_best_device # Specify paths and hyperparameters for quantization -model_path = "your_model_path" -quant_path = "your_quantized_model_path" +model_path = "Qwen/Qwen2-VL-7B-Instruct" +quant_path = "qwen2-vl-7b-instruct" quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"} -# Load your processor and model with AutoAWQ -processor = Qwen2VLProcessor.from_pretrained(model_path) -model = Qwen2VLAWQForConditionalGeneration.from_pretrained( - model_path, model_type="qwen2_vl", use_cache=False, attn_implementation="flash_attention_2" +model = AutoAWQForCausalLM.from_pretrained( + model_path, use_cache=False, attn_implementation="flash_attention_2" ) - # Then you need to prepare your data for calibaration. What you need to do is just put samples into a list, # each of which is a typical chat message as shown below. you can specify text and image in `content` field: # dataset = [ @@ -72,16 +60,81 @@ def prepare_dataset(n_sample: int = 8) -> list[list[dict]]: dataset = prepare_dataset() # process the dataset into tensors -text = processor.apply_chat_template(dataset, tokenize=False, add_generation_prompt=True) +text = model.processor.apply_chat_template(dataset, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(dataset) -inputs = processor(text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") - -# Then just run the calibration process by one line of code: -model.quantize(calib_data=inputs, quant_config=quant_config) - -# Finally, save the quantized model: +inputs = model.processor(text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") + +class Qwen2VLAwqQuantizer(AwqQuantizer): + def init_quant(self, n_samples=None, max_seq_len=None): + modules = self.awq_model.get_model_layers(self.model) + samples = self.calib_data + + inps = [] + layer_kwargs = {} + + best_device = get_best_device() + modules[0] = modules[0].to(best_device) + self.awq_model.move_embed(self.model, best_device) + + # get input and kwargs to layer 0 + # with_kwargs is only supported in PyTorch 2.0 + # use this Catcher hack for now + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + # assume first input to forward is hidden states + if len(args) > 0: + hidden_states = args[0] + del args + else: + first_key = list(kwargs.keys())[0] + hidden_states = kwargs.pop(first_key) + + inps.append(hidden_states) + layer_kwargs.update(kwargs) + raise ValueError # early exit to break later inference + + def move_to_device(obj: torch.Tensor | nn.Module, device: torch.device): + def get_device(obj: torch.Tensor | nn.Module): + if isinstance(obj, torch.Tensor): + return obj.device + return next(obj.parameters()).device + + if get_device(obj) != device: + obj = obj.to(device) + return obj + + # patch layer 0 to catch input and kwargs + modules[0] = Catcher(modules[0]) + for k, v in samples.items(): + if isinstance(v, (torch.Tensor, nn.Module)): + samples[k] = move_to_device(v, best_device) + try: + self.model(**samples) + except ValueError: # work with early exit + pass + finally: + for k, v in samples.items(): + if isinstance(v, (torch.Tensor, nn.Module)): + samples[k] = move_to_device(v, "cpu") + modules[0] = modules[0].module # restore + + del samples + inps = inps[0] + + modules[0] = modules[0].cpu() + self.awq_model.move_embed(self.model, "cpu") + + clear_memory() + + return modules, layer_kwargs, inps + +# Then just run the calibration process by one line of code +model.quantize(calib_data=inputs, quant_config=quant_config, quantizer_cls=Qwen2VLAwqQuantizer) + +# Save the model model.model.config.use_cache = model.model.generation_config.use_cache = True -model.save_quantized(quant_path, safetensors=True, shard_size="4GB") -processor.save_pretrained(quant_path) - -# Then you can obtain your own AWQ quantized model for deployment. Enjoy! +model.save_quantized(quant_path, safetensors=True, shard_size="4GB") \ No newline at end of file From f3a5436905bdcfc20c337c9c91fe5914e15fcd8a Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Thu, 14 Nov 2024 06:12:59 +0000 Subject: [PATCH 3/4] Add inference example to docs --- docs/examples.md | 49 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/docs/examples.md b/docs/examples.md index 2fc0259c..83fd10cc 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -466,3 +466,52 @@ generation_output = model.generate( streamer=streamer ) ``` + +### Qwen2 VL + +Below is an example of how to run inference using Qwen2 VL. + +```python +from awq import AutoAWQForCausalLM +from awq.utils.qwen_vl_utils import process_vision_info +from transformers import AutoProcessor, TextStreamer + +# Load model +quant_path = "Qwen/Qwen2-VL-7B-Instruct-AWQ" +model = AutoAWQForCausalLM.from_quantized(quant_path) +processor = AutoProcessor.from_pretrained(quant_path) +streamer = TextStreamer(processor, skip_prompt=True) + +messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image."}, + ], + } +] + +# Load inputs +text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True +) +image_inputs, video_inputs = process_vision_info(messages) +inputs = processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", +) +inputs = inputs.to("cuda") + +generation_output = model.generate( + **inputs, + max_new_tokens=512, + streamer=streamer +) +``` \ No newline at end of file From 0c3ea1d68522042a80a03b43af0418843721061b Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Thu, 14 Nov 2024 06:15:38 +0000 Subject: [PATCH 4/4] Move example to docs --- docs/examples.md | 150 +++++++++++++++++++++++++++++++++++ examples/quantize_qwen2vl.py | 140 -------------------------------- 2 files changed, 150 insertions(+), 140 deletions(-) delete mode 100644 examples/quantize_qwen2vl.py diff --git a/docs/examples.md b/docs/examples.md index 83fd10cc..6032b212 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -274,6 +274,156 @@ subprocess.run([ ], shell=True, check=True) ``` +### Custom Quantizer (Qwen2 VL Example) + +Below, the Qwen team has provided an example of how to use a custom quantizer. This works to +effectively quantize the Qwen2 VL model using multimodal examples. + +```python +import torch +import torch.nn as nn + +from awq import AutoAWQForCausalLM +from awq.utils.qwen_vl_utils import process_vision_info +from awq.quantize.quantizer import AwqQuantizer, clear_memory, get_best_device + +# Specify paths and hyperparameters for quantization +model_path = "Qwen/Qwen2-VL-7B-Instruct" +quant_path = "qwen2-vl-7b-instruct" +quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"} + +model = AutoAWQForCausalLM.from_pretrained( + model_path, use_cache=False, attn_implementation="flash_attention_2" +) + +# We define our own quantizer by extending the AwqQuantizer. +# The main difference is in how the samples are processed when +# the quantization process initialized. +class Qwen2VLAwqQuantizer(AwqQuantizer): + def init_quant(self, n_samples=None, max_seq_len=None): + modules = self.awq_model.get_model_layers(self.model) + samples = self.calib_data + + inps = [] + layer_kwargs = {} + + best_device = get_best_device() + modules[0] = modules[0].to(best_device) + self.awq_model.move_embed(self.model, best_device) + + # get input and kwargs to layer 0 + # with_kwargs is only supported in PyTorch 2.0 + # use this Catcher hack for now + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + # assume first input to forward is hidden states + if len(args) > 0: + hidden_states = args[0] + del args + else: + first_key = list(kwargs.keys())[0] + hidden_states = kwargs.pop(first_key) + + inps.append(hidden_states) + layer_kwargs.update(kwargs) + raise ValueError # early exit to break later inference + + def move_to_device(obj: torch.Tensor | nn.Module, device: torch.device): + def get_device(obj: torch.Tensor | nn.Module): + if isinstance(obj, torch.Tensor): + return obj.device + return next(obj.parameters()).device + + if get_device(obj) != device: + obj = obj.to(device) + return obj + + # patch layer 0 to catch input and kwargs + modules[0] = Catcher(modules[0]) + for k, v in samples.items(): + if isinstance(v, (torch.Tensor, nn.Module)): + samples[k] = move_to_device(v, best_device) + try: + self.model(**samples) + except ValueError: # work with early exit + pass + finally: + for k, v in samples.items(): + if isinstance(v, (torch.Tensor, nn.Module)): + samples[k] = move_to_device(v, "cpu") + modules[0] = modules[0].module # restore + + del samples + inps = inps[0] + + modules[0] = modules[0].cpu() + self.awq_model.move_embed(self.model, "cpu") + + clear_memory() + + return modules, layer_kwargs, inps + +# Then you need to prepare your data for calibaration. What you need to do is just put samples into a list, +# each of which is a typical chat message as shown below. you can specify text and image in `content` field: +# dataset = [ +# # message 0 +# [ +# {"role": "system", "content": "You are a helpful assistant."}, +# {"role": "user", "content": "Tell me who you are."}, +# {"role": "assistant", "content": "I am a large language model named Qwen..."}, +# ], +# # message 1 +# [ +# { +# "role": "user", +# "content": [ +# {"type": "image", "image": "file:///path/to/your/image.jpg"}, +# {"type": "text", "text": "Output all text in the image"}, +# ], +# }, +# {"role": "assistant", "content": "The text in the image is balabala..."}, +# ], +# # other messages... +# ..., +# ] +# here, we use a caption dataset **only for demonstration**. You should replace it with your own sft dataset. +def prepare_dataset(n_sample: int = 8) -> list[list[dict]]: + from datasets import load_dataset + + dataset = load_dataset("laion/220k-GPT4Vision-captions-from-LIVIS", split=f"train[:{n_sample}]") + return [ + [ + { + "role": "user", + "content": [ + {"type": "image", "image": sample["url"]}, + {"type": "text", "text": "generate a caption for this image"}, + ], + }, + {"role": "assistant", "content": sample["caption"]}, + ] + for sample in dataset + ] + +dataset = prepare_dataset() + +# process the dataset into tensors +text = model.processor.apply_chat_template(dataset, tokenize=False, add_generation_prompt=True) +image_inputs, video_inputs = process_vision_info(dataset) +inputs = model.processor(text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") + +# Then just run the calibration process by one line of code +model.quantize(calib_data=inputs, quant_config=quant_config, quantizer_cls=Qwen2VLAwqQuantizer) + +# Save the model +model.model.config.use_cache = model.model.generation_config.use_cache = True +model.save_quantized(quant_path, safetensors=True, shard_size="4GB") +``` + ## Basic Inference ### Inference With GPU diff --git a/examples/quantize_qwen2vl.py b/examples/quantize_qwen2vl.py deleted file mode 100644 index 49fb3819..00000000 --- a/examples/quantize_qwen2vl.py +++ /dev/null @@ -1,140 +0,0 @@ -import torch -import torch.nn as nn - -from awq import AutoAWQForCausalLM -from awq.utils.qwen_vl_utils import process_vision_info -from awq.quantize.quantizer import AwqQuantizer, clear_memory, get_best_device - -# Specify paths and hyperparameters for quantization -model_path = "Qwen/Qwen2-VL-7B-Instruct" -quant_path = "qwen2-vl-7b-instruct" -quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"} - -model = AutoAWQForCausalLM.from_pretrained( - model_path, use_cache=False, attn_implementation="flash_attention_2" -) - -# Then you need to prepare your data for calibaration. What you need to do is just put samples into a list, -# each of which is a typical chat message as shown below. you can specify text and image in `content` field: -# dataset = [ -# # message 0 -# [ -# {"role": "system", "content": "You are a helpful assistant."}, -# {"role": "user", "content": "Tell me who you are."}, -# {"role": "assistant", "content": "I am a large language model named Qwen..."}, -# ], -# # message 1 -# [ -# { -# "role": "user", -# "content": [ -# {"type": "image", "image": "file:///path/to/your/image.jpg"}, -# {"type": "text", "text": "Output all text in the image"}, -# ], -# }, -# {"role": "assistant", "content": "The text in the image is balabala..."}, -# ], -# # other messages... -# ..., -# ] -# here, we use a caption dataset **only for demonstration**. You should replace it with your own sft dataset. -def prepare_dataset(n_sample: int = 8) -> list[list[dict]]: - from datasets import load_dataset - - dataset = load_dataset("laion/220k-GPT4Vision-captions-from-LIVIS", split=f"train[:{n_sample}]") - return [ - [ - { - "role": "user", - "content": [ - {"type": "image", "image": sample["url"]}, - {"type": "text", "text": "generate a caption for this image"}, - ], - }, - {"role": "assistant", "content": sample["caption"]}, - ] - for sample in dataset - ] - - -dataset = prepare_dataset() - -# process the dataset into tensors -text = model.processor.apply_chat_template(dataset, tokenize=False, add_generation_prompt=True) -image_inputs, video_inputs = process_vision_info(dataset) -inputs = model.processor(text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") - -class Qwen2VLAwqQuantizer(AwqQuantizer): - def init_quant(self, n_samples=None, max_seq_len=None): - modules = self.awq_model.get_model_layers(self.model) - samples = self.calib_data - - inps = [] - layer_kwargs = {} - - best_device = get_best_device() - modules[0] = modules[0].to(best_device) - self.awq_model.move_embed(self.model, best_device) - - # get input and kwargs to layer 0 - # with_kwargs is only supported in PyTorch 2.0 - # use this Catcher hack for now - class Catcher(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - - def forward(self, *args, **kwargs): - # assume first input to forward is hidden states - if len(args) > 0: - hidden_states = args[0] - del args - else: - first_key = list(kwargs.keys())[0] - hidden_states = kwargs.pop(first_key) - - inps.append(hidden_states) - layer_kwargs.update(kwargs) - raise ValueError # early exit to break later inference - - def move_to_device(obj: torch.Tensor | nn.Module, device: torch.device): - def get_device(obj: torch.Tensor | nn.Module): - if isinstance(obj, torch.Tensor): - return obj.device - return next(obj.parameters()).device - - if get_device(obj) != device: - obj = obj.to(device) - return obj - - # patch layer 0 to catch input and kwargs - modules[0] = Catcher(modules[0]) - for k, v in samples.items(): - if isinstance(v, (torch.Tensor, nn.Module)): - samples[k] = move_to_device(v, best_device) - try: - self.model(**samples) - except ValueError: # work with early exit - pass - finally: - for k, v in samples.items(): - if isinstance(v, (torch.Tensor, nn.Module)): - samples[k] = move_to_device(v, "cpu") - modules[0] = modules[0].module # restore - - del samples - inps = inps[0] - - modules[0] = modules[0].cpu() - self.awq_model.move_embed(self.model, "cpu") - - clear_memory() - - return modules, layer_kwargs, inps - -# Then just run the calibration process by one line of code -model.quantize(calib_data=inputs, quant_config=quant_config, quantizer_cls=Qwen2VLAwqQuantizer) - -# Save the model -model.model.config.use_cache = model.model.generation_config.use_cache = True -model.save_quantized(quant_path, safetensors=True, shard_size="4GB") \ No newline at end of file