Skip to content

Commit

Permalink
Feat (graph): flag to enable/disable stats collection with disabled q…
Browse files Browse the repository at this point in the history
…uant (#641)

* Fix (act_quant): flag to enable/disable stats collection

* review
  • Loading branch information
Giuseppe5 authored Jul 6, 2023
1 parent e99ae57 commit 8d5035b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
16 changes: 12 additions & 4 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class calibration_mode:
def __init__(self, model, enabled=True):
self.model = model
self.previous_training_state = model.training
self.disable_quant_inference = DisableEnableQuantization()
self.disable_quant_inference = DisableEnableQuantization(call_act_quantizer_impl=True)
self.enabled = enabled

def __enter__(self):
Expand Down Expand Up @@ -111,9 +111,10 @@ def apply(self, model):

class DisableEnableQuantization(Transform):

def __init__(self):
def __init__(self, call_act_quantizer_impl=False):
super(DisableEnableQuantization, self).__init__()
self.disable_act_quant_hooks = []
self.call_act_quantizer_impl = call_act_quantizer_impl

def unpack_input(self, inp):
if isinstance(inp, tuple):
Expand All @@ -136,11 +137,17 @@ def disable_act_quant_hook(self, module, inp, output):
return QuantTensor(value=inp, training=module.training)

def disable_act_quantization(self, model, is_training):
# If self.call_act_quantizer_impl is set to True, the quantization will be performed but the output
# will be discarded through the hook. It is useful for collecting activation stats,
# for example during activation calibration in PTQ
for module in model.modules():
if isinstance(module, ActQuantProxyFromInjector):
hook = module.register_forward_hook(self.disable_act_quant_hook)
module.train(is_training)
self.disable_act_quant_hooks.append(hook)
if self.call_act_quantizer_impl:
hook = module.register_forward_hook(self.disable_act_quant_hook)
self.disable_act_quant_hooks.append(hook)
else:
module.disable_quant = True
elif isinstance(module, _ACC_PROXIES):
module.train(is_training)
module.disable_quant = True
Expand All @@ -163,6 +170,7 @@ def enable_act_quantization(self, model, is_training):
module.train(is_training)
module.disable_quant = False
elif isinstance(module, ActQuantProxyFromInjector):
module.disable_quant = False
module.train(is_training)
for hook in self.disable_act_quant_hooks:
hook.remove()
Expand Down
7 changes: 6 additions & 1 deletion src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self, quant_layer, quant_injector):

@property
def is_quant_enabled(self):
return self._is_quant_enabled
return self._is_quant_enabled and not self.disable_quant

@is_quant_enabled.setter
def is_quant_enabled(self, is_quant_enabled):
Expand Down Expand Up @@ -142,9 +142,12 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor:
y = x
if isinstance(y, QuantTensor):
y = y.value

if self.export_mode:
y = self.fused_activation_quant_proxy.activation_impl(y)
y = self.export_handler(y)
elif not self.is_quant_enabled:
y = self.fused_activation_quant_proxy.activation_impl(y)
else:
y = self.fused_activation_quant_proxy(y)
# If y is an empty QuantTensor, we need to check if this is a passthrough proxy,
Expand All @@ -156,6 +159,8 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor:
y = y[0]
return QuantTensor(y, x.scale, x.zero_point, x.bit_width, x.signed, self.training)
else:
if isinstance(y, tuple):
y = y[0]
return QuantTensor(y, training=self.training)
else:
if isinstance(x, QuantTensor): # passthrough
Expand Down

0 comments on commit 8d5035b

Please # to comment.