Skip to content

Commit

Permalink
[quant] Run weight_post_process for QAT (#33852)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#33852

This fixes an issue for QAT models. During eval if we call `prepare_qat` and `convert` before calling `load_state_dict` it throws an error because the weight info (num channels) is not updated in the observer module.
It is not an issue for per-tensor case

Fixes issue #33830

Test Plan:
python test/test_quantization.py EagerModePostTrainingQuantTest.test_eval_after_train
python test/test_quantization.py EagerModeQuantizationAwareTrainingTest.test_eval_after_train

Imported from OSS

Differential Revision: D20212996

fbshipit-source-id: a04af8fe4df2e555270ae4d6693f5777d86f8a46
  • Loading branch information
supriyar authored and facebook-github-bot committed Mar 4, 2020
1 parent d59e036 commit e236e15
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 3 deletions.
93 changes: 92 additions & 1 deletion test/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.quantization import default_per_channel_qconfig
from torch.quantization._quantize_script import quantize_script

from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import run_tests, TEST_WITH_UBSAN, IS_WINDOWS
from torch.testing._internal.common_quantization import QuantizationTestCase, \
AnnotatedSingleLayerLinearModel, SingleLayerLinearModel, \
AnnotatedConvModel, ConvModel, \
Expand Down Expand Up @@ -311,6 +311,44 @@ def checkQuantized(model):

checkQuantized(model)

@given(qengine=st.sampled_from(("qnnpack", "fbgemm")))
def test_save_load_state_dict(self, qengine):
r"""Test PTQ flow of creating a model and quantizing it and saving the quantized state_dict
Load the quantized state_dict for eval and compare results against original model
"""
if qengine == 'qnnpack':
if IS_WINDOWS or TEST_WITH_UBSAN:
return
with override_quantized_engine(qengine):
model = TwoLayerLinearModel()
model = torch.quantization.QuantWrapper(model)
model.qconfig = torch.quantization.get_default_qconfig(qengine)

model = prepare(model)
# calibrate
test_only_eval_fn(model, self.calib_data)
model = convert(model)
x = torch.rand(2, 5, dtype=torch.float)
ref = model(x)

quant_state_dict = model.state_dict()

# Create model again for eval
model = TwoLayerLinearModel()
model = torch.quantization.QuantWrapper(model)
model.qconfig = torch.quantization.get_default_qconfig(qengine)
model = prepare(model)
model = convert(model)
new_state_dict = model.state_dict()

# Check to make sure the state dict keys match original model after convert.
self.assertEqual(set(new_state_dict.keys()), set(quant_state_dict.keys()))

model.load_state_dict(quant_state_dict)

out = model(x)
self.assertEqual(ref, out)

@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.")
Expand Down Expand Up @@ -793,6 +831,59 @@ def checkQuantized(model):
model = quantize_qat(model, test_only_train_fn, self.img_data)
checkQuantized(model)

@given(qengine=st.sampled_from(("qnnpack", "fbgemm")))
def test_train_save_load_eval(self, qengine):
r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict
During eval, we first call prepare_qat and conver on the model and then load the state_dict
and compare results against original model
"""
if qengine == 'qnnpack':
if IS_WINDOWS or TEST_WITH_UBSAN:
return
with override_quantized_engine(qengine):
model = TwoLayerLinearModel()
model = torch.quantization.QuantWrapper(model)
model.qconfig = torch.quantization.get_default_qat_qconfig(qengine)
model = prepare_qat(model)

fq_state_dict = model.state_dict()

test_only_train_fn(model, self.train_data)
model = convert(model)

quant_state_dict = model.state_dict()

x = torch.rand(2, 5, dtype=torch.float)
ref = model(x)

# Create model again for eval. Check result using quantized state_dict
model = TwoLayerLinearModel()
model = torch.quantization.QuantWrapper(model)
model.qconfig = torch.quantization.get_default_qat_qconfig(qengine)
torch.quantization.prepare_qat(model, inplace=True)
new_state_dict = model.state_dict()

# Check to make sure the model after prepare_qat has the same state_dict as original.
self.assertEqual(set(fq_state_dict.keys()), set(new_state_dict.keys()))

torch.quantization.convert(model, inplace=True)
model.eval()
model.load_state_dict(quant_state_dict)
out = model(x)
self.assertEqual(ref, out)

# Check model created using prepare has same state dict as quantized state_dict
model = TwoLayerLinearModel()
model.eval()
model = torch.quantization.QuantWrapper(model)
model.qconfig = torch.quantization.get_default_qconfig(qengine)
torch.quantization.prepare(model, inplace=True)
torch.quantization.convert(model, inplace=True)
self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys()))
model.eval()
model.load_state_dict(quant_state_dict)
out = model(x)
self.assertEqual(ref, out)

@unittest.skipUnless(
'fbgemm' in torch.backends.quantized.supported_engines,
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/quantized/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def from_float(cls, mod):
else:
activation_post_process = mod.activation_post_process
weight_post_process = mod.qconfig.weight()
weight_post_process(mod.weight)
weight_post_process(mod.weight)
act_scale, act_zp = activation_post_process.calculate_qparams()
assert weight_post_process.dtype == torch.qint8, \
'Weight observer must have a dtype of qint8'
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/quantized/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def from_float(cls, mod):
else:
activation_post_process = mod.activation_post_process
weight_post_process = mod.qconfig.weight()
weight_post_process(mod.weight)
weight_post_process(mod.weight)
dtype = weight_post_process.dtype
act_scale, act_zp = activation_post_process.calculate_qparams()
assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
Expand Down

0 comments on commit e236e15

Please # to comment.