From db161270a1d94ff1251be4ff0d4d3bed642d484c Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo Date: Mon, 19 Jun 2023 12:11:15 +0200 Subject: [PATCH] Fix bias qcdq export --- src/brevitas/export/common/handler/qcdq.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 16b8aa61c..1aecaaeb4 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -265,7 +265,6 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None): if input_scale is not None: scale = input_scale scale_orig_shape = scale.shape - zero_point_orig_shape = zero_point.shape if input_bit_width is not None: bit_width = input_bit_width quant_axis = self.quant_axis(scale) @@ -279,7 +278,7 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None): y = self.dequantize_fn(int_bias, scale, zero_point, quant_axis) # Restore the original shapes to guarantee correct shape propagation downstream scale = scale.view(scale_orig_shape) - zero_point = zero_point.view(zero_point_orig_shape) + zero_point = zero_point.view_as(scale) return y, scale, zero_point, bit_width