From e236e1593468a68e47c5bcafd7272eca01684294 Mon Sep 17 00:00:00 2001 From: Supriya Rao Date: Wed, 4 Mar 2020 13:58:44 -0800 Subject: [PATCH] [quant] Run weight_post_process for QAT (#33852) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33852 This fixes an issue for QAT models. During eval if we call `prepare_qat` and `convert` before calling `load_state_dict` it throws an error because the weight info (num channels) is not updated in the observer module. It is not an issue for per-tensor case Fixes issue #33830 Test Plan: python test/test_quantization.py EagerModePostTrainingQuantTest.test_eval_after_train python test/test_quantization.py EagerModeQuantizationAwareTrainingTest.test_eval_after_train Imported from OSS Differential Revision: D20212996 fbshipit-source-id: a04af8fe4df2e555270ae4d6693f5777d86f8a46 --- test/test_quantization.py | 93 +++++++++++++++++++++++++++- torch/nn/quantized/modules/conv.py | 2 +- torch/nn/quantized/modules/linear.py | 2 +- 3 files changed, 94 insertions(+), 3 deletions(-) diff --git a/test/test_quantization.py b/test/test_quantization.py index 674162fb2bb..f42ce00829e 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -23,7 +23,7 @@ from torch.quantization import default_per_channel_qconfig from torch.quantization._quantize_script import quantize_script -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import run_tests, TEST_WITH_UBSAN, IS_WINDOWS from torch.testing._internal.common_quantization import QuantizationTestCase, \ AnnotatedSingleLayerLinearModel, SingleLayerLinearModel, \ AnnotatedConvModel, ConvModel, \ @@ -311,6 +311,44 @@ def checkQuantized(model): checkQuantized(model) + @given(qengine=st.sampled_from(("qnnpack", "fbgemm"))) + def test_save_load_state_dict(self, qengine): + r"""Test PTQ flow of creating a model and quantizing it and saving the quantized state_dict + Load the quantized state_dict for eval and compare results against original model + """ + if qengine == 'qnnpack': + if IS_WINDOWS or TEST_WITH_UBSAN: + return + with override_quantized_engine(qengine): + model = TwoLayerLinearModel() + model = torch.quantization.QuantWrapper(model) + model.qconfig = torch.quantization.get_default_qconfig(qengine) + + model = prepare(model) + # calibrate + test_only_eval_fn(model, self.calib_data) + model = convert(model) + x = torch.rand(2, 5, dtype=torch.float) + ref = model(x) + + quant_state_dict = model.state_dict() + + # Create model again for eval + model = TwoLayerLinearModel() + model = torch.quantization.QuantWrapper(model) + model.qconfig = torch.quantization.get_default_qconfig(qengine) + model = prepare(model) + model = convert(model) + new_state_dict = model.state_dict() + + # Check to make sure the state dict keys match original model after convert. + self.assertEqual(set(new_state_dict.keys()), set(quant_state_dict.keys())) + + model.load_state_dict(quant_state_dict) + + out = model(x) + self.assertEqual(ref, out) + @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" " with instruction set support avx2 or newer.") @@ -793,6 +831,59 @@ def checkQuantized(model): model = quantize_qat(model, test_only_train_fn, self.img_data) checkQuantized(model) + @given(qengine=st.sampled_from(("qnnpack", "fbgemm"))) + def test_train_save_load_eval(self, qengine): + r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict + During eval, we first call prepare_qat and conver on the model and then load the state_dict + and compare results against original model + """ + if qengine == 'qnnpack': + if IS_WINDOWS or TEST_WITH_UBSAN: + return + with override_quantized_engine(qengine): + model = TwoLayerLinearModel() + model = torch.quantization.QuantWrapper(model) + model.qconfig = torch.quantization.get_default_qat_qconfig(qengine) + model = prepare_qat(model) + + fq_state_dict = model.state_dict() + + test_only_train_fn(model, self.train_data) + model = convert(model) + + quant_state_dict = model.state_dict() + + x = torch.rand(2, 5, dtype=torch.float) + ref = model(x) + + # Create model again for eval. Check result using quantized state_dict + model = TwoLayerLinearModel() + model = torch.quantization.QuantWrapper(model) + model.qconfig = torch.quantization.get_default_qat_qconfig(qengine) + torch.quantization.prepare_qat(model, inplace=True) + new_state_dict = model.state_dict() + + # Check to make sure the model after prepare_qat has the same state_dict as original. + self.assertEqual(set(fq_state_dict.keys()), set(new_state_dict.keys())) + + torch.quantization.convert(model, inplace=True) + model.eval() + model.load_state_dict(quant_state_dict) + out = model(x) + self.assertEqual(ref, out) + + # Check model created using prepare has same state dict as quantized state_dict + model = TwoLayerLinearModel() + model.eval() + model = torch.quantization.QuantWrapper(model) + model.qconfig = torch.quantization.get_default_qconfig(qengine) + torch.quantization.prepare(model, inplace=True) + torch.quantization.convert(model, inplace=True) + self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys())) + model.eval() + model.load_state_dict(quant_state_dict) + out = model(x) + self.assertEqual(ref, out) @unittest.skipUnless( 'fbgemm' in torch.backends.quantized.supported_engines, diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py index 54f0c934ba2..bfca7d53583 100644 --- a/torch/nn/quantized/modules/conv.py +++ b/torch/nn/quantized/modules/conv.py @@ -247,7 +247,7 @@ def from_float(cls, mod): else: activation_post_process = mod.activation_post_process weight_post_process = mod.qconfig.weight() - weight_post_process(mod.weight) + weight_post_process(mod.weight) act_scale, act_zp = activation_post_process.calculate_qparams() assert weight_post_process.dtype == torch.qint8, \ 'Weight observer must have a dtype of qint8' diff --git a/torch/nn/quantized/modules/linear.py b/torch/nn/quantized/modules/linear.py index 4f337d75067..fb618dcaf5b 100644 --- a/torch/nn/quantized/modules/linear.py +++ b/torch/nn/quantized/modules/linear.py @@ -215,7 +215,7 @@ def from_float(cls, mod): else: activation_post_process = mod.activation_post_process weight_post_process = mod.qconfig.weight() - weight_post_process(mod.weight) + weight_post_process(mod.weight) dtype = weight_post_process.dtype act_scale, act_zp = activation_post_process.calculate_qparams() assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'