Skip to content

Commit

Permalink
Fix (core): MSE support with row wise quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius committed Jun 19, 2023
1 parent f88a8cc commit b217067
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 14 deletions.
8 changes: 1 addition & 7 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,10 @@ class PermuteDims(brevitas.jit.ScriptModule):
def __init__(self, permute_dims: Tuple[int, ...]) -> None:
super(PermuteDims, self).__init__()
self.permute_dims = permute_dims
# This is a workaroud to avoid re-computing permute during evaluation of a local loss
# like MSE over permuted data such as per channel scaled activations or tranposed conv weights
self.local_loss_mode: bool = brevitas.jit.Attribute(False, bool)

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
if not self.local_loss_mode:
return x.permute(*self.permute_dims).contiguous()
else:
return x
return x.permute(*self.permute_dims).contiguous()


class OverTensorView(brevitas.jit.ScriptModule):
Expand Down
12 changes: 8 additions & 4 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,11 +412,13 @@ def __init__(
self,
proxy_module,
mse_init_op,
inner_stats_input_view_shape_impl: torch.nn.Module,
stats_reduce_dim: Optional[int] = None,
mse_search_method='fibonacci',
mse_iters=10):
super(MSE, self).__init__()
self.mse_init_op = mse_init_op
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_forward = proxy_module.forward
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
self.internal_candidate = None
Expand All @@ -426,10 +428,10 @@ def __init__(
self.local_loss_mode: bool = False

def mse_loss_fn(self, x, quant_value):
# squeeze is a workaround for ConvTranpose per-channel weights
# where broadcasting generates an extra leading dim of size 1
loss = torch.nn.functional.mse_loss(x, quant_value.squeeze(), reduction='none')
loss = torch.nn.functional.mse_loss(x, quant_value, reduction='none')
if self.stats_reduce_dim is not None:
# stats_reduce_dim applies to the permuted and reshaped tensor
loss = self.input_view_shape_impl(loss)
loss = torch.sum(loss, dim=self.stats_reduce_dim)
else:
loss = torch.sum(loss)
Expand Down Expand Up @@ -483,7 +485,8 @@ def fib_seq(n):
return torch.where(f1 <= f2, x1, x2)

def mse_search(self, x):
init = self.mse_init_op(x).detach()
x_view = self.input_view_shape_impl(x)
init = self.mse_init_op(x_view).detach()
base = init / self.num
if self.search_method == 'grid':
best_candidate = self.mse_grid_search(base, x)
Expand All @@ -502,5 +505,6 @@ def forward(self, x):
else:
# This is invoked for the zero-point whenever scale is being optimized first
if self.internal_candidate is None:
x = self.input_view_shape_impl(x)
self.internal_candidate = self.mse_init_op(x).detach()
return self.internal_candidate
27 changes: 24 additions & 3 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from dependencies import this
from dependencies import value
from torch import nn

from brevitas.core.bit_width import BitWidthConst
from brevitas.core.bit_width import BitWidthStatefulConst
Expand All @@ -11,6 +12,7 @@
from brevitas.core.function_wrapper import TensorClamp
from brevitas.core.function_wrapper import TensorClampSte
from brevitas.core.function_wrapper.ops_ste import CeilSte
from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl
from brevitas.core.quant import ClampedBinaryQuant
from brevitas.core.quant.int import DecoupledRescalingIntQuant
from brevitas.core.quant.int import DecoupledRescalingIntQuantWithInput
Expand Down Expand Up @@ -392,21 +394,37 @@ def accumulator_bit_width_impl(accumulator_bit_width):
float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints


class MSESymmetricScaleSubInjector(ExtendedInjector):
class MSESubInjectorBase(ExtendedInjector):

@value
def inner_stats_input_view_shape_impl(per_channel):
if per_channel:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
else:
return StatsInputViewShapeImpl.OVER_TENSOR

permute_dims = (this << 1).permute_dims


class MSESymmetricScaleSubInjector(MSESubInjectorBase):
per_channel = (this << 1).scaling_per_output_channel
proxy_module = (this << 1).proxy_module
mse_init_op = AbsMax
stats_impl = MSE
stats_reduce_dim = (this << 1).stats_reduce_dim


class MSEAsymmetricScaleSubInjector(ExtendedInjector):
class MSEAsymmetricScaleSubInjector(MSESubInjectorBase):
per_channel = (this << 1).scaling_per_output_channel
proxy_module = (this << 1).proxy_module
mse_init_op = AbsMinMax
stats_impl = MSE
stats_reduce_dim = (this << 1).stats_reduce_dim


class MSEZeroPointSubInjector(ExtendedInjector):
class MSEZeroPointSubInjector(MSESubInjectorBase):
# zp is per channel when scaling is per channel
per_channel = (this << 1).scaling_per_output_channel
proxy_module = (this << 1).proxy_module
mse_init_op = NegativeMinOrZero
stats_impl = MSE
Expand All @@ -420,6 +438,7 @@ class MSEAsymmetricScale(ExtendedInjector):

mse_scale = MSEAsymmetricScaleSubInjector
scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS
scaling_stats_input_view_shape_impl = nn.Identity()

@value
def scaling_stats_impl():
Expand All @@ -433,6 +452,7 @@ class MSESymmetricScale(ExtendedInjector):

mse_scale = MSESymmetricScaleSubInjector
scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS
scaling_stats_input_view_shape_impl = nn.Identity()

@value
def scaling_stats_impl():
Expand All @@ -445,6 +465,7 @@ class MSEZeroPoint(ExtendedInjector):
"""

mse_zero_point = MSEZeroPointSubInjector
zero_point_stats_input_view_shape_impl = nn.Identity()

@value
def zero_point_stats_impl():
Expand Down

0 comments on commit b217067

Please # to comment.