diff --git a/src/brevitas/nn/quant_linear.py b/src/brevitas/nn/quant_linear.py index cb26a5625..576529a32 100644 --- a/src/brevitas/nn/quant_linear.py +++ b/src/brevitas/nn/quant_linear.py @@ -74,6 +74,9 @@ def quant_output_scale_impl( if quant_input_scale.shape == (): input_broadcast_shape = tuple([1] * len(inp.size())) quant_input_scale = quant_input_scale.view(input_broadcast_shape) + if quant_weight_scale.shape == (): + weight_broadcast_shape = tuple([1] * len(self.weight.size())) + quant_weight_scale = quant_weight_scale.view(weight_broadcast_shape) quant_output_scale = linear(quant_input_scale, quant_weight_scale) return quant_output_scale