Skip to content

Commit

Permalink
Fix (act_quant): flag to enable/disable stats collection
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jun 27, 2023
1 parent 0a55295 commit 45d4a9e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
15 changes: 11 additions & 4 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,18 @@ def disable_act_quant_hook(self, module, inp, output):
inp, min_val=module.quant_injector.min_val, max_val=module.quant_injector.max_val)
return QuantTensor(value=inp, training=module.training)

def disable_act_quantization(self, model, is_training):
def disable_act_quantization(self, model, is_training, call_quantizer_impl=False):
# If call_quantizer_impl is set to True, the quantization will be performed but the output
# will be discarded through the hook. It is useful for collecting activation stats,
# for example during activation calibration in PTQ
for module in model.modules():
if isinstance(module, ActQuantProxyFromInjector):
hook = module.register_forward_hook(self.disable_act_quant_hook)
module.train(is_training)
self.disable_act_quant_hooks.append(hook)
if call_quantizer_impl:
hook = module.register_forward_hook(self.disable_act_quant_hook)
self.disable_act_quant_hooks.append(hook)
else:
module.disable_quant = True
elif isinstance(module, _ACC_PROXIES):
module.train(is_training)
module.disable_quant = True
Expand All @@ -163,6 +169,7 @@ def enable_act_quantization(self, model, is_training):
module.train(is_training)
module.disable_quant = False
elif isinstance(module, ActQuantProxyFromInjector):
module.disable_quant = False
module.train(is_training)
for hook in self.disable_act_quant_hooks:
hook.remove()
Expand All @@ -182,7 +189,7 @@ def enable_bias_quantization(self, model, is_training):

def apply(self, model, is_training, quantization_enabled):
if not quantization_enabled:
self.disable_act_quantization(model, is_training)
self.disable_act_quantization(model, is_training, call_quantizer_impl=True)
self.disable_param_quantization(model, is_training)
else:
self.enable_act_quantization(model, is_training)
Expand Down
7 changes: 6 additions & 1 deletion src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self, quant_layer, quant_injector):

@property
def is_quant_enabled(self):
return self._is_quant_enabled
return self._is_quant_enabled and not self.disable_quant

@is_quant_enabled.setter
def is_quant_enabled(self, is_quant_enabled):
Expand Down Expand Up @@ -142,9 +142,12 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor:
y = x
if isinstance(y, QuantTensor):
y = y.value

if self.export_mode:
y = self.fused_activation_quant_proxy.activation_impl(y)
y = self.export_handler(y)
elif not self.is_quant_enabled:
y = self.fused_activation_quant_proxy.activation_impl(y)
else:
y = self.fused_activation_quant_proxy(y)
# If y is an empty QuantTensor, we need to check if this is a passthrough proxy,
Expand All @@ -156,6 +159,8 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor:
y = y[0]
return QuantTensor(y, x.scale, x.zero_point, x.bit_width, x.signed, self.training)
else:
if isinstance(y, tuple):
y = y[0]
return QuantTensor(y, training=self.training)
else:
if isinstance(x, QuantTensor): # passthrough
Expand Down

0 comments on commit 45d4a9e

Please # to comment.