Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Naive Run Compressed Pt. 2 #62

Merged
merged 14 commits into from
Aug 30, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ class QuantizationModifier(Modifier):
calibration_function_: Any = None

def on_initialize_structure(self, state: State, **kwargs):
module = state.model
self._apply_modifier_to_model(module)
module.apply(freeze_module_quantization)
pass

def on_initialize(
self, state: State, freeze_quantization: bool = True, **kwargs
Expand Down
29 changes: 27 additions & 2 deletions src/llmcompressor/transformers/sparsification/sparse_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from typing import Optional, Union

import torch
from accelerate import load_checkpoint_and_dispatch
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.quantization import (
QuantizationStatus,
apply_quantization_config,
)
from loguru import logger
from torch.nn import Module
from transformers import AutoModelForCausalLM, PreTrainedModel
Expand Down Expand Up @@ -40,6 +45,7 @@ class SparseAutoModelForCausalLM(AutoModelForCausalLM):
def from_pretrained(
cls,
pretrained_model_name_or_path,
run_compressed: bool = False,
recipe: Optional[Union[str, Path]] = None,
*model_args,
**kwargs,
Expand Down Expand Up @@ -115,8 +121,27 @@ def skip(*args, **kwargs):
# If model is quantized or compressed on disk, initialize quantization
# structure and run decompression
if compressor is not None:
# initialize quantization and decompress weights
compressor.decompress(model_path=pretrained_model_name_or_path, model=model)
quantization_config = compressor.quantization_config
is_compressed = (
quantization_config is not None
and quantization_config.quantization_status
== QuantizationStatus.COMPRESSED
)
if run_compressed and is_compressed:
# initialize quantization, don't decompress
apply_quantization_config(
model, quantization_config, run_compressed=True
)
model = load_checkpoint_and_dispatch(
model, pretrained_model_name_or_path
)
else:
# initialize quantization and decompress weights
if quantization_config is not None:
quantization_config.quantization_status = QuantizationStatus.FROZEN
compressor.decompress(
model_path=pretrained_model_name_or_path, model=model
)

recipe = resolve_recipe(recipe=recipe, model_path=pretrained_model_name_or_path)
if recipe:
Expand Down
Loading