From 12302ca7d38da5f922da66f5caddaf5070f3642d Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo Date: Wed, 19 Apr 2023 16:56:46 +0100 Subject: [PATCH] Feat (QuantTensor): initial support for interpolate and pixel_shuffle --- src/brevitas/quant_tensor/torch_handler.py | 35 ++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/brevitas/quant_tensor/torch_handler.py b/src/brevitas/quant_tensor/torch_handler.py index 3bea44d70..5708fd35f 100644 --- a/src/brevitas/quant_tensor/torch_handler.py +++ b/src/brevitas/quant_tensor/torch_handler.py @@ -119,3 +119,38 @@ def adaptive_max_pool2d_handler(*args, **kwargs): @implements(F.adaptive_max_pool3d) def adaptive_max_pool3d_handler(*args, **kwargs): return quant_invariant_handler(F.adaptive_max_pool3d, *args, **kwargs) + + +@implements(F.interpolate) +def interpolate_handler( + inp, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None, + recompute_scale_factor=None, + **kwargs): # support newer kwargs added in recent pytorch versions + if mode == 'nearest' or mode == 'nearest_exact': + return quant_invariant_handler( + F.interpolate, + inp, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + **kwargs) + else: + return F.interpolate( + inp.value, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + **kwargs) + + +@implements(F.pixel_shuffle) +def pixel_shuffle_handler(*args, **kwargs): + return quant_invariant_handler(F.pixel_shuffle_handler, *args, **kwargs)