Skip to content

Commit

Permalink
[Fix] fix the bug of mmocr visualizing (open-mmlab#105)
Browse files Browse the repository at this point in the history
* fix mmcor show image

* remove rescale in ocr recognition
  • Loading branch information
AllentDan authored Sep 29, 2021
1 parent 6318e9f commit 4587322
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions mmdeploy/mmocr/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ def aug_test(self, imgs, img_metas, **kwargs):
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')

def simple_test(self, img: torch.Tensor, img_metas: Sequence[dict], *args,
def simple_test(self,
img: torch.Tensor,
img_metas: Sequence[dict],
rescale: bool = False,
*args,
**kwargs):
"""Run forward test.
Expand All @@ -59,13 +63,13 @@ def simple_test(self, img: torch.Tensor, img_metas: Sequence[dict], *args,
if len(img_metas) > 1:
boundaries = [
self.bbox_head.get_boundary(
*(pred[i].unsqueeze(0)), [img_metas[i]], rescale=False)
*(pred[i].unsqueeze(0)), [img_metas[i]], rescale=rescale)
for i in range(len(img_metas))
]

else:
boundaries = [
self.bbox_head.get_boundary(*pred, img_metas, rescale=False)
self.bbox_head.get_boundary(*pred, img_metas, rescale=rescale)
]
return boundaries

Expand Down Expand Up @@ -103,12 +107,12 @@ def aug_test(self, imgs, img_metas, **kwargs):
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')

def forward(self, img: torch.Tensor, img_metas: Sequence[dict], *args,
**kwargs):
def forward(self, img: Union[torch.Tensor, Sequence[torch.Tensor]],
img_metas: Sequence[dict], *args, **kwargs):
"""Run forward.
Args:
imgs (torch.Tensor): Image input tensor.
imgs (torch.Tensor | Sequence[torch.Tensor]): Image input tensor.
img_metas (Sequence[dict]): List of image information.
Returns:
Expand Down

0 comments on commit 4587322

Please # to comment.