diff --git a/src/brevitas/config.py b/src/brevitas/config.py index 082a6508b..a5685721c 100644 --- a/src/brevitas/config.py +++ b/src/brevitas/config.py @@ -25,4 +25,3 @@ def env_to_bool(name, default): _FULL_STATE_DICT = False _IS_INSIDE_QUANT_LAYER = None _ONGOING_EXPORT = None -_RETROCOMPATIBLE_SCALING = False diff --git a/src/brevitas/core/restrict_val.py b/src/brevitas/core/restrict_val.py index 0720e595e..aa9a4fa2e 100644 --- a/src/brevitas/core/restrict_val.py +++ b/src/brevitas/core/restrict_val.py @@ -90,8 +90,8 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() - def retrocompatibility_op(self, x): - return x + def combine_stats_threshold(self, x, threshold): + return x / threshold @brevitas.jit.script_method def forward(self, x: torch.Tensor) -> Tensor: @@ -116,8 +116,8 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() - def retrocompatibility_op(self, x): - return self.power_of_two(x) + def combine_stats_threshold(self, x, threshold): + return x - threshold @brevitas.jit.script_method def forward(self, x: torch.Tensor): @@ -143,8 +143,8 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() - def retrocompatibility_op(self, x): - return x + def combine_stats_threshold(self, x, threshold): + return x / threshold @brevitas.jit.script_method def forward(self, x: torch.Tensor): @@ -171,8 +171,8 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() - def retrocompatibility_op(self, x): - return self.power_of_two(x) + def combine_stats_threshold(self, x, threshold): + return x - threshold @brevitas.jit.script_method def forward(self, x: torch.Tensor): diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index 51abdc1a6..3b50142d9 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -81,13 +81,16 @@ def __init__( self.affine_rescaling = Identity() self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() + self.restrict_scaling_impl = restrict_scaling_impl @brevitas.jit.script_method def forward( self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: if threshold is None: threshold = torch.ones(1).type_as(stats) - stats = self.restrict_scaling_pre(stats / threshold) + threshold = self.restrict_scaling_pre(threshold) + stats = self.restrict_scaling_pre(stats) + stats = self.restrict_scaling_impl.combine_stats_threshold(stats, threshold) stats = self.affine_rescaling(stats) stats = self.restrict_clamp_scaling(stats) return stats diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 50e896466..d729a3c2c 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -70,18 +70,27 @@ def __init__( scaling_init = scaling_init.to(device=device, dtype=dtype) if restrict_scaling_impl is not None: scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) + self.restrict_init_module = restrict_scaling_impl.restrict_init_module() + else: + self.restrict_init_module = Identity() self.value = StatelessBuffer(scaling_init.detach()) else: if restrict_scaling_impl is not None: scaling_init = restrict_scaling_impl.restrict_init_float(scaling_init) + self.restrict_init_module = restrict_scaling_impl.restrict_init_module() + else: + self.restrict_init_module = Identity() self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device)) @brevitas.jit.script_method def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor: if threshold is None: threshold = torch.ones(1).type_as(placeholder) - value = self.value() / threshold - restricted_value = self.restrict_clamp_scaling(value) + # We first apply any restriction to scaling + # For IntQuant, this is no-op, retrocompatible. + threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold)) + restricted_value = self.restrict_clamp_scaling(self.value()) + restricted_value = restricted_value / threshold return restricted_value @@ -145,6 +154,9 @@ def __init__( scaling_init = torch.tensor(scaling_init, dtype=dtype, device=device) if restrict_scaling_impl is not None: scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) + self.restrict_init_module = restrict_scaling_impl.restrict_init_module() + else: + self.restrict_init_module = Identity() if scaling_init.shape == SCALAR_SHAPE and scaling_shape is not None: scaling_init = torch.full(scaling_shape, scaling_init, dtype=dtype, device=device) self.value = Parameter(scaling_init) @@ -154,8 +166,11 @@ def __init__( def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor: if threshold is None: threshold = torch.ones(1).type_as(placeholder) - value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value) / threshold) - return value + # We first apply any restriction to scaling + # For IntQuant, this is no-op, retrocompatible. + threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold)) + value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) + return value / threshold def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, @@ -217,7 +232,8 @@ def forward( # This is because we don't want to store a parameter dependant on a runtime value (threshold) # And because restrict needs to happen after we divide by threshold if self.init_done: - value = self.restrict_preprocess(self.value / threshold) + value = self.restrict_scaling_impl.combine_stats_threshold( + self.value, self.restrict_inplace_preprocess(threshold)) value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) return value else: @@ -225,10 +241,12 @@ def forward( # workaround to avoid find_ununsed_parameter=True in DDP stats = stats + 0. * self.value if self.local_loss_mode: - return self.stats_scaling_impl(stats, threshold) + return self.stats_scaling_impl(stats) + stats = self.restrict_inplace_preprocess(stats) + threshold = self.restrict_inplace_preprocess(threshold) inplace_tensor_mul(self.value.detach(), stats) - value = self.restrict_preprocess(self.value / threshold) - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + value = self.restrict_scaling_impl.combine_stats_threshold(stats, threshold) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) self.init_done = True return value @@ -245,14 +263,6 @@ def _load_from_state_dict( error_msgs): value_key = prefix + 'value' - # Before, the parameter would be stored after restrict_preprocess (e.g., Log2) - # When we load, if retrocompatibility is enabled, we perform the opposite operation (e.g., Po2) - # Knowing that during the forward pass we will re-apply restrict_preprocess (e.g., again Log2) - if config._RETROCOMPATIBLE_SCALING: - if not isinstance(self.restrict_scaling_impl, Identity): - state_dict[value_key] = self.restrict_scaling_impl.retrocompatibility_op( - state_dict[value_key]) - super(ParameterFromStatsFromParameterScaling, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) # disable stats collection when a pretrained value is loaded @@ -365,12 +375,15 @@ def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tens self.counter = new_counter return abs_binary_sign_grad(clamped_stats / threshold) elif self.counter == self.collect_stats_steps: + self.restrict_inplace_preprocess(self.buffer) inplace_tensor_mul(self.value.detach(), self.buffer) - value = self.restrict_preprocess(self.value / threshold) + threshold = self.restrict_preprocess(threshold) + value = self.restrict_scaling_impl.combine_stats_threshold(self.value, threshold) self.counter = self.counter + 1 return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) else: - value = self.restrict_preprocess(self.value / threshold) + threshold = self.restrict_preprocess(threshold) + value = self.restrict_scaling_impl.combine_stats_threshold(value, threshold) return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) @brevitas.jit.script_method @@ -385,7 +398,8 @@ def forward(self, stats_input: Tensor, threshold: Optional[torch.Tensor] = None) out = self.buffer / threshold out = self.restrict_preprocess(out) else: - out = self.restrict_preprocess(self.value / threshold) + threshold = self.restrict_preprocess(threshold) + out = self.restrict_scaling_impl.combine_stats_threshold(self.value, threshold) out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out))) return out @@ -411,14 +425,6 @@ def _load_from_state_dict( if retrocomp_value_key in state_dict: state_dict[value_key] = state_dict.pop(retrocomp_value_key) - # Before, the parameter would be stored after restrict_preprocess (e.g., Log2) - # When we load, if retrocompatibility is enabled, we perform the opposite operation (e.g., Po2) - # Knowing that during the forward pass we will re-apply restrict_preprocess (e.g., again Log2) - if config._RETROCOMPATIBLE_SCALING: - if not isinstance(self.restrict_scaling_impl, Identity): - state_dict[value_key] = self.restrict_scaling_impl.retrocompatibility_op( - state_dict[value_key]) - super(ParameterFromRuntimeStatsScaling, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) # Buffer is supposed to be always missing diff --git a/tests/brevitas_finn/brevitas_examples/test_quartznet_finn_export.py b/tests/brevitas_finn/brevitas_examples/test_quartznet_finn_export.py index c72a2d1f8..1e774eed0 100644 --- a/tests/brevitas_finn/brevitas_examples/test_quartznet_finn_export.py +++ b/tests/brevitas_finn/brevitas_examples/test_quartznet_finn_export.py @@ -16,7 +16,6 @@ from brevitas.export import export_qonnx from brevitas_examples.speech_to_text import quant_quartznet_perchannelscaling_4b -config._RETROCOMPATIBLE_SCALING = True QUARTZNET_POSTPROCESSED_INPUT_SIZE = (1, 64, 256) # B, features, sequence MIN_INP_VAL = 0.0 MAX_INP_VAL = 200.0