Skip to content

Commit 920b084

Browse files
committed
Switch to new ao quant api for 8da4w (#8501)
1 parent 8cd1b93 commit 920b084

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

examples/models/llama/source_transformation/quantize.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,12 @@ def quantize( # noqa C901
138138
# Check for required args
139139
if group_size is None:
140140
raise Exception("For 8da4w quantization, group size must be specified.")
141-
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
142141

143-
# 1. Quantize in checkpoint dtype.
144-
model = Int8DynActInt4WeightQuantizer(
145-
precision=checkpoint_torch_dtype, groupsize=group_size
146-
).quantize(model)
147-
# 2. Set the computation dtype (what weights/acts dequantize to).
148-
model = set_8da4w_computation_dtype(model, computation_torch_dtype)
142+
from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
143+
144+
quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size))
145+
146+
# TODO: deal with checkpoint / computation dtype decoupling.
149147

150148
if verbose:
151149
print("quantized model:", model)
@@ -700,7 +698,7 @@ def convert_for_runtime(self) -> nn.Module:
700698
def quantized_model(self) -> nn.Module:
701699
model_updated_state_dict = self.create_quantized_state_dict(self.packed)
702700
self.convert_for_runtime()
703-
self.mod.load_state_dict(model_updated_state_dict)
701+
self.mod.load_state_dict(model_updated_state_dict, assign=True)
704702
return self.mod
705703

706704

0 commit comments

Comments
 (0)