Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Feature] Support lsq new #501

Merged
merged 13 commits into from
Apr 11, 2023
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ concurrency:

jobs:
test_linux:
runs-on: ubuntu-18.04
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: [3.7]
Expand Down
7 changes: 6 additions & 1 deletion mmrazor/engine/runner/quantization_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@

from torch.utils.data import DataLoader

from mmrazor.models import register_torch_fake_quants, register_torch_observers
from mmrazor.models.fake_quants import (enable_param_learning,
enable_static_estimate, enable_val)
from mmrazor.registry import LOOPS

TORCH_observers = register_torch_observers()
TORCH_fake_quants = register_torch_fake_quants()


@LOOPS.register_module()
class QATEpochBasedLoop(EpochBasedTrainLoop):
Expand Down Expand Up @@ -87,7 +91,6 @@ def run(self):
and self._epoch % self.val_interval == 0):
# observer disabled during evaluation
self.prepare_for_val()
self.runner.model.sync_qparams(src_mode='loss')
self.runner.val_loop.run()

self.runner.call_hook('after_train')
Expand All @@ -108,6 +111,7 @@ def run_epoch(self) -> None:
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)

self.runner.model.sync_qparams(src_mode='loss')
self.runner.call_hook('after_train_epoch')
self._epoch += 1

Expand Down Expand Up @@ -181,6 +185,7 @@ def run_epoch(self) -> None:
self.runner.model.apply(enable_param_learning)
self.run_iter(idx, data_batch)

self.runner.model.sync_qparams(src_mode='loss')
self.runner.call_hook('after_train_epoch')
self._epoch += 1

Expand Down
45 changes: 25 additions & 20 deletions mmrazor/models/algorithms/quantization/mm_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
disable_observer)
except ImportError:
from mmrazor.utils import get_placeholder

FakeQuantizeBase = get_placeholder('torch>=1.13')
MinMaxObserver = get_placeholder('torch>=1.13')
PerChannelMinMaxObserver = get_placeholder('torch>=1.13')
Expand Down Expand Up @@ -283,23 +284,31 @@ def _build_qmodels(self, model: BaseModel):
"""

rewriter_context = self._get_rewriter_context_in_mmdeploy(
self.deploy_cfg)
self.deploy_cfg) if self.deploy_cfg is not None else None

# Pop function records in `quantizer.tracer.skipped_method` temporarily
function_record_backup = self._pop_function_record_in_rewriter_context(
rewriter_context)
if rewriter_context is not None:
# Pop function records in `quantizer.tracer.skipped_method`
# temporarily
function_record_backup = \
self._pop_function_record_in_rewriter_context(rewriter_context)

qmodels = nn.ModuleDict()
for mode in self.forward_modes:
concrete_args = {'mode': mode}
# todo: support qat.
with rewriter_context:

if rewriter_context is not None:
with rewriter_context:
observed_module = self.quantizer.prepare(
model, concrete_args)
else:
observed_module = self.quantizer.prepare(model, concrete_args)

qmodels[mode] = observed_module

# Add these popped function records back.
rewriter_context._rewriter_manager.function_rewriter. \
_registry._rewrite_records.update(function_record_backup)
if rewriter_context is not None:
# Add these popped function records back.
rewriter_context._rewriter_manager.function_rewriter. \
_registry._rewrite_records.update(function_record_backup)

# data_samples can not be None in detectors during prediction.
# But we need to make the dummy prediction in _build_qmodels.
Expand Down Expand Up @@ -357,7 +366,10 @@ def get_deploy_model(self):
observed_model.load_state_dict(quantized_state_dict)

self.quantizer.post_process_for_deploy(
observed_model, device=device, keep_w_fake_quant=True)
observed_model,
device=device,
keep_w_fake_quant=True,
update_weight_with_fakequant=True)

# replace various activation fakequant with base fakequant, which
# contributes to deploy our model to various backends.
Expand Down Expand Up @@ -406,21 +418,14 @@ def calibrate_step(self, data: Union[Dict, Tuple, List]):

return self.module.calibrate_step(data)

def sync_qparams(self, src: str):
def sync_qparams(self, src_mode: str):
"""Same as in 'MMArchitectureQuant'. Sync all quantize parameters in
different `forward_modes`. We could have several modes to generate
graphs, but in training, only one graph will be update, so we need to
sync qparams on the other graphs.

