Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Dy2St] transforms.Resize Support static mode #49008

Closed
wants to merge 1 commit into from

Conversation

Aurelius84
Copy link
Contributor

@Aurelius84 Aurelius84 commented Dec 12, 2022

PR types

New features

PR changes

OPs

Describe

vision.transform变换类API新增支持静态图的样板间PR,供 #48612 参考。

支持静态图分支,实现API行为动静统一主要包括如下三个核心步骤:

step 1:tensor 判断逻辑升级

在 Resize API 中使用了 _is_tensor_image 函数。在静态图下,需要此函数支持 Variable 类型判断,同样返回 True 即可:

def _is_tensor_image(img):
    """
    Return True if img is a Tensor for dynamic mode or Variable for static mode.
    """
    return isinstance(img, (paddle.Tensor, Variable))

step 2:升级functional_tensor.py的核心接口

Resize 与Tensor相关的变换逻辑是在functional_tensor.py 中实现的,需要兼容适配下静态图逻辑。静态图下是通过append_op添加算子实现组网的,此部分逻辑大多数已经封装在了框架公用API中,只需要微调适配下即可。

注:要额外注意静态图下可能出现动态shape的场景,如image.shape = [-1, -1, 3],此时根据具体的实现判断是否需要特殊处理或捕获报错

def resize(img, size, interpolation='bilinear', data_format='CHW'):
   
    _assert_image_tensor(img, data_format)   # <----- 此部分要适配静态图Variable类型

    if not (
        isinstance(size, int)
        or (isinstance(size, (tuple, list)) and len(size) == 2)
    ):
        raise TypeError('Got inappropriate size arg: {}'.format(size))

    if isinstance(size, int):
        w, h = _get_image_size(img, data_format)   # <----- 静态图下 w,h 可能为 -1,要小心处理
        # TODO(Aurelius84): In static mode, w and h will be -1 for dynamic shape.
        # We should consider to support this case in future.
        if w <= 0 or h <= 0:
            raise NotImplementedError(
                "Not support while w<=0 or h<=0, but received w={}, h={}".format(
                    w, h
                )
            )
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
        else:
            oh = size
            ow = int(size * w / h)
    else:
        oh, ow = size

    img = img.unsqueeze(0)   # <---- 此接口已经是动静统一了,底层会自动走静态图append_op 分支
    img = F.interpolate(           # <---- 此接口已经是动静统一了,底层会自动走静态图append_op 分支
        img,
        size=(oh, ow),
        mode=interpolation.lower(),
        data_format='N' + data_format.upper(),
    )

    return img.squeeze(0)    # <---- 此接口已经是动静统一了,底层会自动走静态图append_op 分支

step 3:添加相应单测,确保静态图执行结果与动态图一致

可以统一添加到 test_transforms_static.py 文件里,统一继承TestTransformUnitTestBase基类即可。
对于新增单测,仅需要设置api信息即可,如有新需求,可扩展TestTransformUnitTestBase基类接口:

class TestResize(TestTransformUnitTestBase):
    def set_trans_api(self):
        self.api = transforms.Resize(size=(16, 16))



# 基类接口:
class TestTransformUnitTestBase(unittest.TestCase):
    def setUp(self):
        self.img = (np.random.rand(*self.get_shape()) * 255.0).astype(
            np.float32
        )
        self.set_trans_api()

    def get_shape(self):
        return (64, 64, 3)

    def set_trans_api(self):
        self.api = transforms.Resize(size=16)

    def dynamic_transform(self):
        paddle.seed(SEED)

        img_t = paddle.to_tensor(self.img)
        return self.api(img_t)

    def static_transform(self):
        paddle.enable_static()
        paddle.seed(SEED)

        main_program = paddle.static.Program()
        with paddle.static.program_guard(main_program):
            x = paddle.static.data(
                shape=self.get_shape(), dtype=paddle.float32, name='img'
            )
            out = self.api(x)

        exe = paddle.static.Executor()
        res = exe.run(main_program, fetch_list=[out], feed={'img': self.img})

        paddle.disable_static()
        return res[0]

    def test_transform(self):
        dy_res = self.dynamic_transform()
        st_res = self.static_transform()

        np.testing.assert_almost_equal(dy_res, st_res)

其他说明

1. 注意 API 入口逻辑分流

其他vision.transform 的API入口逻辑可能相对复杂,可以考虑在入口函数进行分流,如:

# 原始代码
class RandomHorizontalFlip(BaseTransform):

    def __init__(self, prob=0.5, keys=None):
        super().__init__(keys)
        assert 0 <= prob <= 1, "probability must be between 0 and 1"
        self.prob = prob

    def _apply_image(self, img):
        if random.random() < self.prob:
            return F.hflip(img)
        return img


# 修改思路
class RandomHorizontalFlip(BaseTransform):

    def __init__(self, prob=0.5, keys=None):
        super().__init__(keys)
        assert 0 <= prob <= 1, "probability must be between 0 and 1"
        self.prob = prob

    def _apply_image(self, img):
        if in_dynamic_mode():
             return self._dynamic_apply_image(img)
        else:
             return self._static_apply_image(img)
    
    def _dynamic_apply_image(self, img):
        if random.random() < self.prob:
            return F.hflip(img)
        return img

    def _static_apply_image(self, img):
        return  paddle.static.nn.cond(paddle.rand([1]) < self.prob, lambda : F.hflip(img), lambda: img)

2. 注意静态图下动态shape的适配或报错

@paddle-bot
Copy link

paddle-bot bot commented Dec 12, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@luotao1 luotao1 changed the title [API] transforms.Resize Support static mode [Dy2St] transforms.Resize Support static mode Dec 13, 2022
Copy link
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM~~~

@Aurelius84
Copy link
Contributor Author

#49024 已包含此PR内容,故closed掉

@Aurelius84 Aurelius84 closed this Dec 13, 2022
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants