We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 07838a4 commit 3e848e9Copy full SHA for 3e848e9
examples/models/llama/source_transformation/quantize.py
@@ -138,8 +138,10 @@ def quantize( # noqa C901
138
raise Exception("For 8da4w quantization, group size must be specified.")
139
140
from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
141
+ from torchao.utils import unwrap_tensor_subclass
142
143
quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size))
144
+ model = unwrap_tensor_subclass(model)
145
146
# TODO: deal with checkpoint / computation dtype decoupling.
147
0 commit comments