diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index e78f7aa87..e48cfd6e5 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -295,7 +295,7 @@ def cat(tensors, dim, out=None): else: tensors = [qt.value if isinstance(qt, QuantTensor) else qt for qt in tensors] output_value = torch.cat(tensors, dim=dim) - return QuantTensor(output_value) + return output_value # Reference: https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types @@ -366,9 +366,9 @@ def __add__(self, other): signed=output_signed, training=output_training) elif isinstance(other, QuantTensor): - output = QuantTensor(self.value + other.value) + output = self.value + other.value else: - output = QuantTensor(self.value + other) + output = self.value + other return output def __radd__(self, other): @@ -396,9 +396,9 @@ def __mul__(self, other): signed=output_signed, training=output_training) elif isinstance(other, QuantTensor): - output = QuantTensor(self.value * other.value) + output = self.value * other.value else: - output = QuantTensor(self.value * other) + output = self.value * other return output def __sub__(self, other): @@ -423,9 +423,9 @@ def __truediv__(self, other): signed=output_signed, training=output_training) elif isinstance(other, QuantTensor): - output = QuantTensor(self.value / other.value) + output = self.value / other.value else: - output = QuantTensor(self.value / other) + output = self.value / other return output def __abs__(self):