Skip to content

Commit

Permalink
Feat (export/qcdq): support row wise QCDQ export
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius committed Jun 19, 2023
1 parent b217067 commit c43892d
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from abc import ABC
from abc import abstractmethod
from copy import copy

import torch
from torch import Tensor
Expand Down Expand Up @@ -97,6 +98,7 @@ def signed_dtype(cls, bit_width, is_signed):
class CDQProxyHandlerMixin(QuantAxisMixin, ClipMixin, ZeroPointHandlerMixin, CDQMixin, ABC):

def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
scale_orig_shape = scale.shape
axis = cls.quant_axis(scale)
if cls.flatten_dequantize_params:
scale = scale.flatten()
Expand All @@ -106,7 +108,14 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
zp = to_0dim_if_scalar(zero_point)
zp = zp.expand_as(scale)
zp = cls.zero_point_with_dtype(is_signed, bit_width, zp)
return {'scale': scale, 'zero_point': zp, 'axis': axis}
return {
'scale': scale,
'zero_point': zp,
'axis': axis,
# We save only the scale original shape
# as zero-point is being expanded to the same
# size as the scale
'scale_orig_shape': scale_orig_shape}


class QCDQWeightQuantProxyHandlerMixin(CDQProxyHandlerMixin, ABC):
Expand Down Expand Up @@ -136,15 +145,20 @@ def symbolic_execution(self, x: Tensor):
assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled'
x = self.symbolic_kwargs['int_weights'][x.data_ptr()]
clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs']
dequantize_symbolic_kwargs = self.symbolic_kwargs['dequantize_symbolic_kwargs']
# Copy dict to allow for popping kwargs even on shared quantizers
dequantize_symbolic_kwargs = copy(self.symbolic_kwargs['dequantize_symbolic_kwargs'])
scale = dequantize_symbolic_kwargs['scale']
zero_point = dequantize_symbolic_kwargs['zero_point']
bit_width = self.symbolic_kwargs['bit_width']
scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape')
# Workaround to trick the tracer into believing all return values are used
self.assert_ge_zero(scale, zero_point, bit_width)
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
return x, scale, zero_point, bit_width


Expand Down Expand Up @@ -199,16 +213,21 @@ def symbolic_execution(self, x: Tensor):
assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled'
quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs']
clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs']
dequantize_symbolic_kwargs = self.symbolic_kwargs['dequantize_symbolic_kwargs']
# Copy dict to allow for popping kwargs even on shared quantizers
dequantize_symbolic_kwargs = copy(self.symbolic_kwargs['dequantize_symbolic_kwargs'])
scale = dequantize_symbolic_kwargs['scale']
zero_point = dequantize_symbolic_kwargs['zero_point']
scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape')
bit_width = self.symbolic_kwargs['bit_width']
# Workaround to trick the tracer into believing all return values are used
self.assert_ge_zero(scale, zero_point, bit_width)
x = self.quantize_fn(x, *quantize_symbolic_kwargs.values())
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
return x, scale, zero_point, bit_width


Expand Down Expand Up @@ -245,6 +264,8 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None):
assert bit_width is not None or input_bit_width is not None, 'Input bit width required for bias export'
if input_scale is not None:
scale = input_scale
scale_orig_shape = scale.shape
zero_point_orig_shape = zero_point.shape
if input_bit_width is not None:
bit_width = input_bit_width
quant_axis = self.quant_axis(scale)
Expand All @@ -256,6 +277,9 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None):
zero_point = self.zero_point_with_dtype(
True, bit_width, zero_point) # assume signed is True
y = self.dequantize_fn(int_bias, scale, zero_point, quant_axis)
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view(zero_point_orig_shape)
return y, scale, zero_point, bit_width


Expand Down

0 comments on commit c43892d

Please # to comment.