diff --git a/src/brevitas/core/function_wrapper/shape.py b/src/brevitas/core/function_wrapper/shape.py index c082b3d0b..4cda8da20 100644 --- a/src/brevitas/core/function_wrapper/shape.py +++ b/src/brevitas/core/function_wrapper/shape.py @@ -22,16 +22,10 @@ class PermuteDims(brevitas.jit.ScriptModule): def __init__(self, permute_dims: Tuple[int, ...]) -> None: super(PermuteDims, self).__init__() self.permute_dims = permute_dims - # This is a workaroud to avoid re-computing permute during evaluation of a local loss - # like MSE over permuted data such as per channel scaled activations or tranposed conv weights - self.local_loss_mode: bool = brevitas.jit.Attribute(False, bool) @brevitas.jit.script_method def forward(self, x: torch.Tensor): - if not self.local_loss_mode: - return x.permute(*self.permute_dims).contiguous() - else: - return x + return x.permute(*self.permute_dims).contiguous() class OverTensorView(brevitas.jit.ScriptModule): diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index 4282deaf4..161025ad2 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -412,11 +412,13 @@ def __init__( self, proxy_module, mse_init_op, + inner_stats_input_view_shape_impl: torch.nn.Module, stats_reduce_dim: Optional[int] = None, mse_search_method='fibonacci', mse_iters=10): super(MSE, self).__init__() self.mse_init_op = mse_init_op + self.input_view_shape_impl = inner_stats_input_view_shape_impl self.proxy_forward = proxy_module.forward self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled) self.internal_candidate = None @@ -426,10 +428,10 @@ def __init__( self.local_loss_mode: bool = False def mse_loss_fn(self, x, quant_value): - # squeeze is a workaround for ConvTranpose per-channel weights - # where broadcasting generates an extra leading dim of size 1 - loss = torch.nn.functional.mse_loss(x, quant_value.squeeze(), reduction='none') + loss = torch.nn.functional.mse_loss(x, quant_value, reduction='none') if self.stats_reduce_dim is not None: + # stats_reduce_dim applies to the permuted and reshaped tensor + loss = self.input_view_shape_impl(loss) loss = torch.sum(loss, dim=self.stats_reduce_dim) else: loss = torch.sum(loss) @@ -483,7 +485,8 @@ def fib_seq(n): return torch.where(f1 <= f2, x1, x2) def mse_search(self, x): - init = self.mse_init_op(x).detach() + x_view = self.input_view_shape_impl(x) + init = self.mse_init_op(x_view).detach() base = init / self.num if self.search_method == 'grid': best_candidate = self.mse_grid_search(base, x) @@ -502,5 +505,6 @@ def forward(self, x): else: # This is invoked for the zero-point whenever scale is being optimized first if self.internal_candidate is None: + x = self.input_view_shape_impl(x) self.internal_candidate = self.mse_init_op(x).detach() return self.internal_candidate diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index 457b3b854..2c1fe33f0 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -3,6 +3,7 @@ from dependencies import this from dependencies import value +from torch import nn from brevitas.core.bit_width import BitWidthConst from brevitas.core.bit_width import BitWidthStatefulConst @@ -11,6 +12,7 @@ from brevitas.core.function_wrapper import TensorClamp from brevitas.core.function_wrapper import TensorClampSte from brevitas.core.function_wrapper.ops_ste import CeilSte +from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl from brevitas.core.quant import ClampedBinaryQuant from brevitas.core.quant.int import DecoupledRescalingIntQuant from brevitas.core.quant.int import DecoupledRescalingIntQuantWithInput @@ -392,21 +394,37 @@ def accumulator_bit_width_impl(accumulator_bit_width): float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints -class MSESymmetricScaleSubInjector(ExtendedInjector): +class MSESubInjectorBase(ExtendedInjector): + + @value + def inner_stats_input_view_shape_impl(per_channel): + if per_channel: + return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS + else: + return StatsInputViewShapeImpl.OVER_TENSOR + + permute_dims = (this << 1).permute_dims + + +class MSESymmetricScaleSubInjector(MSESubInjectorBase): + per_channel = (this << 1).scaling_per_output_channel proxy_module = (this << 1).proxy_module mse_init_op = AbsMax stats_impl = MSE stats_reduce_dim = (this << 1).stats_reduce_dim -class MSEAsymmetricScaleSubInjector(ExtendedInjector): +class MSEAsymmetricScaleSubInjector(MSESubInjectorBase): + per_channel = (this << 1).scaling_per_output_channel proxy_module = (this << 1).proxy_module mse_init_op = AbsMinMax stats_impl = MSE stats_reduce_dim = (this << 1).stats_reduce_dim -class MSEZeroPointSubInjector(ExtendedInjector): +class MSEZeroPointSubInjector(MSESubInjectorBase): + # zp is per channel when scaling is per channel + per_channel = (this << 1).scaling_per_output_channel proxy_module = (this << 1).proxy_module mse_init_op = NegativeMinOrZero stats_impl = MSE @@ -420,6 +438,7 @@ class MSEAsymmetricScale(ExtendedInjector): mse_scale = MSEAsymmetricScaleSubInjector scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS + scaling_stats_input_view_shape_impl = nn.Identity() @value def scaling_stats_impl(): @@ -433,6 +452,7 @@ class MSESymmetricScale(ExtendedInjector): mse_scale = MSESymmetricScaleSubInjector scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS + scaling_stats_input_view_shape_impl = nn.Identity() @value def scaling_stats_impl(): @@ -445,6 +465,7 @@ class MSEZeroPoint(ExtendedInjector): """ mse_zero_point = MSEZeroPointSubInjector + zero_point_stats_input_view_shape_impl = nn.Identity() @value def zero_point_stats_impl():