From 181ef8077314db4ea97848bf5f09d060255c614f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 18 Jun 2024 10:17:46 +0100 Subject: [PATCH 1/2] Fix (core/float): add default for float_scaling_impl --- src/brevitas/core/quant/float.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 71f518bb5..5b582e195 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -52,6 +52,9 @@ def __init__( if scaling_impl is None: scaling_impl = ConstScaling(1., device=device, dtype=dtype) + if float_scaling_impl is None: + float_scaling_impl = ConstScaling(1., device=device, dtype=dtype) + # Zero-point is currently hardcoded to 0 self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype)) self.float_scaling_impl = float_scaling_impl From 70afcc310b7efee7d5eee5e1b6af09da66d8533b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 18 Jun 2024 10:30:33 +0100 Subject: [PATCH 2/2] fix --- src/brevitas/core/quant/float.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 5b582e195..195d42a96 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -52,9 +52,6 @@ def __init__( if scaling_impl is None: scaling_impl = ConstScaling(1., device=device, dtype=dtype) - if float_scaling_impl is None: - float_scaling_impl = ConstScaling(1., device=device, dtype=dtype) - # Zero-point is currently hardcoded to 0 self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype)) self.float_scaling_impl = float_scaling_impl @@ -68,10 +65,13 @@ def __init__( @brevitas.jit.script_method def quantize(self, x: torch.Tensor): - scaling_impl_value = self.scaling_impl(x) - float_scaling_impl_value = self.float_scaling_impl( - self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) - scale = scaling_impl_value / float_scaling_impl_value + scale = self.scaling_impl(x) + + if self.float_scaling_impl is not None: + float_scaling_impl_value = self.float_scaling_impl( + self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) + scale = scale / float_scaling_impl_value + scaled_x = x / scale internal_scale = float_internal_scale( scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps)