Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Naive Run Compressed Pt. 2 #62

Merged
merged 14 commits into from
Aug 30, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ class QuantizationModifier(Modifier):
calibration_function_: Any = None

def on_initialize_structure(self, state: State, **kwargs):
module = state.model
self._apply_modifier_to_model(module)
module.apply(freeze_module_quantization)
pass

def on_initialize(self, state: State, **kwargs) -> bool:
if self.end and self.end != -1:
Expand Down
39 changes: 28 additions & 11 deletions src/llmcompressor/transformers/sparsification/sparse_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@
from typing import Optional, Union

import torch
from accelerate import load_checkpoint_and_dispatch
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.quantization import (
QuantizationStatus,
apply_quantization_config,
)
from loguru import logger
from torch.nn import Module
from transformers import AutoModelForCausalLM, PreTrainedModel

from llmcompressor.pytorch.model_load.helpers import initialize_recipe
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
)
from llmcompressor.transformers.utils.helpers import (
download_model_directory,
resolve_recipe,
)
from llmcompressor.transformers.utils.helpers import download_model_directory

__all__ = ["SparseAutoModel", "SparseAutoModelForCausalLM", "get_shared_tokenizer_src"]

Expand All @@ -40,6 +41,7 @@ class SparseAutoModelForCausalLM(AutoModelForCausalLM):
def from_pretrained(
cls,
pretrained_model_name_or_path,
run_compressed: bool = False,
recipe: Optional[Union[str, Path]] = None,
*model_args,
**kwargs,
Expand Down Expand Up @@ -115,12 +117,27 @@ def skip(*args, **kwargs):
# If model is quantized or compressed on disk, initialize quantization
# structure and run decompression
if compressor is not None:
# initialize quantization and decompress weights
compressor.decompress(model_path=pretrained_model_name_or_path, model=model)

recipe = resolve_recipe(recipe=recipe, model_path=pretrained_model_name_or_path)
if recipe:
initialize_recipe(model=model, recipe_path=recipe)
quantization_config = compressor.quantization_config
is_compressed = (
quantization_config is not None
and quantization_config.quantization_status
== QuantizationStatus.COMPRESSED
)
if run_compressed and is_compressed:
# initialize quantization, don't decompress
apply_quantization_config(
model, quantization_config, run_compressed=True
)
model = load_checkpoint_and_dispatch(
model, pretrained_model_name_or_path
)
else:
# initialize quantization and decompress weights
if quantization_config is not None:
quantization_config.quantization_status = QuantizationStatus.FROZEN
compressor.decompress(
model_path=pretrained_model_name_or_path, model=model
)

return model

Expand Down
Loading