Args:
src (str): The src modes of forward method.

Note:
`traverse()` function recursively traverses all module to sync
quantized graph generated from different `forward_modes`.
This is because We have different mode ('tensor', 'predict',
'loss') in OpenMMLab architecture which have different graph
in some subtle ways, so we need to sync them here.
src_mode (str): The src modes of forward method.
"""

self.module.sync_qparams(src)
self.module.sync_qparams(src_mode)
40 changes: 40 additions & 0 deletions mmrazor/models/fake_quants/lsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,46 @@ def forward(self, X):

return X

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""Removing this function throws an error that the the size of the
loaded tensor does not match the original size i.e., These buffers
start out with numel 0 and become numel 1 once they have their first
forward pass.

Modified from https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fake_quantize.py # noqa:E501
"""
local_state = ['scale', 'zero_point']
for name in local_state:
key = prefix + name
if key in state_dict:
val = state_dict[key]
# Custom handling to allow loading scale and zero_point
# of size N into uninitialized buffers of size 0. The
# buffers are resized here, and the values are copied in
# the default state_dict loading code of the parent.
if name == 'scale':
self.scale.data = self.scale.data.resize_(val.shape)
else:
assert name == 'zero_point'
self.zero_point.data = self.zero_point.data.resize_(
val.shape)
# For torchscript module we need to update the attributes here
# since we do not call the `_load_from_state_dict` function
# defined module.py
if torch.jit.is_scripting():
if name == 'scale':
self.scale.copy_(val)
else:
assert name == 'zero_point'
self.zero_point.copy_(val)
elif strict:
missing_keys.append(key)
super(LearnableFakeQuantize,
self)._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys,
unexpected_keys, error_msgs)

@torch.jit.export
def extra_repr(self):
"""The printable representational string."""
Expand Down
14 changes: 6 additions & 8 deletions mmrazor/models/fake_quants/torch_fake_quants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
torch_fake_quant_src = get_package_placeholder('torch>=1.13')


# TORCH_fake_quants = register_torch_fake_quants()
# TORCH_fake_quants including:
# FakeQuantize
# FakeQuantizeBase
# FixedQParamsFakeQuantize
# FusedMovingAvgObsFakeQuantize
def register_torch_fake_quants() -> List[str]:
"""Register fake_quants in ``torch.ao.quantization.fake_quantize`` to the
``MODELS`` registry.
Expand All @@ -30,11 +36,3 @@ def register_torch_fake_quants() -> List[str]:
MODELS.register_module(module=_fake_quant)
torch_fake_quants.append(module_name)
return torch_fake_quants


TORCH_fake_quants = register_torch_fake_quants()
# TORCH_fake_quants including:
# FakeQuantize
# FakeQuantizeBase
# FixedQParamsFakeQuantize
# FusedMovingAvgObsFakeQuantize
30 changes: 14 additions & 16 deletions mmrazor/models/observers/torch_observers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ def reset_min_max_vals(self):
PerChannelMinMaxObserver.reset_min_max_vals = reset_min_max_vals


# TORCH_observers = register_torch_observers()
# TORCH_observers including:
# FixedQParamsObserver
# HistogramObserver
# MinMaxObserver
# MovingAverageMinMaxObserver
# MovingAveragePerChannelMinMaxObserver
# NoopObserver
# ObserverBase
# PerChannelMinMaxObserver
# PlaceholderObserver
# RecordingObserver
# ReuseInputObserver
# UniformQuantizationObserverBase
def register_torch_observers() -> List[str]:
"""Register observers in ``torch.ao.quantization.observer`` to the
``MODELS`` registry.
Expand All @@ -50,19 +64,3 @@ def register_torch_observers() -> List[str]:
MODELS.register_module(module=_observer)
torch_observers.append(module_name)
return torch_observers


TORCH_observers = register_torch_observers()
# TORCH_observers including:
# FixedQParamsObserver
# HistogramObserver
# MinMaxObserver
# MovingAverageMinMaxObserver
# MovingAveragePerChannelMinMaxObserver
# NoopObserver
# ObserverBase
# PerChannelMinMaxObserver
# PlaceholderObserver
# RecordingObserver
# ReuseInputObserver
# UniformQuantizationObserverBase
74 changes: 59 additions & 15 deletions tests/test_models/test_algorithms/test_mm_architecture.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import shutil
import tempfile
from unittest import TestCase, skip
from unittest import TestCase, skipIf

