From d0c10a521026e1d50bc96878ed3f1c01398eb5f3 Mon Sep 17 00:00:00 2001 From: nickfraser Date: Thu, 21 Dec 2023 10:58:39 +0000 Subject: [PATCH] Feat (quant_tensor): update `__truediv__` behaviour to match "standard fixed point rules" (#769) * [quant_tensor] Updated `__truediv__` behaviour based on #740 * [quant_tensor] Updated div behaviour to throw RuntimeError when non-zero zero-point operands are used * Fix: changed other.tensor -> other.value --- src/brevitas/quant_tensor/__init__.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index ff1898db2..d8bb253bb 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -406,15 +406,16 @@ def __sub__(self, other): def __truediv__(self, other): if isinstance(other, QuantTensor) and self.is_not_none and other.is_not_none: - output_tensor = self.value / other.tensor - output_scale = self.scale / other.scale - output_bit_width = self.bit_width - other.bit_width + output_tensor = self.value / other.value # Note, output tensor not guaranteed to pass self.is_valid() + max_int_denominator = 2 ** (other.bit_width - int(other.signed)) + output_scale = self.scale / (other.scale * max_int_denominator) + output_bit_width = self.bit_width + other.bit_width output_signed = self.signed or other.signed output_training = self.training or other.training if self.is_zero_zero_point(self) and self.is_zero_zero_point(other): output_zero_point = self.zero_point * other.zero_point # Output zero_point is a new, zero-valued tensor else: - output_zero_point = None # TODO non-zero zero point + raise RuntimeError("Zero-points of div operands are non-zero, not supported.") output = QuantTensor( value=output_tensor, scale=output_scale,