diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index ce2de35b4..4b2eb2f79 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -16,6 +16,7 @@ ADD_FNS = [torch.add, operator.add, operator.iadd] ADD_METHODS = ['add', 'add_'] + CAT = brevitas.original_cat SIGN_PRESERVING_MODULES = ( @@ -46,6 +47,8 @@ nn.PixelUnshuffle, nn.Identity) +MAX_RESIDUAL_ITERS = 9999 + def inp_placeholder_handler(model, input_quantizer): """ @@ -253,8 +256,7 @@ def recursive_input_handler( else: assert align_output is None, f"align_output {str(align_output)} not supported." elif inp_node.op == 'call_function' and inp_node.target in [ - torch.flatten, torch.reshape, torch.transpose, operator.getitem, - operator.__getitem__]: + torch.flatten, torch.reshape, torch.transpose]: recursive_input_handler( model, inp_node, @@ -307,6 +309,7 @@ def _get_quant_module(model, node, quant_identity_map, quant_act_map, unsigned_a def residual_handler( model, quant_identity_map, quant_act_map, unsigned_act_tuple, align_input_quant_fn): + iter = 0 def is_converged(model): @@ -349,7 +352,11 @@ def is_converged(model): return True while not is_converged(model): - continue + iter += 1 + if iter == MAX_RESIDUAL_ITERS: + raise RuntimeError( + "Residual handler could not find a solution to align scale factors " + "across ADDs and CATs") return model