Skip to content

Commit

Permalink
Unified input for resized crop op (#2396)
Browse files Browse the repository at this point in the history
* [WIP] Unify random resized crop

* Unify input for RandomResizedCrop

* Fixed bugs and updated test

* Added resized crop functional test
- fixed bug with size convention

* Fixed incoherent sampling

* Fixed torch randint review remark
  • Loading branch information
vfdev-5 authored Jul 7, 2020
1 parent b572d5e commit 9b80465
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 31 deletions.
17 changes: 17 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,23 @@ def test_resize(self):
pad_tensor_script = script_fn(tensor, size=script_size, interpolation=interpolation)
self.assertTrue(resized_tensor.equal(pad_tensor_script), msg="{}, {}".format(size, interpolation))

def test_resized_crop(self):
# test values of F.resized_crop in several cases:
# 1) resize to the same size, crop to the same size => should be identity
tensor, _ = self._create_data(26, 36)
for i in [0, 2, 3]:
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=i)
self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))

# 2) resize by half and crop a TL corner
tensor, _ = self._create_data(26, 36)
out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=0)
expected_out_tensor = tensor[:, :20:2, :30:2]
self.assertTrue(
expected_out_tensor.equal(out_tensor),
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10])
)


if __name__ == '__main__':
unittest.main()
19 changes: 19 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,25 @@ def test_resize(self):
s_resized_tensor = script_transform(tensor)
self.assertTrue(s_resized_tensor.equal(resized_tensor))

def test_resized_crop(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8)

scale = (0.7, 1.2)
ratio = (0.75, 1.333)

for size in [(32, ), [32, ], [32, 32], (32, 32)]:
for interpolation in [NEAREST, BILINEAR, BICUBIC]:
transform = T.RandomResizedCrop(
size=size, scale=scale, ratio=ratio, interpolation=interpolation
)
s_transform = torch.jit.script(transform)

torch.manual_seed(12)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))


if __name__ == '__main__':
unittest.main()
16 changes: 9 additions & 7 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,24 +439,26 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
return crop(img, crop_top, crop_left, crop_height, crop_width)


def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINEAR):
"""Crop the given PIL Image and resize it to desired size.
def resized_crop(
img: Tensor, top: int, left: int, height: int, width: int, size: List[int], interpolation: int = Image.BILINEAR
) -> Tensor:
"""Crop the given image and resize it to desired size.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
Args:
img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.
img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box.
width (int): Width of the crop box.
size (sequence or int): Desired output size. Same semantics as ``resize``.
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``.
interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR``.
Returns:
PIL Image: Cropped image.
PIL Image or Tensor: Cropped image.
"""
assert F_pil._is_pil_image(img), 'img should be PIL Image'
img = crop(img, top, left, height, width)
img = resize(img, size, interpolation)
return img
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
elif len(size) < 2:
size_w, size_h = size[0], size[0]
else:
size_w, size_h = size[0], size[1]
size_w, size_h = size[1], size[0] # Convention (h, w)

if isinstance(size, int) or len(size) < 2:
if w < h:
Expand Down
64 changes: 41 additions & 23 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,40 +687,56 @@ def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)


class RandomResizedCrop(object):
"""Crop the given PIL Image to random size and aspect ratio.
class RandomResizedCrop(torch.nn.Module):
"""Crop the given image to random size and aspect ratio.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: expected output size of each edge
scale: range of size of the origin size cropped
ratio: range of aspect ratio of the origin aspect ratio cropped
interpolation: Default: PIL.Image.BILINEAR
size (int or sequence): expected output size of each edge. If size is an
int instead of sequence like (h, w), a square output size ``(size, size)`` is
made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
scale (tuple of float): range of size of the origin size cropped
ratio (tuple of float): range of aspect ratio of the origin aspect ratio cropped.
interpolation (int): Desired interpolation. Default: ``PIL.Image.BILINEAR``
"""

def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
if isinstance(size, (tuple, list)):
self.size = size
super().__init__()
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
elif isinstance(size, Sequence) and len(size) == 1:
self.size = (size[0], size[0])
else:
self.size = (size, size)
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
self.size = size

if not isinstance(scale, (tuple, list)):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, (tuple, list)):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
warnings.warn("Scale and ratio should be of kind (min, max)")

self.interpolation = interpolation
self.scale = scale
self.ratio = ratio

@staticmethod
def get_params(img, scale, ratio):
def get_params(
img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float]
) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Image to be cropped.
scale (tuple): range of size of the origin size cropped
img (PIL Image or Tensor): Input image.
scale (tuple): range of scale of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
Returns:
Expand All @@ -731,24 +747,26 @@ def get_params(img, scale, ratio):
area = height * width

for _ in range(10):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
target_area = area * torch.empty(1).uniform_(*scale).item()
log_ratio = torch.log(torch.tensor(ratio))
aspect_ratio = torch.exp(
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
).item()

w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))

if 0 < w <= width and 0 < h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
i = torch.randint(0, height - h + 1, size=(1,)).item()
j = torch.randint(0, width - w + 1, size=(1,)).item()
return i, j, h, w

# Fallback to central crop
in_ratio = float(width) / float(height)
if (in_ratio < min(ratio)):
if in_ratio < min(ratio):
w = width
h = int(round(w / min(ratio)))
elif (in_ratio > max(ratio)):
elif in_ratio > max(ratio):
h = height
w = int(round(h * max(ratio)))
else: # whole image
Expand All @@ -758,13 +776,13 @@ def get_params(img, scale, ratio):
j = (width - w) // 2
return i, j, h, w

def __call__(self, img):
def forward(self, img):
"""
Args:
img (PIL Image): Image to be cropped and resized.
img (PIL Image or Tensor): Image to be cropped and resized.
Returns:
PIL Image: Randomly cropped and resized image.
PIL Image or Tensor: Randomly cropped and resized image.
"""
i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
Expand Down

0 comments on commit 9b80465

Please # to comment.