From 5526473a7ecd4f3bc2e018ab2d5cfc79effeda4c Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 11 Apr 2024 13:07:01 +0100 Subject: [PATCH] Feat (examples/ptq): support for dynamic act quant --- src/brevitas/nn/quant_layer.py | 12 ++- .../common/generative/quant_blocks.py | 1 + .../imagenet_classification/ptq/README.md | 4 + .../imagenet_classification/ptq/ptq_common.py | 75 +++++++++++++------ .../ptq/ptq_evaluate.py | 13 +++- 5 files changed, 76 insertions(+), 29 deletions(-) diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 8cd2e10a7..c470d3003 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -121,9 +121,15 @@ def max_acc_bit_width(self, input_bit_width: Tensor, quant_weight_bit_width: Ten def quant_output_scale_impl( self, inp: Tensor, quant_input_scale: Tensor, quant_weight_scale: Tensor): - output_scale_shape = compute_channel_view_shape(inp, channel_dim=1) - output_scale = quant_weight_scale.view(output_scale_shape) - output_scale = output_scale * quant_input_scale.view(output_scale_shape) + channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1 + output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim) + + if len(quant_weight_scale.shape) == 0: + quant_weight_scale = quant_weight_scale.view(output_scale_shape) + if len(quant_input_scale.shape) == 0: + quant_input_scale = quant_input_scale.view(output_scale_shape) + + output_scale = quant_weight_scale * quant_input_scale return output_scale @property diff --git a/src/brevitas_examples/common/generative/quant_blocks.py b/src/brevitas_examples/common/generative/quant_blocks.py index e403deecf..40516111f 100644 --- a/src/brevitas_examples/common/generative/quant_blocks.py +++ b/src/brevitas_examples/common/generative/quant_blocks.py @@ -106,6 +106,7 @@ def forward(self, x) -> Tensor: shape = x.shape x = self.scaling_stats_input_view_shape_impl(x) x = self.stats_impl(x) + x = self.dynamic_scaling_broadcastable_fn(x, shape) return x diff --git a/src/brevitas_examples/imagenet_classification/ptq/README.md b/src/brevitas_examples/imagenet_classification/ptq/README.md index 19aa4054c..5387014e9 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/README.md +++ b/src/brevitas_examples/imagenet_classification/ptq/README.md @@ -84,6 +84,7 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir [--weight-quant-calibration-type {stats,mse}] [--act-equalization {fx,layerwise,None}] [--act-quant-calibration-type {stats,mse}] + [--act-scale-computation-type {static,dynamic}] [--graph-eq-iterations GRAPH_EQ_ITERATIONS] [--learned-round-iters LEARNED_ROUND_ITERS] [--learned-round-lr LEARNED_ROUND_LR] @@ -184,6 +185,9 @@ options: --act-quant-calibration-type {stats,mse} Activation quantization calibration type (default: stats) + --act-scale-computation-type {static,dynamic} + Activation quantization scale computation type + (default: static) --graph-eq-iterations GRAPH_EQ_ITERATIONS Numbers of iterations for graph equalization (default: 20) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index c4bc616e7..adc911dd7 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -8,6 +8,7 @@ import torch.backends.cudnn as cudnn from tqdm import tqdm +from brevitas.core.function_wrapper.shape import OverBatchOverTensorView from brevitas.core.scaling.standalone import ParameterFromStatsFromParameterScaling from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint from brevitas.graph.calibrate import bias_correction_mode @@ -49,10 +50,28 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE +from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat +from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat from brevitas_examples.imagenet_classification.ptq.learned_round_utils import learned_round_iterator from brevitas_examples.imagenet_classification.ptq.learned_round_utils import save_inp_out_data from brevitas_examples.imagenet_classification.ptq.learned_round_utils import split_layers + +# Every element of the Batch will have its own scale factor and zero point +class CNNShiftedUint8DynamicActPerTensorFloat(ShiftedUint8DynamicActPerTensorFloat): + scaling_stats_input_view_shape_impl = OverBatchOverTensorView + scaling_stats_permute_dims = None + stats_reduce_dim = 1 + dynamic_scaling_broadcastable_fn = lambda x, shape: x.view(shape[0], *[1 for _ in range(len(shape[1:]))]) + + +class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat): + scaling_stats_input_view_shape_impl = OverBatchOverTensorView + scaling_stats_permute_dims = None + stats_reduce_dim = 1 + dynamic_scaling_broadcastable_fn = lambda x, shape: x.view(shape[0], *[1 for _ in range(len(shape[1:]))]) + + QUANTIZE_MAP = {'layerwise': layerwise_quantize, 'fx': quantize, 'flexml': quantize_flexml} BIAS_BIT_WIDTH_MAP = {32: Int32Bias, 16: Int16Bias, None: None} @@ -98,21 +117,29 @@ INPUT_QUANT_MAP = { 'int': { - 'float_scale': { - 'stats': { - 'per_tensor': { - 'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat}}, - 'mse': { - 'per_tensor': { - 'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE}}}, - 'po2_scale': { - 'stats': { - 'per_tensor': { - 'sym': Int8ActPerTensorFixedPoint, 'asym': ShiftedUint8ActPerTensorFixedPoint}, - }, - 'mse': { - 'per_tensor': { - 'sym': Int8ActPerTensorFixedPointMSE}},}}, + 'static': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat}}, + 'mse': { + 'per_tensor': { + 'sym': Int8ActPerTensorFloatMSE, + 'asym': ShiftedUint8ActPerTensorFloatMSE}}}, + 'po2_scale': { + 'stats': { + 'per_tensor': { + 'sym': Int8ActPerTensorFixedPoint, + 'asym': ShiftedUint8ActPerTensorFixedPoint},}, + 'mse': { + 'per_tensor': { + 'sym': Int8ActPerTensorFixedPointMSE}}}}, + 'dynamic': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': CNNInt8DynamicActPerTensorFloat, + 'asym': CNNShiftedUint8DynamicActPerTensorFloat}}}}}, 'float': { 'float_scale': { 'stats': { @@ -146,6 +173,7 @@ def quantize_model( act_param_method='stats', weight_quant_type='sym', act_quant_granularity='per_tensor', + act_scale_computation_type='dynamic', uint_sym_act_for_unsigned_values=True, dtype=torch.float32, device='cpu'): @@ -157,8 +185,7 @@ def quantize_model( # We check all of the provided values are positive integers check_positive_int( weight_bit_width, - act_bit_width, - bias_bit_width, + act_bit_width, # bias_bit_width, layerwise_first_last_bit_width, layerwise_first_last_mantissa_bit_width, layerwise_first_last_exponent_bit_width, @@ -253,6 +280,7 @@ def layerwise_bit_width_fn_weight(module): act_quant_type=act_quant_type, act_quant_granularity=act_quant_granularity, act_quant_percentile=act_quant_percentile, + act_scale_computation_type=act_scale_computation_type, **weight_bit_width_dict, **act_bit_width_dict) @@ -288,6 +316,7 @@ def create_quant_maps( act_exponent_bit_width=None, act_bit_width=None, act_scale_type=None, + act_scale_computation_type=None, act_param_method=None, act_quant_type=None, act_quant_granularity=None, @@ -317,14 +346,14 @@ def kwargs_prefix(prefix, weight_kwargs): weight_quant = weight_quant.let(**weight_bit_width_dict) if act_bit_width is not None: - act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][ - act_quant_granularity][act_quant_type] + act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][act_scale_type][ + act_param_method][act_quant_granularity][act_quant_type] # Some activations in MHA should always be symmetric - sym_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][ - act_quant_granularity]['sym'] + sym_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][ + act_scale_type][act_param_method][act_quant_granularity]['sym'] # Linear layers with 2d input should always be per tensor - per_tensor_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][ - 'per_tensor'][act_quant_type] + per_tensor_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][ + act_scale_type][act_param_method]['per_tensor'][act_quant_type] act_quant = act_quant.let(**act_bit_width_dict) sym_act_quant = sym_act_quant.let(**act_bit_width_dict) per_tensor_act_quant = per_tensor_act_quant.let(**act_bit_width_dict) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 1f7c06a2b..377e705ab 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -138,6 +138,11 @@ def parse_type(v, default_type): default='stats', choices=['stats', 'mse'], help='Activation quantization calibration type (default: stats)') +parser.add_argument( + '--act-scale-computation-type', + default='static', + choices=['static', 'dynamic'], + help='Activation quantization scale computation type (default: static)') parser.add_argument( '--graph-eq-iterations', default=20, @@ -411,11 +416,13 @@ def main(): weight_exponent_bit_width=args.weight_exponent_bit_width, act_mantissa_bit_width=args.act_mantissa_bit_width, act_exponent_bit_width=args.act_exponent_bit_width, + act_scale_computation_type=args.act_scale_computation_type, uint_sym_act_for_unsigned_values=args.uint_sym_act_for_unsigned_values) - # Calibrate the quant_model on the calibration dataloader - print("Starting activation calibration:") - calibrate(calib_loader, quant_model) + if args.act_scale_computation_type == 'static': + # Calibrate the quant_model on the calibration dataloader + print("Starting activation calibration:") + calibrate(calib_loader, quant_model) if args.gpfq: print("Performing GPFQ:")