From 37ac00671975314f5e8ad656438b2197cd18c13a Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Thu, 4 Jul 2024 00:55:49 +0000 Subject: [PATCH 1/8] update auto model --- .../transformers/sparsification/sparse_model.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/transformers/sparsification/sparse_model.py b/src/llmcompressor/transformers/sparsification/sparse_model.py index 6a148c784..6b4b2525f 100644 --- a/src/llmcompressor/transformers/sparsification/sparse_model.py +++ b/src/llmcompressor/transformers/sparsification/sparse_model.py @@ -5,6 +5,7 @@ import torch from compressed_tensors.compressors import ModelCompressor +from compressed_tensors.quantization import apply_quantization_config, QuantizationStatus from loguru import logger from torch.nn import Module from transformers import AutoModelForCausalLM, PreTrainedModel @@ -40,6 +41,7 @@ class SparseAutoModelForCausalLM(AutoModelForCausalLM): def from_pretrained( cls, pretrained_model_name_or_path, + run_compressed: bool = False, recipe: Optional[Union[str, Path]] = None, *model_args, **kwargs, @@ -103,8 +105,14 @@ def skip(*args, **kwargs): # If model is quantized or compressed on disk, initialize quantization # structure and run decompression if compressor is not None: - # initialize quantization and decompress weights - compressor.decompress(model_path=pretrained_model_name_or_path, model=model) + quantization_config = compressor.quantization_config + is_compressed = (quantization_config.quantization_status == QuantizationStatus.COMPRESSED) + if run_compressed and is_compressed: + # initialize quantization, don't decompress + apply_quantization_config(model, quantization_config) + else: + # initialize quantization and decompress weights + compressor.decompress(model_path=pretrained_model_name_or_path, model=model) recipe = resolve_recipe(recipe=recipe, model_path=pretrained_model_name_or_path) if recipe: From 847b4a47a15494e0481e53d1ca28549d2658e6db Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 9 Jul 2024 17:55:05 +0000 Subject: [PATCH 2/8] style and reload --- .../transformers/sparsification/sparse_model.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/transformers/sparsification/sparse_model.py b/src/llmcompressor/transformers/sparsification/sparse_model.py index 6b4b2525f..f4dd88c66 100644 --- a/src/llmcompressor/transformers/sparsification/sparse_model.py +++ b/src/llmcompressor/transformers/sparsification/sparse_model.py @@ -4,8 +4,12 @@ from typing import Optional, Union import torch +from accelerate import load_checkpoint_and_dispatch from compressed_tensors.compressors import ModelCompressor -from compressed_tensors.quantization import apply_quantization_config, QuantizationStatus +from compressed_tensors.quantization import ( + QuantizationStatus, + apply_quantization_config, +) from loguru import logger from torch.nn import Module from transformers import AutoModelForCausalLM, PreTrainedModel @@ -106,13 +110,20 @@ def skip(*args, **kwargs): # structure and run decompression if compressor is not None: quantization_config = compressor.quantization_config - is_compressed = (quantization_config.quantization_status == QuantizationStatus.COMPRESSED) + is_compressed = ( + quantization_config.quantization_status == QuantizationStatus.COMPRESSED + ) if run_compressed and is_compressed: # initialize quantization, don't decompress apply_quantization_config(model, quantization_config) + model = load_checkpoint_and_dispatch( + model, pretrained_model_name_or_path, *model_args, **kwargs + ) else: # initialize quantization and decompress weights - compressor.decompress(model_path=pretrained_model_name_or_path, model=model) + compressor.decompress( + model_path=pretrained_model_name_or_path, model=model + ) recipe = resolve_recipe(recipe=recipe, model_path=pretrained_model_name_or_path) if recipe: From 278fc209d0437b0b07c720d512034800fbd60b7b Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 6 Aug 2024 19:40:42 +0000 Subject: [PATCH 3/8] fixing tests --- .../modifiers/quantization/gptq/utils/gptq_wrapper.py | 2 +- src/llmcompressor/transformers/sparsification/sparse_model.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py index d8e565777..6d7966672 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -161,7 +161,7 @@ def compress( quant_scheme = self.layer.quantization_scheme if quant_scheme.weights is not None: scale = self.layer.weight_scale - zero_point = self.layer.weight_zero_point + zero_point = getattr(self.layer, "weight_zero_point", None) from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization.lifecycle.forward import ( fake_quantize, diff --git a/src/llmcompressor/transformers/sparsification/sparse_model.py b/src/llmcompressor/transformers/sparsification/sparse_model.py index b512ed045..8a229d67b 100644 --- a/src/llmcompressor/transformers/sparsification/sparse_model.py +++ b/src/llmcompressor/transformers/sparsification/sparse_model.py @@ -123,6 +123,7 @@ def skip(*args, **kwargs): if compressor is not None: quantization_config = compressor.quantization_config is_compressed = ( + quantization_config is not None and quantization_config.quantization_status == QuantizationStatus.COMPRESSED ) if run_compressed and is_compressed: @@ -133,6 +134,8 @@ def skip(*args, **kwargs): ) else: # initialize quantization and decompress weights + if quantization_config is not None: + quantization_config.quantization_status = QuantizationStatus.FROZEN compressor.decompress( model_path=pretrained_model_name_or_path, model=model ) From f1e269f6ac5d23c76b1ed61b0799689ff8393487 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 7 Aug 2024 18:19:47 +0000 Subject: [PATCH 4/8] sparseautomodel compatability --- src/llmcompressor/transformers/compression/helpers.py | 3 ++- .../transformers/sparsification/sparse_model.py | 11 +++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/transformers/compression/helpers.py b/src/llmcompressor/transformers/compression/helpers.py index 17b215f86..d81c9dfab 100644 --- a/src/llmcompressor/transformers/compression/helpers.py +++ b/src/llmcompressor/transformers/compression/helpers.py @@ -1,8 +1,9 @@ from typing import Dict, List, Optional, Union -from accelerate.accelerator import get_state_dict_offloaded_model + import psutil import torch from accelerate import infer_auto_device_map, init_empty_weights +from accelerate.accelerator import get_state_dict_offloaded_model from torch.nn.modules import Linear from tqdm import tqdm from transformers import AutoModelForCausalLM diff --git a/src/llmcompressor/transformers/sparsification/sparse_model.py b/src/llmcompressor/transformers/sparsification/sparse_model.py index 8a229d67b..9bd0faac0 100644 --- a/src/llmcompressor/transformers/sparsification/sparse_model.py +++ b/src/llmcompressor/transformers/sparsification/sparse_model.py @@ -123,14 +123,17 @@ def skip(*args, **kwargs): if compressor is not None: quantization_config = compressor.quantization_config is_compressed = ( - quantization_config is not None and - quantization_config.quantization_status == QuantizationStatus.COMPRESSED + quantization_config is not None + and quantization_config.quantization_status + == QuantizationStatus.COMPRESSED ) if run_compressed and is_compressed: # initialize quantization, don't decompress - apply_quantization_config(model, quantization_config) + apply_quantization_config( + model, quantization_config, run_compressed=True + ) model = load_checkpoint_and_dispatch( - model, pretrained_model_name_or_path, *model_args, **kwargs + model, pretrained_model_name_or_path ) else: # initialize quantization and decompress weights From bef05bf18c0dcd2b343c0c3a9d4d3bb664f14749 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 7 Aug 2024 19:26:39 +0000 Subject: [PATCH 5/8] revert un-needed change --- .../modifiers/quantization/gptq/utils/gptq_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py index 6d7966672..d8e565777 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -161,7 +161,7 @@ def compress( quant_scheme = self.layer.quantization_scheme if quant_scheme.weights is not None: scale = self.layer.weight_scale - zero_point = getattr(self.layer, "weight_zero_point", None) + zero_point = self.layer.weight_zero_point from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization.lifecycle.forward import ( fake_quantize, From 7892cb26f2fd3d6768312b628f7486ba0c103791 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Mon, 12 Aug 2024 14:26:11 +0000 Subject: [PATCH 6/8] skip init --- .../modifiers/quantization/quantization/base.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index 319350dcb..96e917bc5 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -57,10 +57,8 @@ class QuantizationModifier(Modifier): calibration_function_: Any = None def on_initialize_structure(self, state: State, **kwargs): - module = state.model - self._apply_modifier_to_model(module) - module.apply(freeze_module_quantization) - + pass + def on_initialize( self, state: State, freeze_quantization: bool = True, **kwargs ) -> bool: From b6001c664cc6231d9dd31db7b4bb944bbfb49303 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Mon, 12 Aug 2024 15:47:16 +0000 Subject: [PATCH 7/8] quality --- src/llmcompressor/modifiers/quantization/quantization/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index 96e917bc5..09db774b8 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -58,7 +58,7 @@ class QuantizationModifier(Modifier): def on_initialize_structure(self, state: State, **kwargs): pass - + def on_initialize( self, state: State, freeze_quantization: bool = True, **kwargs ) -> bool: From 6d79ca842403f9d54b84e419a99f0d10539a83b2 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Fri, 30 Aug 2024 04:41:17 +0000 Subject: [PATCH 8/8] remove recipe app --- .../transformers/sparsification/sparse_model.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/llmcompressor/transformers/sparsification/sparse_model.py b/src/llmcompressor/transformers/sparsification/sparse_model.py index 3af7d2b4d..4153ec4f4 100644 --- a/src/llmcompressor/transformers/sparsification/sparse_model.py +++ b/src/llmcompressor/transformers/sparsification/sparse_model.py @@ -14,14 +14,10 @@ from torch.nn import Module from transformers import AutoModelForCausalLM, PreTrainedModel -from llmcompressor.pytorch.model_load.helpers import initialize_recipe from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( modify_save_pretrained, ) -from llmcompressor.transformers.utils.helpers import ( - download_model_directory, - resolve_recipe, -) +from llmcompressor.transformers.utils.helpers import download_model_directory __all__ = ["SparseAutoModel", "SparseAutoModelForCausalLM", "get_shared_tokenizer_src"] @@ -143,10 +139,6 @@ def skip(*args, **kwargs): model_path=pretrained_model_name_or_path, model=model ) - recipe = resolve_recipe(recipe=recipe, model_path=pretrained_model_name_or_path) - if recipe: - initialize_recipe(model=model, recipe_path=recipe) - return model