diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index ff1898db2..d8bb253bb 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -406,15 +406,16 @@ def __sub__(self, other): def __truediv__(self, other): if isinstance(other, QuantTensor) and self.is_not_none and other.is_not_none: - output_tensor = self.value / other.tensor - output_scale = self.scale / other.scale - output_bit_width = self.bit_width - other.bit_width + output_tensor = self.value / other.value # Note, output tensor not guaranteed to pass self.is_valid() + max_int_denominator = 2 ** (other.bit_width - int(other.signed)) + output_scale = self.scale / (other.scale * max_int_denominator) + output_bit_width = self.bit_width + other.bit_width output_signed = self.signed or other.signed output_training = self.training or other.training if self.is_zero_zero_point(self) and self.is_zero_zero_point(other): output_zero_point = self.zero_point * other.zero_point # Output zero_point is a new, zero-valued tensor else: - output_zero_point = None # TODO non-zero zero point + raise RuntimeError("Zero-points of div operands are non-zero, not supported.") output = QuantTensor( value=output_tensor, scale=output_scale,