diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index d729a3c2c..93eee216f 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -383,7 +383,7 @@ def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tens return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) else: threshold = self.restrict_preprocess(threshold) - value = self.restrict_scaling_impl.combine_stats_threshold(value, threshold) + value = self.restrict_scaling_impl.combine_stats_threshold(self.value, threshold) return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) @brevitas.jit.script_method