Skip to content

Commit 3e848e9

Browse files
committed
Fix export with unwrap_tensor_subclass
1 parent 07838a4 commit 3e848e9

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

examples/models/llama/source_transformation/quantize.py

+2
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,10 @@ def quantize( # noqa C901
138138
raise Exception("For 8da4w quantization, group size must be specified.")
139139

140140
from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
141+
from torchao.utils import unwrap_tensor_subclass
141142

142143
quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size))
144+
model = unwrap_tensor_subclass(model)
143145

144146
# TODO: deal with checkpoint / computation dtype decoupling.
145147

0 commit comments

Comments
 (0)