Skip to content

Commit

Permalink
[Dy2St] transforms.RandomErasing Support static mode (#49617)
Browse files Browse the repository at this point in the history
* static.nn.cond ten

* add unitest

* update code style
  • Loading branch information
DrRyanHuang authored Jan 9, 2023
1 parent d4b3bfa commit e9df6fc
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 5 deletions.
32 changes: 32 additions & 0 deletions python/paddle/tests/test_transforms_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,38 @@ def set_trans_api(self):
self.api = transforms.RandomRotation(degree_tuple, expand=True, fill=3)


class TestRandomErasing(TestTransformUnitTestBase):
def set_trans_api(self):

self.value = 100
self.scale = (0.02, 0.33)
self.ratio = (0.3, 3.3)
self.api = transforms.RandomErasing(
prob=1, value=self.value, scale=self.scale, ratio=self.ratio
)

def test_transform(self):
dy_res = self.dynamic_transform()
if isinstance(dy_res, paddle.Tensor):
dy_res = dy_res.numpy()
st_res = self.static_transform()

self.assert_test_erasing(dy_res)
self.assert_test_erasing(st_res)

def assert_test_erasing(self, arr):

_, h, w = arr.shape
area = h * w

height = (arr[2] == self.value).cumsum(1)[:, -1].max()
width = (arr[2] == self.value).cumsum(0)[-1].max()
erasing_area = height * width

assert self.ratio[0] < height / width < self.ratio[1]
assert self.scale[0] < erasing_area / area < self.scale[1]


class TestRandomResizedCrop(TestTransformUnitTestBase):
def set_trans_api(self, eps=10e-5):
c, h, w = self.get_shape()
Expand Down
138 changes: 133 additions & 5 deletions python/paddle/vision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,8 +1914,8 @@ def __init__(
self.value = value
self.inplace = inplace

def _get_param(self, img, scale, ratio, value):
"""Get parameters for ``erase`` for a random erasing.
def _dynamic_get_param(self, img, scale, ratio, value):
"""Get parameters for ``erase`` for a random erasing in dynamic mode.
Args:
img (paddle.Tensor | np.array | PIL.Image): Image to be erased.
Expand Down Expand Up @@ -1964,13 +1964,104 @@ def _get_param(self, img, scale, ratio, value):

return 0, 0, h, w, img

def _apply_image(self, img):
def _static_get_param(self, img, scale, ratio, value):
"""Get parameters for ``erase`` for a random erasing in static mode.
Args:
img (paddle.static.Variable): Image to be erased.
scale (sequence, optional): The proportional range of the erased area to the input image.
ratio (sequence, optional): Aspect ratio range of the erased area.
value (sequence | None): The value each pixel in erased area will be replaced with.
If value is a sequence with length 3, the R, G, B channels will be ereased
respectively. If value is None, each pixel will be erased with random values.
Returns:
tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erase.
"""

c, h, w = img.shape[-3], img.shape[-2], img.shape[-1]

img_area = h * w
log_ratio = np.log(np.array(ratio))

def cond(counter, ten, erase_h, erase_w):
return counter < ten and (erase_h >= h or erase_w >= w)

def body(counter, ten, erase_h, erase_w):

erase_area = (
paddle.uniform([1], min=scale[0], max=scale[1]) * img_area
)
aspect_ratio = paddle.exp(
paddle.uniform([1], min=log_ratio[0], max=log_ratio[1])
)
erase_h = paddle.round(paddle.sqrt(erase_area * aspect_ratio)).cast(
"int32"
)
erase_w = paddle.round(paddle.sqrt(erase_area / aspect_ratio)).cast(
"int32"
)

counter += 1

return [counter, ten, erase_h, erase_w]

h = paddle.assign([h]).astype("int32")
w = paddle.assign([w]).astype("int32")
erase_h, erase_w = h.clone(), w.clone()
counter = paddle.full(
shape=[1], fill_value=0, dtype='int32'
) # loop counter
ten = paddle.full(
shape=[1], fill_value=10, dtype='int32'
) # loop length
counter, ten, erase_h, erase_w = paddle.static.nn.while_loop(
cond, body, [counter, ten, erase_h, erase_w]
)

if value is None:
v = paddle.normal(shape=[c, erase_h, erase_w]).astype(img.dtype)
else:
v = value[:, None, None]

zero = paddle.zeros([1]).astype("int32")
top = paddle.static.nn.cond(
erase_h < h and erase_w < w,
lambda: paddle.uniform(
shape=[1], min=0, max=h - erase_h + 1
).astype("int32"),
lambda: zero,
)

left = paddle.static.nn.cond(
erase_h < h and erase_w < w,
lambda: paddle.uniform(
shape=[1], min=0, max=w - erase_w + 1
).astype("int32"),
lambda: zero,
)

erase_h = paddle.static.nn.cond(
erase_h < h and erase_w < w, lambda: erase_h, lambda: h
)

erase_w = paddle.static.nn.cond(
erase_h < h and erase_w < w, lambda: erase_w, lambda: w
)

v = paddle.static.nn.cond(
erase_h < h and erase_w < w, lambda: v, lambda: img
)

return top, left, erase_h, erase_w, v, counter

def _dynamic_apply_image(self, img):
"""
Args:
img (paddle.Tensor | np.array | PIL.Image): Image to be Erased.
Returns:
output (paddle.Tensor np.array | PIL.Image): A random erased image.
output (paddle.Tensor | np.array | PIL.Image): A random erased image.
"""

if random.random() < self.prob:
Expand All @@ -1984,8 +2075,45 @@ def _apply_image(self, img):
raise ValueError(
"Value should be a single number or a sequence with length equals to image's channel."
)
top, left, erase_h, erase_w, v = self._get_param(
top, left, erase_h, erase_w, v = self._dynamic_get_param(
img, self.scale, self.ratio, value
)
return F.erase(img, top, left, erase_h, erase_w, v, self.inplace)
return img

def _static_apply_image(self, img):
"""
Args:
img (paddle.static.Variable): Image to be Erased.
Returns:
output (paddle.static.Variable): A random erased image.
"""

if isinstance(self.value, numbers.Number):
value = paddle.assign([self.value]).astype(img.dtype)
elif isinstance(self.value, str):
value = None
else:
value = paddle.assign(self.value).astype(img.dtype)
if value is not None and not (
value.shape[0] == 1 or value.shape[0] == 3
):
raise ValueError(
"Value should be a single number or a sequence with length equals to image's channel."
)

top, left, erase_h, erase_w, v, counter = self._static_get_param(
img, self.scale, self.ratio, value
)
return F.erase(img, top, left, erase_h, erase_w, v, self.inplace)

def _apply_image(self, img):
if paddle.in_dynamic_mode():
return self._dynamic_apply_image(img)
else:
return paddle.static.nn.cond(
paddle.rand([1]) < self.prob,
lambda: self._static_apply_image(img),
lambda: img,
)

0 comments on commit e9df6fc

Please # to comment.