diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 16b8aa61c..1aecaaeb4 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -265,7 +265,6 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None): if input_scale is not None: scale = input_scale scale_orig_shape = scale.shape - zero_point_orig_shape = zero_point.shape if input_bit_width is not None: bit_width = input_bit_width quant_axis = self.quant_axis(scale) @@ -279,7 +278,7 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None): y = self.dequantize_fn(int_bias, scale, zero_point, quant_axis) # Restore the original shapes to guarantee correct shape propagation downstream scale = scale.view(scale_orig_shape) - zero_point = zero_point.view(zero_point_orig_shape) + zero_point = zero_point.view_as(scale) return y, scale, zero_point, bit_width