diff --git a/mmdeploy/mmocr/apis/inference.py b/mmdeploy/mmocr/apis/inference.py index 8c6d7c07a..5cbaeea7c 100644 --- a/mmdeploy/mmocr/apis/inference.py +++ b/mmdeploy/mmocr/apis/inference.py @@ -44,7 +44,11 @@ def aug_test(self, imgs, img_metas, **kwargs): def extract_feat(self, imgs): raise NotImplementedError('This method is not implemented.') - def simple_test(self, img: torch.Tensor, img_metas: Sequence[dict], *args, + def simple_test(self, + img: torch.Tensor, + img_metas: Sequence[dict], + rescale: bool = False, + *args, **kwargs): """Run forward test. @@ -59,13 +63,13 @@ def simple_test(self, img: torch.Tensor, img_metas: Sequence[dict], *args, if len(img_metas) > 1: boundaries = [ self.bbox_head.get_boundary( - *(pred[i].unsqueeze(0)), [img_metas[i]], rescale=False) + *(pred[i].unsqueeze(0)), [img_metas[i]], rescale=rescale) for i in range(len(img_metas)) ] else: boundaries = [ - self.bbox_head.get_boundary(*pred, img_metas, rescale=False) + self.bbox_head.get_boundary(*pred, img_metas, rescale=rescale) ] return boundaries @@ -103,12 +107,12 @@ def aug_test(self, imgs, img_metas, **kwargs): def extract_feat(self, imgs): raise NotImplementedError('This method is not implemented.') - def forward(self, img: torch.Tensor, img_metas: Sequence[dict], *args, - **kwargs): + def forward(self, img: Union[torch.Tensor, Sequence[torch.Tensor]], + img_metas: Sequence[dict], *args, **kwargs): """Run forward. Args: - imgs (torch.Tensor): Image input tensor. + imgs (torch.Tensor | Sequence[torch.Tensor]): Image input tensor. img_metas (Sequence[dict]): List of image information. Returns: