Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed May 14, 2024
1 parent a626478 commit ea3068a
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/brevitas/quant_tensor/torch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

import functools
import math
from typing import Callable
import warnings

import torch
from torch import Tensor
import torch.nn.functional as F

import brevitas
from brevitas.function.ops import max_int
from brevitas.function.ops_ste import ceil_ste
from brevitas.utils.torch_utils import compute_channel_view_shape
Expand Down Expand Up @@ -358,11 +359,16 @@ def create_quant_tensor(tensor, scale, bit_width, zero_point, signed, training):
training=training)


def quant_output_scale_impl(fn, inp, quant_input_scale, quant_weight_scale):
def quant_output_scale_impl(
fn: Callable, inp: Tensor, quant_input_scale: Tensor, quant_weight_scale: Tensor):
channel_dim = -1 if fn == F.linear else 1
output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim)
output_scale = quant_weight_scale.view(output_scale_shape)
output_scale = output_scale * quant_input_scale.view(output_scale_shape)

quant_weight_scale = quant_weight_scale.view(output_scale_shape)
if len(quant_input_scale.shape) == 0:
quant_input_scale = quant_input_scale.view(output_scale_shape)

output_scale = quant_weight_scale * quant_input_scale
return output_scale


Expand Down

0 comments on commit ea3068a

Please # to comment.