File tree Expand file tree Collapse file tree 1 file changed +6
-8
lines changed
examples/models/llama/source_transformation Expand file tree Collapse file tree 1 file changed +6
-8
lines changed Original file line number Diff line number Diff line change @@ -138,14 +138,12 @@ def quantize( # noqa C901
138
138
# Check for required args
139
139
if group_size is None :
140
140
raise Exception ("For 8da4w quantization, group size must be specified." )
141
- from torchao .quantization .quant_api import Int8DynActInt4WeightQuantizer
142
141
143
- # 1. Quantize in checkpoint dtype.
144
- model = Int8DynActInt4WeightQuantizer (
145
- precision = checkpoint_torch_dtype , groupsize = group_size
146
- ).quantize (model )
147
- # 2. Set the computation dtype (what weights/acts dequantize to).
148
- model = set_8da4w_computation_dtype (model , computation_torch_dtype )
142
+ from torchao .quantization import int8_dynamic_activation_int4_weight , quantize_
143
+
144
+ quantize_ (model , int8_dynamic_activation_int4_weight (group_size = group_size ))
145
+
146
+ # TODO: deal with checkpoint / computation dtype decoupling.
149
147
150
148
if verbose :
151
149
print ("quantized model:" , model )
@@ -700,7 +698,7 @@ def convert_for_runtime(self) -> nn.Module:
700
698
def quantized_model (self ) -> nn .Module :
701
699
model_updated_state_dict = self .create_quantized_state_dict (self .packed )
702
700
self .convert_for_runtime ()
703
- self .mod .load_state_dict (model_updated_state_dict )
701
+ self .mod .load_state_dict (model_updated_state_dict , assign = True )
704
702
return self .mod
705
703
706
704
You can’t perform that action at this time.
0 commit comments