diff --git a/awq/models/__init__.py b/awq/models/__init__.py index 2f1a88e2..9b3a4f27 100644 --- a/awq/models/__init__.py +++ b/awq/models/__init__.py @@ -24,3 +24,4 @@ from .deepseek_v2 import DeepseekV2AWQForCausalLM from .minicpm import MiniCPMAWQForCausalLM from .internlm2 import InternLM2AWQForCausalLM +from .qwen2vl import Qwen2VLAWQForCausalLM \ No newline at end of file diff --git a/awq/models/auto.py b/awq/models/auto.py index 1ce1b21d..495722ab 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -34,6 +34,7 @@ "deepseek_v2": DeepseekV2AWQForCausalLM, "minicpm": MiniCPMAWQForCausalLM, "internlm2": InternLM2AWQForCausalLM, + "qwen2_vl": Qwen2VLAWQForCausalLM, } diff --git a/awq/models/base.py b/awq/models/base.py index 2da5095d..abfb9b38 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -39,7 +39,7 @@ PreTrainedModel, PretrainedConfig, AutoProcessor, - CLIPImageProcessor, + BaseImageProcessor, PreTrainedTokenizer, ) from accelerate.big_modeling import ( @@ -74,6 +74,7 @@ "baichuan": "AutoModelForCausalLM", "llava": "AutoModelForVision2Seq", "qwen2": "AutoModelForCausalLM", + "qwen2_vl": "AutoModelForVision2Seq", "gemma": "AutoModelForCausalLM", "gemma2": "AutoModelForCausalLM", "stablelm": "AutoModelForCausalLM", @@ -84,6 +85,7 @@ "deepseek_v2": "AutoModelForCausalLM", "minicpm": "AutoModelForCausalLM", "internlm2": "AutoModelForCausalLM", + "qwen2_vl": "AutoModelForVision2Seq", } @@ -100,7 +102,7 @@ def __init__( AwqConfig, Doc("The quantization config of the model.") ], processor: Annotated[ - AutoProcessor, Doc("An optional processor, e.g. for vision models.") + BaseImageProcessor, Doc("An optional processor, e.g. for vision models.") ], ): """The base model for all AutoAWQ models.""" @@ -111,7 +113,7 @@ def __init__( self.search_result = None self.config: PretrainedConfig = config self.quant_config: AwqConfig = quant_config - self.processor: CLIPImageProcessor = processor + self.processor: BaseImageProcessor = processor def to(self, device: Annotated[str, Doc("The device to move your model to.")]): """A utility function for moving the model to a device.""" @@ -186,6 +188,11 @@ def quantize( ] = 1024 * 1024 * 1024, + quantizer_cls: Annotated[ + AwqQuantizer, + Doc("If you want to customize the quantization class, you can use AwqQuantizer as a base class.") + ] = AwqQuantizer, + **kwargs, ): """ The main quantization function that you can use to quantize your model. @@ -209,7 +216,7 @@ def quantize( if hasattr(self, "modules_to_not_convert"): self.quant_config.modules_to_not_convert = self.modules_to_not_convert - self.quantizer = AwqQuantizer( + self.quantizer = quantizer_cls( self, self.model, tokenizer, @@ -228,6 +235,7 @@ def quantize( max_calib_samples=max_calib_samples, max_calib_seq_len=max_calib_seq_len, max_chunk_memory=max_chunk_memory, + **kwargs, ) self.quantizer.quantize() @@ -373,7 +381,6 @@ def from_pretrained( processor = None if target_cls_name == "AutoModelForVision2Seq": processor = AutoProcessor.from_pretrained(model_weights_path) - processor: CLIPImageProcessor = processor.image_processor # If not quantized, must load with AutoModelForCausalLM model = target_cls.from_pretrained( diff --git a/awq/models/qwen2vl.py b/awq/models/qwen2vl.py new file mode 100644 index 00000000..4d8ad213 --- /dev/null +++ b/awq/models/qwen2vl.py @@ -0,0 +1,75 @@ +from .base import BaseAWQForCausalLM +from typing_extensions import TYPE_CHECKING + +if TYPE_CHECKING: + from transformers import Qwen2VLForConditionalGeneration + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLDecoderLayer + +class Qwen2VLAWQForCausalLM(BaseAWQForCausalLM): + layer_type = "Qwen2VLDecoderLayer" + max_seq_len_key = "max_position_embeddings" + modules_to_not_convert = ["visual"] + + @staticmethod + def get_model_layers(model: "Qwen2VLForConditionalGeneration"): + return model.model.layers + + @staticmethod + def get_act_for_scaling(module: "Qwen2VLForConditionalGeneration"): + return dict(is_scalable=False) + + @staticmethod + def move_embed(model: "Qwen2VLForConditionalGeneration", device: str): + model.model.embed_tokens = model.model.embed_tokens.to(device) + model.visual = model.visual.to(device) + + @staticmethod + def get_layers_for_scaling(module: "Qwen2VLDecoderLayer", input_feat, module_kwargs): + layers = [] + + # attention input + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) + + # attention out + # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 + if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + + # linear 1 + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.gate_proj, module.mlp.up_proj], + inp=input_feat["mlp.gate_proj"], + module2inspect=module.mlp, + ) + ) + + # linear 2 + layers.append( + dict( + prev_op=module.mlp.up_proj, + layers=[module.mlp.down_proj], + inp=input_feat["mlp.down_proj"], + ) + ) + + return layers \ No newline at end of file diff --git a/awq/utils/qwen_vl_utils.py b/awq/utils/qwen_vl_utils.py new file mode 100644 index 00000000..08ba02f7 --- /dev/null +++ b/awq/utils/qwen_vl_utils.py @@ -0,0 +1,339 @@ +from __future__ import annotations + +import base64 +import logging +import math +import os +import sys +import time +import warnings +from functools import lru_cache +from io import BytesIO + +import requests +import torch +import torchvision +from packaging import version +from PIL import Image +from torchvision import io, transforms +from torchvision.transforms import InterpolationMode + + +logger = logging.getLogger(__name__) + +IMAGE_FACTOR = 28 +MIN_PIXELS = 4 * 28 * 28 +MAX_PIXELS = 16384 * 28 * 28 +MAX_RATIO = 200 + +VIDEO_MIN_PIXELS = 128 * 28 * 28 +VIDEO_MAX_PIXELS = 768 * 28 * 28 +VIDEO_TOTAL_PIXELS = 24576 * 28 * 28 +FRAME_FACTOR = 2 +FPS = 2.0 +FPS_MIN_FRAMES = 4 +FPS_MAX_FRAMES = 768 + + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +def smart_resize( + height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS +) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + +def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image: + if "image" in ele: + image = ele["image"] + else: + image = ele["image_url"] + image_obj = None + if isinstance(image, Image.Image): + image_obj = image + elif image.startswith("http://") or image.startswith("https://"): + image_obj = Image.open(requests.get(image, stream=True).raw) + elif image.startswith("file://"): + image_obj = Image.open(image[7:]) + elif image.startswith("data:image"): + if "base64," in image: + _, base64_data = image.split("base64,", 1) + data = base64.b64decode(base64_data) + image_obj = Image.open(BytesIO(data)) + else: + image_obj = Image.open(image) + if image_obj is None: + raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") + image = image_obj.convert("RGB") + ## resize + if "resized_height" in ele and "resized_width" in ele: + resized_height, resized_width = smart_resize( + ele["resized_height"], + ele["resized_width"], + factor=size_factor, + ) + else: + width, height = image.size + min_pixels = ele.get("min_pixels", MIN_PIXELS) + max_pixels = ele.get("max_pixels", MAX_PIXELS) + resized_height, resized_width = smart_resize( + height, + width, + factor=size_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + + return image + + +def smart_nframes( + ele: dict, + total_frames: int, + video_fps: int | float, +) -> int: + """calculate the number of frames for video used for model inputs. + + Args: + ele (dict): a dict contains the configuration of video. + support either `fps` or `nframes`: + - nframes: the number of frames to extract for model inputs. + - fps: the fps to extract frames for model inputs. + - min_frames: the minimum number of frames of the video, only used when fps is provided. + - max_frames: the maximum number of frames of the video, only used when fps is provided. + total_frames (int): the original total number of frames of the video. + video_fps (int | float): the original fps of the video. + + Raises: + ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. + + Returns: + int: the number of frames for video used for model inputs. + """ + assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`" + if "nframes" in ele: + nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) + else: + fps = ele.get("fps", FPS) + min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) + max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR) + nframes = total_frames / video_fps * fps + nframes = min(max(nframes, min_frames), max_frames) + nframes = round_by_factor(nframes, FRAME_FACTOR) + if not (FRAME_FACTOR <= nframes and nframes <= total_frames): + raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.") + return nframes + + +def _read_video_torchvision( + ele: dict, +) -> torch.Tensor: + """read video using torchvision.io.read_video + + Args: + ele (dict): a dict contains the configuration of video. + support keys: + - video: the path of video. support "file://", "http://", "https://" and local path. + - video_start: the start time of video. + - video_end: the end time of video. + Returns: + torch.Tensor: the video tensor with shape (T, C, H, W). + """ + video_path = ele["video"] + if version.parse(torchvision.__version__) < version.parse("0.19.0"): + if "http://" in video_path or "https://" in video_path: + warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.") + if "file://" in video_path: + video_path = video_path[7:] + st = time.time() + video, audio, info = io.read_video( + video_path, + start_pts=ele.get("video_start", 0.0), + end_pts=ele.get("video_end", None), + pts_unit="sec", + output_format="TCHW", + ) + total_frames, video_fps = video.size(0), info["video_fps"] + logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") + nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) + idx = torch.linspace(0, total_frames - 1, nframes).round().long() + video = video[idx] + return video + + +def is_decord_available() -> bool: + import importlib.util + + return importlib.util.find_spec("decord") is not None + + +def _read_video_decord( + ele: dict, +) -> torch.Tensor: + """read video using decord.VideoReader + + Args: + ele (dict): a dict contains the configuration of video. + support keys: + - video: the path of video. support "file://", "http://", "https://" and local path. + - video_start: the start time of video. + - video_end: the end time of video. + Returns: + torch.Tensor: the video tensor with shape (T, C, H, W). + """ + import decord + video_path = ele["video"] + st = time.time() + vr = decord.VideoReader(video_path) + # TODO: support start_pts and end_pts + if 'video_start' in ele or 'video_end' in ele: + raise NotImplementedError("not support start_pts and end_pts in decord for now.") + total_frames, video_fps = len(vr), vr.get_avg_fps() + logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") + nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) + idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() + video = vr.get_batch(idx).asnumpy() + video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format + return video + + +VIDEO_READER_BACKENDS = { + "decord": _read_video_decord, + "torchvision": _read_video_torchvision, +} + +FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None) + + +@lru_cache(maxsize=1) +def get_video_reader_backend() -> str: + if FORCE_QWENVL_VIDEO_READER is not None: + video_reader_backend = FORCE_QWENVL_VIDEO_READER + elif is_decord_available(): + video_reader_backend = "decord" + else: + video_reader_backend = "torchvision" + print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr) + return video_reader_backend + + +def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]: + if isinstance(ele["video"], str): + video_reader_backend = get_video_reader_backend() + video = VIDEO_READER_BACKENDS[video_reader_backend](ele) + nframes, _, height, width = video.shape + + min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) + total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS) + max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05)) + max_pixels = ele.get("max_pixels", max_pixels) + if "resized_height" in ele and "resized_width" in ele: + resized_height, resized_width = smart_resize( + ele["resized_height"], + ele["resized_width"], + factor=image_factor, + ) + else: + resized_height, resized_width = smart_resize( + height, + width, + factor=image_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + video = transforms.functional.resize( + video, + [resized_height, resized_width], + interpolation=InterpolationMode.BICUBIC, + antialias=True, + ).float() + return video + else: + assert isinstance(ele["video"], (list, tuple)) + process_info = ele.copy() + process_info.pop("type", None) + process_info.pop("video", None) + images = [ + fetch_image({"image": video_element, **process_info}, size_factor=image_factor) + for video_element in ele["video"] + ] + nframes = ceil_by_factor(len(images), FRAME_FACTOR) + if len(images) < nframes: + images.extend([images[-1]] * (nframes - len(images))) + return images + + +def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]: + vision_infos = [] + if isinstance(conversations[0], dict): + conversations = [conversations] + for conversation in conversations: + for message in conversation: + if isinstance(message["content"], list): + for ele in message["content"]: + if ( + "image" in ele + or "image_url" in ele + or "video" in ele + or ele["type"] in ("image", "image_url", "video") + ): + vision_infos.append(ele) + return vision_infos + + +def process_vision_info( + conversations: list[dict] | list[list[dict]], +) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None]: + vision_infos = extract_vision_info(conversations) + ## Read images or videos + image_inputs = [] + video_inputs = [] + for vision_info in vision_infos: + if "image" in vision_info or "image_url" in vision_info: + image_inputs.append(fetch_image(vision_info)) + elif "video" in vision_info: + video_inputs.append(fetch_video(vision_info)) + else: + raise ValueError("image, image_url or video should in content.") + if len(image_inputs) == 0: + image_inputs = None + if len(video_inputs) == 0: + video_inputs = None + return image_inputs, video_inputs \ No newline at end of file diff --git a/docs/examples.md b/docs/examples.md index 2fc0259c..6032b212 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -274,6 +274,156 @@ subprocess.run([ ], shell=True, check=True) ``` +### Custom Quantizer (Qwen2 VL Example) + +Below, the Qwen team has provided an example of how to use a custom quantizer. This works to +effectively quantize the Qwen2 VL model using multimodal examples. + +```python +import torch +import torch.nn as nn + +from awq import AutoAWQForCausalLM +from awq.utils.qwen_vl_utils import process_vision_info +from awq.quantize.quantizer import AwqQuantizer, clear_memory, get_best_device + +# Specify paths and hyperparameters for quantization +model_path = "Qwen/Qwen2-VL-7B-Instruct" +quant_path = "qwen2-vl-7b-instruct" +quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"} + +model = AutoAWQForCausalLM.from_pretrained( + model_path, use_cache=False, attn_implementation="flash_attention_2" +) + +# We define our own quantizer by extending the AwqQuantizer. +# The main difference is in how the samples are processed when +# the quantization process initialized. +class Qwen2VLAwqQuantizer(AwqQuantizer): + def init_quant(self, n_samples=None, max_seq_len=None): + modules = self.awq_model.get_model_layers(self.model) + samples = self.calib_data + + inps = [] + layer_kwargs = {} + + best_device = get_best_device() + modules[0] = modules[0].to(best_device) + self.awq_model.move_embed(self.model, best_device) + + # get input and kwargs to layer 0 + # with_kwargs is only supported in PyTorch 2.0 + # use this Catcher hack for now + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + # assume first input to forward is hidden states + if len(args) > 0: + hidden_states = args[0] + del args + else: + first_key = list(kwargs.keys())[0] + hidden_states = kwargs.pop(first_key) + + inps.append(hidden_states) + layer_kwargs.update(kwargs) + raise ValueError # early exit to break later inference + + def move_to_device(obj: torch.Tensor | nn.Module, device: torch.device): + def get_device(obj: torch.Tensor | nn.Module): + if isinstance(obj, torch.Tensor): + return obj.device + return next(obj.parameters()).device + + if get_device(obj) != device: + obj = obj.to(device) + return obj + + # patch layer 0 to catch input and kwargs + modules[0] = Catcher(modules[0]) + for k, v in samples.items(): + if isinstance(v, (torch.Tensor, nn.Module)): + samples[k] = move_to_device(v, best_device) + try: + self.model(**samples) + except ValueError: # work with early exit + pass + finally: + for k, v in samples.items(): + if isinstance(v, (torch.Tensor, nn.Module)): + samples[k] = move_to_device(v, "cpu") + modules[0] = modules[0].module # restore + + del samples + inps = inps[0] + + modules[0] = modules[0].cpu() + self.awq_model.move_embed(self.model, "cpu") + + clear_memory() + + return modules, layer_kwargs, inps + +# Then you need to prepare your data for calibaration. What you need to do is just put samples into a list, +# each of which is a typical chat message as shown below. you can specify text and image in `content` field: +# dataset = [ +# # message 0 +# [ +# {"role": "system", "content": "You are a helpful assistant."}, +# {"role": "user", "content": "Tell me who you are."}, +# {"role": "assistant", "content": "I am a large language model named Qwen..."}, +# ], +# # message 1 +# [ +# { +# "role": "user", +# "content": [ +# {"type": "image", "image": "file:///path/to/your/image.jpg"}, +# {"type": "text", "text": "Output all text in the image"}, +# ], +# }, +# {"role": "assistant", "content": "The text in the image is balabala..."}, +# ], +# # other messages... +# ..., +# ] +# here, we use a caption dataset **only for demonstration**. You should replace it with your own sft dataset. +def prepare_dataset(n_sample: int = 8) -> list[list[dict]]: + from datasets import load_dataset + + dataset = load_dataset("laion/220k-GPT4Vision-captions-from-LIVIS", split=f"train[:{n_sample}]") + return [ + [ + { + "role": "user", + "content": [ + {"type": "image", "image": sample["url"]}, + {"type": "text", "text": "generate a caption for this image"}, + ], + }, + {"role": "assistant", "content": sample["caption"]}, + ] + for sample in dataset + ] + +dataset = prepare_dataset() + +# process the dataset into tensors +text = model.processor.apply_chat_template(dataset, tokenize=False, add_generation_prompt=True) +image_inputs, video_inputs = process_vision_info(dataset) +inputs = model.processor(text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") + +# Then just run the calibration process by one line of code +model.quantize(calib_data=inputs, quant_config=quant_config, quantizer_cls=Qwen2VLAwqQuantizer) + +# Save the model +model.model.config.use_cache = model.model.generation_config.use_cache = True +model.save_quantized(quant_path, safetensors=True, shard_size="4GB") +``` + ## Basic Inference ### Inference With GPU @@ -466,3 +616,52 @@ generation_output = model.generate( streamer=streamer ) ``` + +### Qwen2 VL + +Below is an example of how to run inference using Qwen2 VL. + +```python +from awq import AutoAWQForCausalLM +from awq.utils.qwen_vl_utils import process_vision_info +from transformers import AutoProcessor, TextStreamer + +# Load model +quant_path = "Qwen/Qwen2-VL-7B-Instruct-AWQ" +model = AutoAWQForCausalLM.from_quantized(quant_path) +processor = AutoProcessor.from_pretrained(quant_path) +streamer = TextStreamer(processor, skip_prompt=True) + +messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image."}, + ], + } +] + +# Load inputs +text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True +) +image_inputs, video_inputs = process_vision_info(messages) +inputs = processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", +) +inputs = inputs.to("cuda") + +generation_output = model.generate( + **inputs, + max_new_tokens=512, + streamer=streamer +) +``` \ No newline at end of file