From 8d5035bd77b91dd19627e805e7d57f59c67e225c Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 6 Jul 2023 10:57:59 +0200 Subject: [PATCH] Feat (graph): flag to enable/disable stats collection with disabled quant (#641) * Fix (act_quant): flag to enable/disable stats collection * review --- src/brevitas/graph/calibrate.py | 16 ++++++++++++---- src/brevitas/proxy/runtime_quant.py | 7 ++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 614bcf671..d206e016a 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -62,7 +62,7 @@ class calibration_mode: def __init__(self, model, enabled=True): self.model = model self.previous_training_state = model.training - self.disable_quant_inference = DisableEnableQuantization() + self.disable_quant_inference = DisableEnableQuantization(call_act_quantizer_impl=True) self.enabled = enabled def __enter__(self): @@ -111,9 +111,10 @@ def apply(self, model): class DisableEnableQuantization(Transform): - def __init__(self): + def __init__(self, call_act_quantizer_impl=False): super(DisableEnableQuantization, self).__init__() self.disable_act_quant_hooks = [] + self.call_act_quantizer_impl = call_act_quantizer_impl def unpack_input(self, inp): if isinstance(inp, tuple): @@ -136,11 +137,17 @@ def disable_act_quant_hook(self, module, inp, output): return QuantTensor(value=inp, training=module.training) def disable_act_quantization(self, model, is_training): + # If self.call_act_quantizer_impl is set to True, the quantization will be performed but the output + # will be discarded through the hook. It is useful for collecting activation stats, + # for example during activation calibration in PTQ for module in model.modules(): if isinstance(module, ActQuantProxyFromInjector): - hook = module.register_forward_hook(self.disable_act_quant_hook) module.train(is_training) - self.disable_act_quant_hooks.append(hook) + if self.call_act_quantizer_impl: + hook = module.register_forward_hook(self.disable_act_quant_hook) + self.disable_act_quant_hooks.append(hook) + else: + module.disable_quant = True elif isinstance(module, _ACC_PROXIES): module.train(is_training) module.disable_quant = True @@ -163,6 +170,7 @@ def enable_act_quantization(self, model, is_training): module.train(is_training) module.disable_quant = False elif isinstance(module, ActQuantProxyFromInjector): + module.disable_quant = False module.train(is_training) for hook in self.disable_act_quant_hooks: hook.remove() diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 6566580fe..a650a7755 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -92,7 +92,7 @@ def __init__(self, quant_layer, quant_injector): @property def is_quant_enabled(self): - return self._is_quant_enabled + return self._is_quant_enabled and not self.disable_quant @is_quant_enabled.setter def is_quant_enabled(self, is_quant_enabled): @@ -142,9 +142,12 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor: y = x if isinstance(y, QuantTensor): y = y.value + if self.export_mode: y = self.fused_activation_quant_proxy.activation_impl(y) y = self.export_handler(y) + elif not self.is_quant_enabled: + y = self.fused_activation_quant_proxy.activation_impl(y) else: y = self.fused_activation_quant_proxy(y) # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, @@ -156,6 +159,8 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor: y = y[0] return QuantTensor(y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) else: + if isinstance(y, tuple): + y = y[0] return QuantTensor(y, training=self.training) else: if isinstance(x, QuantTensor): # passthrough