diff --git a/src/brevitas/quant_tensor/torch_handler.py b/src/brevitas/quant_tensor/torch_handler.py index fcbe35a42..79934864f 100644 --- a/src/brevitas/quant_tensor/torch_handler.py +++ b/src/brevitas/quant_tensor/torch_handler.py @@ -3,12 +3,13 @@ import functools import math +from typing import Callable import warnings import torch +from torch import Tensor import torch.nn.functional as F -import brevitas from brevitas.function.ops import max_int from brevitas.function.ops_ste import ceil_ste from brevitas.utils.torch_utils import compute_channel_view_shape @@ -358,11 +359,16 @@ def create_quant_tensor(tensor, scale, bit_width, zero_point, signed, training): training=training) -def quant_output_scale_impl(fn, inp, quant_input_scale, quant_weight_scale): +def quant_output_scale_impl( + fn: Callable, inp: Tensor, quant_input_scale: Tensor, quant_weight_scale: Tensor): channel_dim = -1 if fn == F.linear else 1 output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim) - output_scale = quant_weight_scale.view(output_scale_shape) - output_scale = output_scale * quant_input_scale.view(output_scale_shape) + + quant_weight_scale = quant_weight_scale.view(output_scale_shape) + if len(quant_input_scale.shape) == 0: + quant_input_scale = quant_input_scale.view(output_scale_shape) + + output_scale = quant_weight_scale * quant_input_scale return output_scale