From d8a717c9600f2c8fc63cac55f8ba096848f3cac1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 4 Sep 2024 14:01:06 -0400 Subject: [PATCH] Support bool arguments for actorder (#150) --- src/compressed_tensors/quantization/quant_args.py | 5 ++++- tests/test_quantization/test_quant_args.py | 7 ++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 4b624fcd6..f33149e09 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -93,7 +93,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True): strategy: Optional[QuantizationStrategy] = None block_structure: Optional[str] = None dynamic: bool = False - actorder: Optional[ActivationOrdering] = None + actorder: Union[ActivationOrdering, bool, None] = None observer: str = Field( default="minmax", description=( @@ -151,6 +151,9 @@ def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]: @field_validator("actorder", mode="before") def validate_actorder(cls, value) -> Optional[ActivationOrdering]: + if isinstance(value, bool): + return ActivationOrdering.GROUP if value else None + if isinstance(value, str): return ActivationOrdering(value.lower()) diff --git a/tests/test_quantization/test_quant_args.py b/tests/test_quantization/test_quant_args.py index 027dc1c00..e189cb78e 100644 --- a/tests/test_quantization/test_quant_args.py +++ b/tests/test_quantization/test_quant_args.py @@ -82,11 +82,12 @@ def test_actorder(): with pytest.raises(ValueError): QuantizationArgs(strategy="tensor", actorder="weight") - # test boolean defaulting + # test boolean and none defaulting assert ( - QuantizationArgs(group_size=1, actorder="weight").actorder - == ActivationOrdering.WEIGHT + QuantizationArgs(group_size=1, actorder=True).actorder + == ActivationOrdering.GROUP ) + assert QuantizationArgs(group_size=1, actorder=False).actorder is None assert QuantizationArgs(group_size=1, actorder=None).actorder is None