Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix (act_quant): flag to enable/disable stats collection #641

Merged
merged 2 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class calibration_mode:
def __init__(self, model, enabled=True):
self.model = model
self.previous_training_state = model.training
self.disable_quant_inference = DisableEnableQuantization()
self.disable_quant_inference = DisableEnableQuantization(call_act_quantizer_impl=True)
self.enabled = enabled

def __enter__(self):
Expand Down Expand Up @@ -111,9 +111,10 @@ def apply(self, model):

class DisableEnableQuantization(Transform):

def __init__(self):
def __init__(self, call_act_quantizer_impl=False):
super(DisableEnableQuantization, self).__init__()
self.disable_act_quant_hooks = []
self.call_act_quantizer_impl = call_act_quantizer_impl

def unpack_input(self, inp):
if isinstance(inp, tuple):
Expand All @@ -136,11 +137,17 @@ def disable_act_quant_hook(self, module, inp, output):
return QuantTensor(value=inp, training=module.training)

def disable_act_quantization(self, model, is_training):
# If self.call_act_quantizer_impl is set to True, the quantization will be performed but the output
# will be discarded through the hook. It is useful for collecting activation stats,
# for example during activation calibration in PTQ
for module in model.modules():
if isinstance(module, ActQuantProxyFromInjector):
hook = module.register_forward_hook(self.disable_act_quant_hook)
module.train(is_training)
self.disable_act_quant_hooks.append(hook)
if self.call_act_quantizer_impl:
hook = module.register_forward_hook(self.disable_act_quant_hook)
self.disable_act_quant_hooks.append(hook)
else:
module.disable_quant = True
elif isinstance(module, _ACC_PROXIES):
module.train(is_training)
module.disable_quant = True
Expand All @@ -163,6 +170,7 @@ def enable_act_quantization(self, model, is_training):
module.train(is_training)
module.disable_quant = False
elif isinstance(module, ActQuantProxyFromInjector):
module.disable_quant = False
module.train(is_training)
for hook in self.disable_act_quant_hooks:
hook.remove()
Expand Down
7 changes: 6 additions & 1 deletion src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self, quant_layer, quant_injector):

@property
def is_quant_enabled(self):
return self._is_quant_enabled
return self._is_quant_enabled and not self.disable_quant

@is_quant_enabled.setter
def is_quant_enabled(self, is_quant_enabled):
Expand Down Expand Up @@ -142,9 +142,12 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor:
y = x
if isinstance(y, QuantTensor):
y = y.value

if self.export_mode:
y = self.fused_activation_quant_proxy.activation_impl(y)
y = self.export_handler(y)
elif not self.is_quant_enabled:
y = self.fused_activation_quant_proxy.activation_impl(y)
else:
y = self.fused_activation_quant_proxy(y)
# If y is an empty QuantTensor, we need to check if this is a passthrough proxy,
Expand All @@ -156,6 +159,8 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor:
y = y[0]
return QuantTensor(y, x.scale, x.zero_point, x.bit_width, x.signed, self.training)
else:
if isinstance(y, tuple):
y = y[0]
return QuantTensor(y, training=self.training)
else:
if isinstance(x, QuantTensor): # passthrough
Expand Down