diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 55f787e26..a6c44f144 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -60,7 +60,8 @@ def __init__( inplace: bool = True, use_quant_activations: bool = True, num_blocks: int = 100, - act_order: bool = False) -> None: + act_order: bool = False, + return_forward_output: bool = False) -> None: if not inplace: model = deepcopy(model) self.model = model @@ -80,6 +81,7 @@ def __init__( self.orig_forward = self.model.forward self.model.forward = self.catch_stopfwd self.group_of_parallel_layers = group_of_parallel_layers + self.return_forward_output = return_forward_output def _is_module_supported(self, module): if isinstance(module, SUPPORTED_CONV_OP): @@ -158,6 +160,15 @@ def catch_stopfwd(self, *args, **kwargs): self.orig_forward(*args, **kwargs) except StopFwdException: pass + finally: + if self.return_forward_output: + # If we want to return the output of the network, we need to disable all hooks + for name, gptq_class in self.gptq_layers.items(): + gptq_class.disable_pre_forward_hook = True + out = self.orig_forward(*args, **kwargs) + for name, gptq_class in self.gptq_layers.items(): + gptq_class.disable_pre_forward_hook = False + return out class GPTQ(): @@ -211,7 +222,11 @@ def __init__(self, layer, name, num_blocks, act_order, parallel_layers=1) -> Non self.nsamples = 0 self.parallel_layers = parallel_layers + self.disable_pre_forward_hook = False + def update_batch(self, module, input, current_layer): + if self.disable_pre_forward_hook: + return input # Update reference to current layer current_layer.layer_names.add(self.name)