From 45d4a9e75851d4bbe6530e015a1c3fbbbf4946ae Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 23 Jun 2023 14:16:52 +0100 Subject: [PATCH 1/2] Fix (act_quant): flag to enable/disable stats collection --- src/brevitas/graph/calibrate.py | 15 +++++++++++---- src/brevitas/proxy/runtime_quant.py | 7 ++++++- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 614bcf671..7c4566808 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -135,12 +135,18 @@ def disable_act_quant_hook(self, module, inp, output): inp, min_val=module.quant_injector.min_val, max_val=module.quant_injector.max_val) return QuantTensor(value=inp, training=module.training) - def disable_act_quantization(self, model, is_training): + def disable_act_quantization(self, model, is_training, call_quantizer_impl=False): + # If call_quantizer_impl is set to True, the quantization will be performed but the output + # will be discarded through the hook. It is useful for collecting activation stats, + # for example during activation calibration in PTQ for module in model.modules(): if isinstance(module, ActQuantProxyFromInjector): - hook = module.register_forward_hook(self.disable_act_quant_hook) module.train(is_training) - self.disable_act_quant_hooks.append(hook) + if call_quantizer_impl: + hook = module.register_forward_hook(self.disable_act_quant_hook) + self.disable_act_quant_hooks.append(hook) + else: + module.disable_quant = True elif isinstance(module, _ACC_PROXIES): module.train(is_training) module.disable_quant = True @@ -163,6 +169,7 @@ def enable_act_quantization(self, model, is_training): module.train(is_training) module.disable_quant = False elif isinstance(module, ActQuantProxyFromInjector): + module.disable_quant = False module.train(is_training) for hook in self.disable_act_quant_hooks: hook.remove() @@ -182,7 +189,7 @@ def enable_bias_quantization(self, model, is_training): def apply(self, model, is_training, quantization_enabled): if not quantization_enabled: - self.disable_act_quantization(model, is_training) + self.disable_act_quantization(model, is_training, call_quantizer_impl=True) self.disable_param_quantization(model, is_training) else: self.enable_act_quantization(model, is_training) diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 6566580fe..a650a7755 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -92,7 +92,7 @@ def __init__(self, quant_layer, quant_injector): @property def is_quant_enabled(self): - return self._is_quant_enabled + return self._is_quant_enabled and not self.disable_quant @is_quant_enabled.setter def is_quant_enabled(self, is_quant_enabled): @@ -142,9 +142,12 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor: y = x if isinstance(y, QuantTensor): y = y.value + if self.export_mode: y = self.fused_activation_quant_proxy.activation_impl(y) y = self.export_handler(y) + elif not self.is_quant_enabled: + y = self.fused_activation_quant_proxy.activation_impl(y) else: y = self.fused_activation_quant_proxy(y) # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, @@ -156,6 +159,8 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor: y = y[0] return QuantTensor(y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) else: + if isinstance(y, tuple): + y = y[0] return QuantTensor(y, training=self.training) else: if isinstance(x, QuantTensor): # passthrough From 5ae46de3a35f918ce7569706bff203d61cfa6211 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 5 Jul 2023 10:03:18 +0100 Subject: [PATCH 2/2] review --- src/brevitas/graph/calibrate.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 7c4566808..d206e016a 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -62,7 +62,7 @@ class calibration_mode: def __init__(self, model, enabled=True): self.model = model self.previous_training_state = model.training - self.disable_quant_inference = DisableEnableQuantization() + self.disable_quant_inference = DisableEnableQuantization(call_act_quantizer_impl=True) self.enabled = enabled def __enter__(self): @@ -111,9 +111,10 @@ def apply(self, model): class DisableEnableQuantization(Transform): - def __init__(self): + def __init__(self, call_act_quantizer_impl=False): super(DisableEnableQuantization, self).__init__() self.disable_act_quant_hooks = [] + self.call_act_quantizer_impl = call_act_quantizer_impl def unpack_input(self, inp): if isinstance(inp, tuple): @@ -135,14 +136,14 @@ def disable_act_quant_hook(self, module, inp, output): inp, min_val=module.quant_injector.min_val, max_val=module.quant_injector.max_val) return QuantTensor(value=inp, training=module.training) - def disable_act_quantization(self, model, is_training, call_quantizer_impl=False): - # If call_quantizer_impl is set to True, the quantization will be performed but the output + def disable_act_quantization(self, model, is_training): + # If self.call_act_quantizer_impl is set to True, the quantization will be performed but the output # will be discarded through the hook. It is useful for collecting activation stats, # for example during activation calibration in PTQ for module in model.modules(): if isinstance(module, ActQuantProxyFromInjector): module.train(is_training) - if call_quantizer_impl: + if self.call_act_quantizer_impl: hook = module.register_forward_hook(self.disable_act_quant_hook) self.disable_act_quant_hooks.append(hook) else: @@ -189,7 +190,7 @@ def enable_bias_quantization(self, model, is_training): def apply(self, model, is_training, quantization_enabled): if not quantization_enabled: - self.disable_act_quantization(model, is_training, call_quantizer_impl=True) + self.disable_act_quantization(model, is_training) self.disable_param_quantization(model, is_training) else: self.enable_act_quantization(model, is_training)