From 229c1a169643137cd0f4e11583a79575a26febc7 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 18 Jul 2023 18:39:29 +0200 Subject: [PATCH] Fix (graph): fix fx quantize for conv->bn (#680) * Fix (fx): fix fx quantize for conv->bn * Formatter --- src/brevitas/graph/quantize_impl.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index 4b2eb2f79..130f977a2 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -49,6 +49,8 @@ MAX_RESIDUAL_ITERS = 9999 +BATCH_NORM = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) + def inp_placeholder_handler(model, input_quantizer): """ @@ -187,6 +189,18 @@ def output_quant_handler( user_module = get_module(model, user.target) if hasattr(user_module, 'act_quant'): output_quant = False + elif isinstance(user_module, BATCH_NORM): + # If the user is BatchNorm, check BN's users and potentially requentize at + # the output of BN + output_quant = False + output_quant_handler( + model, + user, + rewriters, + is_sign_preserving, + quant_identity_map, + quant_act_map, + unsigned_act_tuple) if output_quant: if quant_module_name is None and quant_module is None: if is_sign_preserving and are_inputs_unsigned(