From 698c2fbbdc43acc47ef13cbfa080b5604e3df9c3 Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo Date: Mon, 17 Jul 2023 15:45:02 +0100 Subject: [PATCH] Feat (QuantTensor): QuantTensor x Tensor elementary ops dequantize to Tensor --- src/brevitas/quant_tensor/__init__.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index e78f7aa87..e48cfd6e5 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -295,7 +295,7 @@ def cat(tensors, dim, out=None): else: tensors = [qt.value if isinstance(qt, QuantTensor) else qt for qt in tensors] output_value = torch.cat(tensors, dim=dim) - return QuantTensor(output_value) + return output_value # Reference: https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types @@ -366,9 +366,9 @@ def __add__(self, other): signed=output_signed, training=output_training) elif isinstance(other, QuantTensor): - output = QuantTensor(self.value + other.value) + output = self.value + other.value else: - output = QuantTensor(self.value + other) + output = self.value + other return output def __radd__(self, other): @@ -396,9 +396,9 @@ def __mul__(self, other): signed=output_signed, training=output_training) elif isinstance(other, QuantTensor): - output = QuantTensor(self.value * other.value) + output = self.value * other.value else: - output = QuantTensor(self.value * other) + output = self.value * other return output def __sub__(self, other): @@ -423,9 +423,9 @@ def __truediv__(self, other): signed=output_signed, training=output_training) elif isinstance(other, QuantTensor): - output = QuantTensor(self.value / other.value) + output = self.value / other.value else: - output = QuantTensor(self.value / other) + output = self.value / other return output def __abs__(self):