From 45873224418d4a86acd9c89c8c3656976b22248a Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Wed, 29 Sep 2021 15:04:32 +0800 Subject: [PATCH] [Fix] fix the bug of mmocr visualizing (#105) * fix mmcor show image * remove rescale in ocr recognition --- mmdeploy/mmocr/apis/inference.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/mmdeploy/mmocr/apis/inference.py b/mmdeploy/mmocr/apis/inference.py index 8c6d7c07a..5cbaeea7c 100644 --- a/mmdeploy/mmocr/apis/inference.py +++ b/mmdeploy/mmocr/apis/inference.py @@ -44,7 +44,11 @@ def aug_test(self, imgs, img_metas, **kwargs): def extract_feat(self, imgs): raise NotImplementedError('This method is not implemented.') - def simple_test(self, img: torch.Tensor, img_metas: Sequence[dict], *args, + def simple_test(self, + img: torch.Tensor, + img_metas: Sequence[dict], + rescale: bool = False, + *args, **kwargs): """Run forward test. @@ -59,13 +63,13 @@ def simple_test(self, img: torch.Tensor, img_metas: Sequence[dict], *args, if len(img_metas) > 1: boundaries = [ self.bbox_head.get_boundary( - *(pred[i].unsqueeze(0)), [img_metas[i]], rescale=False) + *(pred[i].unsqueeze(0)), [img_metas[i]], rescale=rescale) for i in range(len(img_metas)) ] else: boundaries = [ - self.bbox_head.get_boundary(*pred, img_metas, rescale=False) + self.bbox_head.get_boundary(*pred, img_metas, rescale=rescale) ] return boundaries @@ -103,12 +107,12 @@ def aug_test(self, imgs, img_metas, **kwargs): def extract_feat(self, imgs): raise NotImplementedError('This method is not implemented.') - def forward(self, img: torch.Tensor, img_metas: Sequence[dict], *args, - **kwargs): + def forward(self, img: Union[torch.Tensor, Sequence[torch.Tensor]], + img_metas: Sequence[dict], *args, **kwargs): """Run forward. Args: - imgs (torch.Tensor): Image input tensor. + imgs (torch.Tensor | Sequence[torch.Tensor]): Image input tensor. img_metas (Sequence[dict]): List of image information. Returns: