Skip to content

Commit

Permalink
Fix scalar weight scale shape
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius committed Jun 19, 2023
1 parent dbd5ef8 commit bb6cc13
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/brevitas/nn/quant_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def quant_output_scale_impl(
if quant_input_scale.shape == ():
input_broadcast_shape = tuple([1] * len(inp.size()))
quant_input_scale = quant_input_scale.view(input_broadcast_shape)
if quant_weight_scale.shape == ():
weight_broadcast_shape = tuple([1] * len(self.weight.size()))
quant_weight_scale = quant_weight_scale.view(weight_broadcast_shape)
quant_output_scale = linear(quant_input_scale, quant_weight_scale)
return quant_output_scale

Expand Down

0 comments on commit bb6cc13

Please # to comment.