import torch
import torch.nn as nn
Expand All @@ -13,8 +14,15 @@
from mmrazor.utils import get_placeholder
GraphModule = get_placeholder('torch>=1.13')

from mmengine import ConfigDict
from mmengine.model import BaseModel

try:
import mmdeploy
except ImportError:
from mmrazor.utils import get_package_placeholder
mmdeploy = get_package_placeholder('mmdeploy')

from mmrazor import digit_version
from mmrazor.models.algorithms import MMArchitectureQuant
from mmrazor.registry import MODELS
Expand Down Expand Up @@ -101,12 +109,44 @@ def forward(self, inputs, data_samples, mode: str = 'tensor'):
return outputs


@skip
DEPLOY_CFG = ConfigDict(
onnx_config=dict(
type='onnx',
export_params=True,
keep_initializers_as_inputs=False,
opset_version=11,
save_file='end2end.onnx',
input_names=['input'],
output_names=['output'],
input_shape=None,
optimize=True,
dynamic_axes={
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'output': {
0: 'batch'
}
}),
backend_config=dict(
type='openvino',
model_inputs=[dict(opt_shapes=dict(input=[1, 3, 224, 224]))]),
codebase_config=dict(type='mmcls', task='Classification'),
function_record_to_pop=[
'mmcls.models.classifiers.ImageClassifier.forward',
'mmcls.models.classifiers.BaseClassifier.forward'
],
)


@skipIf(
digit_version(torch.__version__) < digit_version('1.13.0'),
'PyTorch version lower than 1.13.0 is not supported.')
class TestMMArchitectureQuant(TestCase):

def setUp(self):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')

MODELS.register_module(module=ToyQuantModel, force=True)

Expand All @@ -116,7 +156,7 @@ def setUp(self):
toymodel = ToyQuantModel()
torch.save(toymodel.state_dict(), filename)

global_qconfig = dict(
global_qconfig = ConfigDict(
w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'),
a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'),
w_fake_quant=dict(type='mmrazor.FakeQuantize'),
Expand All @@ -132,7 +172,7 @@ def setUp(self):
is_symmetry=True,
averaging_constant=0.1),
)
alg_kwargs = dict(
alg_kwargs = ConfigDict(
type='mmrazor.MMArchitectureQuant',
architecture=dict(type='ToyQuantModel'),
float_checkpoint=filename,
Expand All @@ -141,23 +181,23 @@ def setUp(self):
global_qconfig=global_qconfig,
tracer=dict(type='mmrazor.CustomTracer')))
self.alg_kwargs = alg_kwargs
self.toy_model = MODELS.build(self.alg_kwargs)

def tearDown(self):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')
MODELS.module_dict.pop('ToyQuantModel')
shutil.rmtree(self.temp_dir)

def test_init(self):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')
self.toy_model = MODELS.build(self.alg_kwargs)
assert isinstance(self.toy_model, MMArchitectureQuant)
assert hasattr(self.toy_model, 'quantizer')

alg_kwargs = copy.deepcopy(self.alg_kwargs)
alg_kwargs.deploy_cfg = DEPLOY_CFG
assert isinstance(self.toy_model, MMArchitectureQuant)
assert hasattr(self.toy_model, 'quantizer')

def test_sync_qparams(self):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')
self.toy_model = MODELS.build(self.alg_kwargs)
mode = self.toy_model.forward_modes[0]
self.toy_model.sync_qparams(mode)
w_loss = self.toy_model.qmodels[
Expand All @@ -170,12 +210,16 @@ def test_sync_qparams(self):
assert w_loss.equal(w_tensor)

def test_build_qmodels(self):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')
self.toy_model = MODELS.build(self.alg_kwargs)
for forward_modes in self.toy_model.forward_modes:
qmodels = self.toy_model.qmodels[forward_modes]
assert isinstance(qmodels, GraphModule)

def test_get_deploy_model(self):
self.toy_model = MODELS.build(self.alg_kwargs)
deploy_model = self.toy_model.get_deploy_model()
self.assertIsInstance(deploy_model, torch.fx.graph_module.GraphModule)

def test_calibrate_step(self):
# TODO
pass
Loading