diff --git a/src/llmcompressor/transformers/sparsification/sparse_model.py b/src/llmcompressor/transformers/sparsification/sparse_model.py index 42530f3e1..2bd23bc9a 100644 --- a/src/llmcompressor/transformers/sparsification/sparse_model.py +++ b/src/llmcompressor/transformers/sparsification/sparse_model.py @@ -14,10 +14,14 @@ from torch.nn import Module from transformers import AutoModelForCausalLM, PreTrainedModel +from llmcompressor.pytorch.model_load.helpers import initialize_recipe from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( modify_save_pretrained, ) -from llmcompressor.transformers.utils.helpers import download_model_directory +from llmcompressor.transformers.utils.helpers import ( + download_model_directory, + resolve_recipe, +) __all__ = ["SparseAutoModel", "SparseAutoModelForCausalLM", "get_shared_tokenizer_src"] @@ -142,6 +146,10 @@ def skip(*args, **kwargs): compressor.decompress( model_path=pretrained_model_name_or_path, model=model ) + recipe = resolve_recipe(recipe=recipe, model_path=pretrained_model_name_or_path) + + if recipe: + initialize_recipe(model=model, recipe_path=recipe) return model diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 5d784eb28..41690566a 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -104,6 +104,7 @@ def _parse_configs_dir(current_config_dir): logging.info( f"Skipping testing model: {file} for cadence: {expected_cadence}" ) + if isinstance(configs_directory, list): for config in configs_directory: _parse_configs_dir(config)