From c43892db14424b81c63f896220bd1c686e151411 Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo Date: Mon, 12 Jun 2023 19:30:08 +0100 Subject: [PATCH] Feat (export/qcdq): support row wise QCDQ export --- src/brevitas/export/common/handler/qcdq.py | 30 +++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 7a158350c..16b8aa61c 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -3,6 +3,7 @@ from abc import ABC from abc import abstractmethod +from copy import copy import torch from torch import Tensor @@ -97,6 +98,7 @@ def signed_dtype(cls, bit_width, is_signed): class CDQProxyHandlerMixin(QuantAxisMixin, ClipMixin, ZeroPointHandlerMixin, CDQMixin, ABC): def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): + scale_orig_shape = scale.shape axis = cls.quant_axis(scale) if cls.flatten_dequantize_params: scale = scale.flatten() @@ -106,7 +108,14 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): zp = to_0dim_if_scalar(zero_point) zp = zp.expand_as(scale) zp = cls.zero_point_with_dtype(is_signed, bit_width, zp) - return {'scale': scale, 'zero_point': zp, 'axis': axis} + return { + 'scale': scale, + 'zero_point': zp, + 'axis': axis, + # We save only the scale original shape + # as zero-point is being expanded to the same + # size as the scale + 'scale_orig_shape': scale_orig_shape} class QCDQWeightQuantProxyHandlerMixin(CDQProxyHandlerMixin, ABC): @@ -136,15 +145,20 @@ def symbolic_execution(self, x: Tensor): assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' x = self.symbolic_kwargs['int_weights'][x.data_ptr()] clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs'] - dequantize_symbolic_kwargs = self.symbolic_kwargs['dequantize_symbolic_kwargs'] + # Copy dict to allow for popping kwargs even on shared quantizers + dequantize_symbolic_kwargs = copy(self.symbolic_kwargs['dequantize_symbolic_kwargs']) scale = dequantize_symbolic_kwargs['scale'] zero_point = dequantize_symbolic_kwargs['zero_point'] bit_width = self.symbolic_kwargs['bit_width'] + scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape') # Workaround to trick the tracer into believing all return values are used self.assert_ge_zero(scale, zero_point, bit_width) if clip_symbolic_kwargs is not None: x = self.clip_fn(x, *clip_symbolic_kwargs.values()) x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values()) + # Restore the original shapes to guarantee correct shape propagation downstream + scale = scale.view(scale_orig_shape) + zero_point = zero_point.view_as(scale) return x, scale, zero_point, bit_width @@ -199,9 +213,11 @@ def symbolic_execution(self, x: Tensor): assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs'] clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs'] - dequantize_symbolic_kwargs = self.symbolic_kwargs['dequantize_symbolic_kwargs'] + # Copy dict to allow for popping kwargs even on shared quantizers + dequantize_symbolic_kwargs = copy(self.symbolic_kwargs['dequantize_symbolic_kwargs']) scale = dequantize_symbolic_kwargs['scale'] zero_point = dequantize_symbolic_kwargs['zero_point'] + scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape') bit_width = self.symbolic_kwargs['bit_width'] # Workaround to trick the tracer into believing all return values are used self.assert_ge_zero(scale, zero_point, bit_width) @@ -209,6 +225,9 @@ def symbolic_execution(self, x: Tensor): if clip_symbolic_kwargs is not None: x = self.clip_fn(x, *clip_symbolic_kwargs.values()) x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values()) + # Restore the original shapes to guarantee correct shape propagation downstream + scale = scale.view(scale_orig_shape) + zero_point = zero_point.view_as(scale) return x, scale, zero_point, bit_width @@ -245,6 +264,8 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None): assert bit_width is not None or input_bit_width is not None, 'Input bit width required for bias export' if input_scale is not None: scale = input_scale + scale_orig_shape = scale.shape + zero_point_orig_shape = zero_point.shape if input_bit_width is not None: bit_width = input_bit_width quant_axis = self.quant_axis(scale) @@ -256,6 +277,9 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None): zero_point = self.zero_point_with_dtype( True, bit_width, zero_point) # assume signed is True y = self.dequantize_fn(int_bias, scale, zero_point, quant_axis) + # Restore the original shapes to guarantee correct shape propagation downstream + scale = scale.view(scale_orig_shape) + zero_point = zero_point.view(zero_point_orig_shape) return y, scale, zero_point, bit_width