diff --git a/captum/optim/_param/image/__init__.py b/captum/optim/_param/image/__init__.py index a2311f7c4..5c36c0c80 100755 --- a/captum/optim/_param/image/__init__.py +++ b/captum/optim/_param/image/__init__.py @@ -1 +1 @@ -"""(Differentiable) Input Parameterizations. Currently only 3-channel images""" +"""(Differentiable) Input Parameterizations. Currently only images""" diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index fa313b38a..f3a45346c 100644 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -1,4 +1,3 @@ -from copy import deepcopy from types import MethodType from typing import Callable, List, Optional, Tuple, Type, Union @@ -37,7 +36,12 @@ def __new__( Returns: x (ImageTensor): An `ImageTensor` instance. """ - if isinstance(x, torch.Tensor) and x.is_cuda: + if ( + isinstance(x, torch.Tensor) + and x.is_cuda + or isinstance(x, torch.Tensor) + and x.dtype != torch.float32 + ): x.show = MethodType(cls.show, x) x.export = MethodType(cls.export, x) return x @@ -181,12 +185,32 @@ def forward(self) -> torch.Tensor: class ImageParameterization(InputParameterization): + r"""The base class for all Image Parameterizations""" pass class FFTImage(ImageParameterization): """ Parameterize an image using inverse real 2D FFT + + Example:: + + >>> fft_image = opt.images.FFTImage(size=(224, 224)) + >>> output_image = fft_image() + >>> print(output_image.required_grad) + True + >>> print(output_image.shape) + torch.Size([1, 3, 224, 224]) + + Example for using an initialization tensor:: + + >>> init = torch.randn(1, 3, 224, 224) + >>> fft_image = opt.images.FFTImage(init=init) + >>> output_image = fft_image() + >>> print(output_image.required_grad) + True + >>> print(output_image.shape) + torch.Size([1, 3, 224, 224]) """ __constants__ = ["size"] @@ -201,16 +225,16 @@ def __init__( """ Args: - size (Tuple[int, int]): The height & width dimensions to use for the - parameterized output image tensor. + size (tuple of int): The height & width dimensions to use for the + parameterized output image tensor, in the format of: (height, width). channels (int, optional): The number of channels to use for each image. - Default: 3 + Default: ``3`` batch (int, optional): The number of images to stack along the batch dimension. - Default: 1 - init (torch.tensor, optional): Optionally specify a tensor to + Default: ``1`` + init (torch.Tensor, optional): Optionally specify a CHW or NCHW tensor to use instead of creating one. - Default: None + Default: ``None`` """ super().__init__() if init is None: @@ -221,9 +245,9 @@ def __init__( if init.dim() == 3: init = init.unsqueeze(0) self.size = (init.size(2), init.size(3)) - self.torch_rfft, self.torch_irfft, self.torch_fftfreq = self.get_fft_funcs() + self.torch_rfft, self.torch_irfft, self.torch_fftfreq = self._get_fft_funcs() - frequencies = self.rfft2d_freqs(*self.size) + frequencies = self._rfft2d_freqs(*self.size) scale = 1.0 / torch.max( frequencies, torch.full_like(frequencies, 1.0 / (max(self.size[0], self.size[1]))), @@ -250,7 +274,7 @@ def __init__( self.register_buffer("spectrum_scale", spectrum_scale) self.fourier_coeffs = nn.Parameter(fourier_coeffs) - def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor: + def _rfft2d_freqs(self, height: int, width: int) -> torch.Tensor: """ Computes 2D spectrum frequencies. @@ -260,7 +284,7 @@ def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor: width (int): The w dimension of the 2d frequency scale. Returns: - **tensor** (tensor): A 2d frequency scale tensor. + tensor (torch.Tensor): A 2d frequency scale tensor. """ fy = self.torch_fftfreq(height)[:, None] @@ -268,20 +292,20 @@ def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor: return torch.sqrt((fx * fx) + (fy * fy)) @torch.jit.export - def torch_irfftn(self, x: torch.Tensor) -> torch.Tensor: - if x.dtype != torch.complex64: + def _torch_irfftn(self, x: torch.Tensor) -> torch.Tensor: + if not torch.is_complex(x): x = torch.view_as_complex(x) return torch.fft.irfftn(x, s=self.size) # type: ignore - def get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]: + def _get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]: """ Support older versions of PyTorch. This function ensures that the same FFT operations are carried regardless of whether your PyTorch version has the torch.fft update. Returns: - fft functions (tuple of Callable): A list of FFT functions - to use for irfft, rfft, and fftfreq operations. + fft_functions (tuple of callable): A list of FFT functions to use for + irfft, rfft, and fftfreq operations. """ if version.parse(TORCH_VERSION) > version.parse("1.7.0"): @@ -292,7 +316,7 @@ def get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]: def torch_rfft(x: torch.Tensor) -> torch.Tensor: return torch.view_as_real(torch.fft.rfftn(x, s=self.size)) - torch_irfftn = self.torch_irfftn + torch_irfftn = self._torch_irfftn def torch_fftfreq(v: int, d: float = 1.0) -> torch.Tensor: return torch.fft.fftfreq(v, d) @@ -320,7 +344,7 @@ def torch_fftfreq(v: int, d: float = 1.0) -> torch.Tensor: def forward(self) -> torch.Tensor: """ Returns: - **output** (torch.tensor): A spatially recorrelated tensor. + output (torch.Tensor): A spatially recorrelated NCHW tensor. """ scaled_spectrum = self.fourier_coeffs * self.spectrum_scale @@ -333,6 +357,25 @@ def forward(self) -> torch.Tensor: class PixelImage(ImageParameterization): """ Parameterize a simple pixel image tensor that requires no additional transforms. + + Example:: + + >>> pixel_image = opt.images.PixelImage(size=(224, 224)) + >>> output_image = pixel_image() + >>> print(output_image.required_grad) + True + >>> print(output_image.shape) + torch.Size([1, 3, 224, 224]) + + Example for using an initialization tensor:: + + >>> init = torch.randn(1, 3, 224, 224) + >>> pixel_image = opt.images.PixelImage(init=init) + >>> output_image = pixel_image() + >>> print(output_image.required_grad) + True + >>> print(output_image.shape) + torch.Size([1, 3, 224, 224]) """ def __init__( @@ -345,16 +388,16 @@ def __init__( """ Args: - size (Tuple[int, int]): The height & width dimensions to use for the - parameterized output image tensor. + size (tuple of int): The height & width dimensions to use for the + parameterized output image tensor, in the format of: (height, width). channels (int, optional): The number of channels to use for each image. - Default: 3 + Default: ``3`` batch (int, optional): The number of images to stack along the batch dimension. - Default: 1 - init (torch.tensor, optional): Optionally specify a tensor to + Default: ``1`` + init (torch.Tensor, optional): Optionally specify a CHW or NCHW tensor to use instead of creating one. - Default: None + Default: ``None`` """ super().__init__() if init is None: @@ -367,6 +410,10 @@ def __init__( self.image = nn.Parameter(init) def forward(self) -> torch.Tensor: + """ + Returns: + output (torch.Tensor): An NCHW tensor. + """ if torch.jit.is_scripting(): return self.image return self.image.refine_names("B", "C", "H", "W") @@ -374,96 +421,101 @@ def forward(self) -> torch.Tensor: class LaplacianImage(ImageParameterization): """ - TODO: Fix divison by 6 in setup_input when init is not None. Parameterize an image tensor with a laplacian pyramid. + + Example:: + + >>> laplacian_image = opt.images.LaplacianImage(size=(224, 224)) + >>> output_image = laplacian_image() + >>> print(output_image.required_grad) + True + >>> print(output_image.shape) + torch.Size([1, 3, 224, 224]) + + Example for using an initialization tensor:: + + >>> init = torch.randn(1, 3, 224, 224) + >>> laplacian_image = opt.images.LaplacianImage(init=init) + >>> output_image = laplacian_image() + >>> print(output_image.required_grad) + True + >>> print(output_image.shape) + torch.Size([1, 3, 224, 224]) """ def __init__( self, - size: Tuple[int, int] = None, + size: Tuple[int, int] = (224, 224), channels: int = 3, batch: int = 1, init: Optional[torch.Tensor] = None, + power: float = 0.1, + scale_list: List[float] = [1.0, 2.0, 4.0, 8.0, 16.0, 32.0], ) -> None: """ Args: - size (Tuple[int, int]): The height & width dimensions to use for the - parameterized output image tensor. + size (tuple of int): The height & width dimensions to use for the + parameterized output image tensor, in the format of: (height, width). channels (int, optional): The number of channels to use for each image. - Default: 3 + Default: ``3`` batch (int, optional): The number of images to stack along the batch dimension. - Default: 1 - init (torch.tensor, optional): Optionally specify a tensor to + Default: ``1`` + init (torch.Tensor, optional): Optionally specify a CHW or NCHW tensor to use instead of creating one. - Default: None + Default: ``None`` + power (float, optional): The desired power value to use. + Default: ``0.1`` + scale_list (list of float, optional): The desired list of scale values to + use in the laplacian pyramid. The height & width dimensions specified + in ``size`` or used in the ``init`` tensor should be divisible by every + scale value in the scale list with no remainder left over. The default + ``scale_list`` values are set to work with a ``size`` of + ``(224, 224)``. + Default: ``[1.0, 2.0, 4.0, 8.0, 16.0, 32.0]`` """ super().__init__() - power = 0.1 - - if init is None: - tensor_params, self.scaler = self._setup_input(size, channels, power, init) - - self.tensor_params = torch.nn.ModuleList( - [deepcopy(tensor_params) for b in range(batch)] - ) - else: + if init is not None: + assert init.dim() in [3, 4] init = init.unsqueeze(0) if init.dim() == 3 else init - P = [] - for b in range(init.size(0)): - tensor_params, self.scaler = self._setup_input( - size, channels, power, init[b].unsqueeze(0) - ) - P.append(tensor_params) - self.tensor_params = torch.nn.ModuleList(P) + size = list(init.shape[2:]) - def _setup_input( - self, - size: Tuple[int, int], - channels: int, - power: float = 0.1, - init: Optional[torch.Tensor] = None, - ) -> Tuple[List[torch.Tensor], List[torch.nn.Upsample]]: tensor_params, scaler = [], [] - scale_list = [1, 2, 4, 8, 16, 32] for scale in scale_list: + assert size[0] % scale == 0 and size[1] % scale == 0, ( + "The chosen image height & width dimensions" + + " must be divisible by all scale values " + + " with no remainder left over." + ) + h, w = int(size[0] // scale), int(size[1] // scale) if init is None: - x = torch.randn([1, channels, h, w]) / 10 + x = torch.randn([batch, channels, h, w]) / 10 else: x = F.interpolate(init.clone(), size=(h, w), mode="bilinear") - x = x / 6 # Prevents output from being all white + x = x / 10 upsample = torch.nn.Upsample(scale_factor=scale, mode="nearest") - x = x * (scale**power) / (32**power) + x = x * (scale**power) / (max(scale_list) ** power) x = torch.nn.Parameter(x) tensor_params.append(x) scaler.append(upsample) - tensor_params = torch.nn.ParameterList(tensor_params) - return tensor_params, scaler + self.tensor_params = torch.nn.ParameterList(tensor_params) + self.scaler = scaler - def _create_tensor(self, params_list: torch.nn.ParameterList) -> torch.Tensor: + def forward(self) -> torch.Tensor: """ - Resize tensor parameters to the target size. - - Args: - - params_list (torch.nn.ParameterList): List of tensors to resize. - Returns: - **tensor** (torch.Tensor): The sum of all tensor parameters. + output (torch.Tensor): An NCHW tensor created from a laplacian pyramid. """ - A: List[torch.Tensor] = [] - for xi, upsamplei in zip(params_list, self.scaler): + A = [] + for xi, upsamplei in zip(self.tensor_params, self.scaler): A.append(upsamplei(xi)) - return torch.sum(torch.cat(A), 0) + 0.5 + output = sum(A) + 0.5 - def forward(self) -> torch.Tensor: - A: List[torch.Tensor] = [] - for params_list in self.tensor_params: - tensor = self._create_tensor(params_list) - A.append(tensor) - return torch.stack(A).refine_names("B", "C", "H", "W") + if torch.jit.is_scripting(): + return output + return output.refine_names("B", "C", "H", "W") class SimpleTensorParameterization(ImageParameterization): @@ -484,7 +536,8 @@ def __init__(self, tensor: torch.Tensor = None) -> None: """ Args: - tensor (torch.tensor): The tensor to return everytime this module is called. + tensor (torch.Tensor): The tensor to return every time this module is + called. """ super().__init__() assert isinstance(tensor, torch.Tensor) @@ -509,6 +562,17 @@ class SharedImage(ImageParameterization): Mordvintsev, et al., "Differentiable Image Parameterizations", Distill, 2018. https://distill.pub/2018/differentiable-parameterizations/ + + Example:: + + >>> fft_image = opt.images.FFTImage(size=(224, 224), batch=2) + >>> shared_shapes = ((1, 3, 64, 64), (4, 3, 32, 32)) + >>> shared_image = opt.images.SharedImage(shared_shapes, fft_image) + >>> output_image = shared_image() + >>> print(output_image.required_grad) + True + >>> print(output_image.shape) + torch.Size([2, 3, 224, 224]) """ __constants__ = ["offset"] @@ -522,13 +586,13 @@ def __init__( """ Args: - shapes (list of int or list of list of ints): The shapes of the shared + shapes (list of int or list of list of int): The shapes of the shared tensors to use for creating the nn.Parameter tensors. parameterization (ImageParameterization): An image parameterization instance. - offset (int or list of int or list of list of ints , optional): The offsets + offset (int or list of int or list of list of int, optional): The offsets to use for the shared tensors. - Default: None + Default: ``None`` """ super().__init__() assert shapes is not None @@ -552,12 +616,12 @@ def _get_offset(self, offset: Union[int, Tuple[int]], n: int) -> List[List[int]] Args: - offset (int or list of int or list of list of ints , optional): The offsets + offset (int or list of int or list of list of int, optional): The offsets to use for the shared tensors. n (int): The number of tensors needing offset values. Returns: - **offset** (list of list of int): A list of offset values. + offset (List[List[int]]): A list of offset values. """ if type(offset) is tuple or type(offset) is list: if type(offset[0]) is tuple or type(offset[0]) is list: @@ -581,7 +645,7 @@ def _apply_offset(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: x_list (list of torch.Tensor): list of tensors to offset. Returns: - **A** (list of torch.Tensor): list of offset tensors. + A (list of torch.Tensor): list of offset tensors. """ A: List[torch.Tensor] = [] @@ -616,8 +680,8 @@ def _interpolate_bilinear( Args: x (torch.Tensor): The NCHW tensor to resize. - size (tuple of int): The desired output size to resize the input - to, with a format of: [height, width]. + size (tuple of int): The desired output size to resize the input to, with + a format of: [height, width]. Returns: x (torch.Tensor): A resized NCHW tensor. @@ -645,8 +709,8 @@ def _interpolate_trilinear( Args: x (torch.Tensor): The NCHW tensor to resize. - size (tuple of int): The desired output size to resize the input - to, with a format of: [channels, height, width]. + size (tuple of int): The desired output size to resize the input to, with + a format of: [channels, height, width]. Returns: x (torch.Tensor): A resized NCHW tensor. @@ -678,7 +742,7 @@ def _interpolate_tensor( width (int): The width to resize the tensor to. Returns: - **tensor** (torch.Tensor): A resized tensor. + tensor (torch.Tensor): A resized tensor. """ if x.size(1) == channels: @@ -721,6 +785,28 @@ def forward(self) -> torch.Tensor: class StackImage(ImageParameterization): """ Stack multiple NCHW image parameterizations along their batch dimensions. + + Example:: + + >>> fft_image_1 = opt.images.FFTImage(size=(224, 224), batch=1) + >>> fft_image_2 = opt.images.FFTImage(size=(224, 224), batch=1) + >>> stack_image = opt.images.StackImage([fft_image_1, fft_image_2]) + >>> output_image = stack_image() + >>> print(output_image.required_grad) + True + >>> print(output_image.shape) + torch.Size([2, 3, 224, 224]) + + Example with ``ImageParameterization`` & ``torch.Tensor``:: + + >>> fft_image = opt.images.FFTImage(size=(224, 224), batch=1) + >>> tensor_image = torch.randn(1, 3, 224, 224) + >>> stack_image = opt.images.StackImage([fft_image, tensor_image]) + >>> output_image = stack_image() + >>> print(output_image.required_grad) + True + >>> print(output_image.shape) + torch.Size([2, 3, 224, 224]) """ __constants__ = ["dim", "output_device"] @@ -735,15 +821,16 @@ def __init__( Args: parameterizations (list of ImageParameterization and torch.Tensor): A list - of image parameterizations to stack across their batch dimensions. - dim (int, optional): Optionally specify the dim to concatinate - parameterization outputs on. Default is set to the batch dimension. - Default: 0 + of image parameterizations and tensors to concatenate across a + specified dimension. + dim (int, optional): Optionally specify the dim to concatenate + parameterization outputs on. Default is set to the batch dimension. + Default: ``0`` output_device (torch.device, optional): If the parameterizations are on different devices, then their outputs will be moved to the device - specified by this variable. Default is set to None with the expectation - that all parameterizations are on the same device. - Default: None + specified by this variable. Default is set to ``None`` with the + expectation that all parameterization outputs are on the same device. + Default: ``None`` """ super().__init__() assert len(parameterizations) > 0 @@ -786,16 +873,46 @@ def forward(self) -> torch.Tensor: class NaturalImage(ImageParameterization): - r"""Outputs an optimizable input image. + r"""Outputs an optimizable input image wrapped in :class:`.ImageTensor`. - By convention, single images are CHW and float32s in [0,1]. - The underlying parameterization can be decorrelated via a ToRGB transform. - When used with the (default) FFT parameterization, this results in a fully - uncorrelated image parameterization. :-) + By convention, single images are CHW and float32s in [0, 1]. + The underlying parameterization can be decorrelated via a + :class:`captum.optim.transforms.ToRGB` transform. + When used with the (default) :class:`.FFTImage` parameterization, this results in + a fully uncorrelated image parameterization. :-) If a model requires a normalization step, such as normalizing imagenet RGB values, - or rescaling to [0,255], it can perform those steps with the provided transforms or - inside its computation. + or rescaling to [0, 255], it can perform those steps with the provided transforms + or inside its module class. + + Example:: + + >>> image = opt.images.NaturalImage(size=(224, 224), channels=3, batch=1) + >>> image_tensor = image() + >>> print(image_tensor.required_grad) + True + >>> print(image_tensor.shape) + torch.Size([1, 3, 224, 224]) + + Example for using an initialization tensor:: + + >>> init = torch.randn(1, 3, 224, 224) + >>> image = opt.images.NaturalImage(init=init) + >>> image_tensor = image() + >>> print(image_tensor.required_grad) + True + >>> print(image_tensor.shape) + torch.Size([1, 3, 224, 224]) + + Example for using a parameterization:: + + >>> fft_image = opt.images.FFTImage(size=(224, 224), channels=3, batch=1) + >>> image = opt.images.NaturalImage(parameterization=fft_image) + >>> image_tensor = image() + >>> print(image_tensor.required_grad) + True + >>> print(image_tensor.shape) + torch.Size([1, 3, 224, 224]) """ def __init__( @@ -805,33 +922,60 @@ def __init__( batch: int = 1, init: Optional[torch.Tensor] = None, parameterization: ImageParameterization = FFTImage, - squash_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + squash_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = torch.sigmoid, decorrelation_module: Optional[nn.Module] = ToRGB(transform="klt"), decorrelate_init: bool = True, ) -> None: """ Args: - size (Tuple[int, int], optional): The height and width to use for the - nn.Parameter image tensor. - Default: (224, 224) + size (tuple of int, optional): The height and width to use for the + nn.Parameter image tensor, in the format of: (height, width). + This parameter is not used if the given ``parameterization`` is an + instance. + Default: ``(224, 224)`` channels (int, optional): The number of channels to use when creating the - nn.Parameter tensor. - Default: 3 + nn.Parameter tensor. This parameter is not used if the given + ``parameterization`` is an instance. + Default: ``3`` batch (int, optional): The number of channels to use when creating the - nn.Parameter tensor, or stacking init images. - Default: 1 + nn.Parameter tensor. This parameter is not used if the given + ``parameterization`` is an instance. + Default: ``1`` + init (torch.Tensor, optional): Optionally specify a tensor to use instead + of creating one from random noise. This parameter is not used if the + given ``parameterization`` is an instance. Set to ``None`` for random + init. + Default: ``None`` parameterization (ImageParameterization, optional): An image parameterization class, or instance of an image parameterization class. - Default: FFTImage - squash_func (Callable[[torch.Tensor], torch.Tensor]], optional): The squash - function to use after color recorrelation. A funtion or lambda function. - Default: None - decorrelation_module (nn.Module, optional): A ToRGB instance. - Default: ToRGB + Default: :class:`.FFTImage` + squash_func (callable, optional): The squash function to use after color + recorrelation. A function, lambda function, or callable class instance. + Any provided squash function should take a single input tensor and + return a single output tensor. If set to ``None``, then + :class:`torch.nn.Identity` will be used to make it a non op. + Default: :func:`torch.sigmoid` + decorrelation_module (nn.Module, optional): A module instance that + recorrelates the colors of an input image. Custom modules can make use + of the ``decorrelate_init`` parameter by having a second ``inverse`` + parameter in their forward functions that performs the inverse + operation when it is set to ``True`` (see :class:`.ToRGB` for an + example). Set to ``None`` for no recorrelation. + Default: :class:`.ToRGB` decorrelate_init (bool, optional): Whether or not to apply color - decorrelation to the init tensor input. - Default: True + decorrelation to the init tensor input. This parameter is not used if + the given ``parameterization`` is an instance or if init is ``None``. + Default: ``True`` + + Attributes: + + parameterization (ImageParameterization): The given image parameterization + instance given when initializing ``NaturalImage``. + Default: :class:`.FFTImage` + decorrelation_module (torch.nn.Module): The given decorrelation module + instance given when initializing ``NaturalImage``. + Default: :class:`.ToRGB` """ super().__init__() if not isinstance(parameterization, ImageParameterization): @@ -851,21 +995,13 @@ def __init__( ) init = self.decorrelate(init, inverse=True).rename(None) - if squash_func is None: - squash_func = self._clamp_image - - self.squash_func = torch.sigmoid if squash_func is None else squash_func + self.squash_func = squash_func or torch.nn.Identity() if not isinstance(parameterization, ImageParameterization): parameterization = parameterization( size=size, channels=channels, batch=batch, init=init ) self.parameterization = parameterization - @torch.jit.export - def _clamp_image(self, x: torch.Tensor) -> torch.Tensor: - """JIT supported squash function.""" - return x.clamp(0, 1) - @torch.jit.ignore def _to_image_tensor(self, x: torch.Tensor) -> torch.Tensor: """ @@ -873,14 +1009,20 @@ def _to_image_tensor(self, x: torch.Tensor) -> torch.Tensor: Args: - x (torch.tensor): An input tensor. + x (torch.Tensor): An input tensor. Returns: - x (ImageTensor): An instance of ImageTensor with the input tensor. + x (ImageTensor): An instance of ``ImageTensor`` with the input tensor. """ return ImageTensor(x) - def forward(self) -> torch.Tensor: + def forward(self) -> ImageTensor: + """ + Returns: + image_tensor (ImageTensor): The parameterization output wrapped in + :class:`.ImageTensor`, that has optionally had its colors + recorrelated. + """ image = self.parameterization() if self.decorrelate is not None: image = self.decorrelate(image) diff --git a/captum/optim/_param/image/transforms.py b/captum/optim/_param/image/transforms.py index 4ec876263..8131f4fc1 100644 --- a/captum/optim/_param/image/transforms.py +++ b/captum/optim/_param/image/transforms.py @@ -939,24 +939,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x[:, [2, 1, 0]] -# class TransformationRobustness(nn.Module): -# def __init__(self, jitter=False, scale=False): -# super().__init__() -# if jitter: -# self.jitter = RandomSpatialJitter(4) -# if scale: -# self.scale = RandomScale() - -# def forward(self, x): -# original_shape = x.shape -# if hasattr(self, "jitter"): -# x = self.jitter(x) -# if hasattr(self, "scale"): -# x = self.scale(x) -# cropped = center_crop(x, original_shape) -# return cropped - - # class RandomHomography(nn.Module): # def __init__(self): # super().__init__() @@ -979,7 +961,7 @@ class GaussianSmoothing(nn.Module): in the input using a depthwise convolution. """ - __constants__ = ["groups"] + __constants__ = ["groups", "padding"] def __init__( self, @@ -987,6 +969,7 @@ def __init__( kernel_size: Union[int, Sequence[int]], sigma: Union[float, Sequence[float]], dim: int = 2, + padding: Union[str, int, Tuple[int, int]] = "same", ) -> None: """ Args: @@ -996,7 +979,11 @@ def __init__( kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the gaussian kernel. dim (int, optional): The number of dimensions of the data. - Default value is 2 (spatial). + Default value is ``2`` for (spatial) + padding (str, int or list of tuple, optional): The desired padding amount + or mode to use. One of; ``"valid"``, ``"same"``, a single number, or a + tuple in the format of: (padH, padW). + Default: ``"same"`` """ super().__init__() if isinstance(kernel_size, numbers.Number): @@ -1007,9 +994,18 @@ def __init__( # The gaussian kernel is the product of the # gaussian function of each dimension. kernel = 1 - meshgrids = torch.meshgrid( - [torch.arange(size, dtype=torch.float32) for size in kernel_size] - ) + + # PyTorch v1.10.0 adds a new indexing argument + if version.parse(torch.__version__) >= version.parse("1.10.0"): + meshgrids = torch.meshgrid( + [torch.arange(size, dtype=torch.float32) for size in kernel_size], + indexing="ij", + ) + else: + meshgrids = torch.meshgrid( + [torch.arange(size, dtype=torch.float32) for size in kernel_size] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): mean = (size - 1) / 2 kernel *= ( @@ -1027,6 +1023,7 @@ def __init__( self.register_buffer("weight", kernel) self.groups = channels + self.padding = padding if dim == 1: self.conv = F.conv1d @@ -1048,9 +1045,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input (torch.Tensor): Input to apply gaussian filter on. Returns: - **filtered** (torch.Tensor): Filtered output. + filtered (torch.Tensor): Filtered output. """ - return self.conv(input, weight=self.weight, groups=self.groups) + return self.conv( + input, weight=self.weight, groups=self.groups, padding=self.padding + ) class SymmetricPadding(torch.autograd.Function): diff --git a/tests/optim/models/test_models_common.py b/tests/optim/models/test_models_common.py index 5c1006076..11856e44e 100644 --- a/tests/optim/models/test_models_common.py +++ b/tests/optim/models/test_models_common.py @@ -6,6 +6,7 @@ import torch import torch.nn.functional as F from captum.optim.models import googlenet +from packaging import version from tests.helpers.basic import BaseTest, assertTensorAlmostEqual @@ -37,7 +38,10 @@ def check_grad(self, grad_input, grad_output): rr_layer = model_utils.RedirectedReluLayer() x = torch.zeros(1, 3, 4, 4, requires_grad=True) - rr_layer.register_backward_hook(check_grad) + if version.parse(torch.__version__) >= version.parse("1.8.0"): + rr_layer.register_full_backward_hook(check_grad) + else: + rr_layer.register_backward_hook(check_grad) rr_loss = rr_layer(x * 1).mean() rr_loss.backward() diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index 617d34a3a..9d23efd37 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -20,6 +20,17 @@ def test_new(self) -> None: test_tensor = images.ImageTensor(x) self.assertTrue(torch.is_tensor(test_tensor)) self.assertEqual(x.shape, test_tensor.shape) + self.assertEqual(x.dtype, test_tensor.dtype) + + def test_new_dtype_float64(self) -> None: + x = torch.ones(5, dtype=torch.float64) + test_tensor = images.ImageTensor(x) + self.assertEqual(test_tensor.dtype, torch.float64) + + def test_new_dtype_float16(self) -> None: + x = torch.ones(5, dtype=torch.float16) + test_tensor = images.ImageTensor(x) + self.assertEqual(test_tensor.dtype, torch.float16) def test_new_numpy(self) -> None: x = torch.ones(5).numpy() @@ -33,6 +44,13 @@ def test_new_list(self) -> None: self.assertTrue(torch.is_tensor(test_tensor)) self.assertEqual(x.shape, test_tensor.shape) + def test_new_with_grad(self) -> None: + x = torch.ones(5, requires_grad=True) + test_tensor = images.ImageTensor(x) + self.assertTrue(test_tensor.requires_grad) + self.assertTrue(torch.is_tensor(test_tensor)) + self.assertEqual(x.shape, test_tensor.shape) + def test_torch_function(self) -> None: x = torch.ones(5) image_tensor = images.ImageTensor(x) @@ -102,7 +120,7 @@ def test_subclass(self) -> None: def test_pytorch_fftfreq(self) -> None: image = images.FFTImage((1, 1)) - _, _, fftfreq = image.get_fft_funcs() + _, _, fftfreq = image._get_fft_funcs() assertTensorAlmostEqual( self, fftfreq(4, 4), torch.as_tensor(np.fft.fftfreq(4, 4)), mode="max" ) @@ -114,7 +132,7 @@ def test_rfft2d_freqs(self) -> None: assertTensorAlmostEqual( self, - image.rfft2d_freqs(height, width), + image._rfft2d_freqs(height, width), torch.tensor([[0.0000, 0.3333], [0.5000, 0.6009]]), ) @@ -308,6 +326,33 @@ def test_fftimage_forward_init_batch(self) -> None: self, fftimage_tensor.detach(), fftimage_array, 25.0, mode="max" ) + def test_fftimage_forward_dtype_float64(self) -> None: + dtype = torch.float64 + image_param = images.FFTImage(size=(224, 224)).to(dtype=dtype) + output = image_param() + self.assertEqual(output.dtype, dtype) + + def test_fftimage_forward_dtype_float32(self) -> None: + dtype = torch.float32 + image_param = images.FFTImage(size=(224, 224)).to(dtype=dtype) + output = image_param() + self.assertEqual(output.dtype, dtype) + + def test_fftimage_forward_dtype_float16(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.12.0"): + raise unittest.SkipTest( + "Skipping FFTImage float16 dtype test due to" + + " insufficient Torch version." + ) + dtype = torch.float16 + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping FFTImage float16 dtype test due to not supporting CUDA." + ) + image_param = images.FFTImage(size=(256, 256)).cuda().to(dtype=dtype) + output = image_param() + self.assertEqual(output.dtype, dtype) + class TestPixelImage(BaseTest): def test_subclass(self) -> None: @@ -317,12 +362,7 @@ def test_pixelimage_random(self) -> None: size = (224, 224) channels = 3 image_param = images.PixelImage(size=size, channels=channels) - - self.assertEqual(image_param.image.dim(), 4) - self.assertEqual(image_param.image.size(0), 1) - self.assertEqual(image_param.image.size(1), channels) - self.assertEqual(image_param.image.size(2), size[0]) - self.assertEqual(image_param.image.size(3), size[1]) + self.assertEqual(list(image_param.image.shape), [1, channels] + list(size)) self.assertTrue(image_param.image.requires_grad) def test_pixelimage_init(self) -> None: @@ -331,11 +371,7 @@ def test_pixelimage_init(self) -> None: init_tensor = torch.randn(channels, *size) image_param = images.PixelImage(size=size, channels=channels, init=init_tensor) - self.assertEqual(image_param.image.dim(), 4) - self.assertEqual(image_param.image.size(0), 1) - self.assertEqual(image_param.image.size(1), channels) - self.assertEqual(image_param.image.size(2), size[0]) - self.assertEqual(image_param.image.size(3), size[1]) + self.assertEqual(list(image_param.image.shape), [1, channels] + list(size)) assertTensorAlmostEqual(self, image_param.image, init_tensor[None, :], 0) self.assertTrue(image_param.image.requires_grad) @@ -344,12 +380,7 @@ def test_pixelimage_random_forward(self) -> None: channels = 3 image_param = images.PixelImage(size=size, channels=channels) test_tensor = image_param.forward().rename(None) - - self.assertEqual(test_tensor.dim(), 4) - self.assertEqual(test_tensor.size(0), 1) - self.assertEqual(test_tensor.size(1), channels) - self.assertEqual(test_tensor.size(2), size[0]) - self.assertEqual(test_tensor.size(3), size[1]) + self.assertEqual(list(test_tensor.shape), [1, channels] + list(size)) def test_pixelimage_forward_jit_module(self) -> None: if version.parse(torch.__version__) <= version.parse("1.8.0"): @@ -369,13 +400,33 @@ def test_pixelimage_init_forward(self) -> None: image_param = images.PixelImage(size=size, channels=channels, init=init_tensor) test_tensor = image_param.forward().rename(None) - self.assertEqual(test_tensor.dim(), 4) - self.assertEqual(test_tensor.size(0), 1) - self.assertEqual(test_tensor.size(1), channels) - self.assertEqual(test_tensor.size(2), size[0]) - self.assertEqual(test_tensor.size(3), size[1]) + self.assertEqual(list(test_tensor.shape), [1, channels] + list(size)) assertTensorAlmostEqual(self, test_tensor, init_tensor[None, :], 0) + def test_pixelimage_forward_dtype_float64(self) -> None: + dtype = torch.float64 + image_param = images.PixelImage(size=(224, 224)).to(dtype=dtype) + output = image_param() + self.assertEqual(output.dtype, torch.float64) + + def test_pixelimage_forward_dtype_float32(self) -> None: + dtype = torch.float32 + image_param = images.PixelImage(size=(224, 224)).to(dtype=dtype) + output = image_param() + self.assertEqual(output.dtype, torch.float32) + + def test_pixelimage_forward_dtype_float16(self) -> None: + dtype = torch.float16 + image_param = images.PixelImage(size=(224, 224)).to(dtype) + output = image_param() + self.assertEqual(output.dtype, dtype) + + def test_pixelimage_forward_dtype_bfloat16(self) -> None: + dtype = torch.bfloat16 + image_param = images.PixelImage(size=(224, 224)).to(dtype=dtype) + output = image_param() + self.assertEqual(output.dtype, dtype) + class TestLaplacianImage(BaseTest): def test_subclass(self) -> None: @@ -384,21 +435,67 @@ def test_subclass(self) -> None: def test_laplacianimage_random_forward(self) -> None: size = (224, 224) channels = 3 - image_param = images.LaplacianImage(size=size, channels=channels) + batch = 1 + image_param = images.LaplacianImage(size=size, channels=channels, batch=batch) test_tensor = image_param.forward().rename(None) + self.assertEqual(list(test_tensor.shape), [batch, channels, size[0], size[1]]) + self.assertTrue(test_tensor.requires_grad) - self.assertEqual(test_tensor.dim(), 4) - self.assertEqual(test_tensor.size(0), 1) - self.assertEqual(test_tensor.size(1), channels) - self.assertEqual(test_tensor.size(2), size[0]) - self.assertEqual(test_tensor.size(3), size[1]) + def test_laplacianimage_random_forward_batch_5(self) -> None: + size = (224, 224) + channels = 3 + batch = 5 + image_param = images.LaplacianImage(size=size, channels=channels, batch=batch) + test_tensor = image_param.forward().rename(None) + self.assertEqual(list(test_tensor.shape), [batch, channels, size[0], size[1]]) - def test_laplacianimage_init(self) -> None: - init_t = torch.zeros(1, 224, 224) - image_param = images.LaplacianImage(size=(224, 224), channels=3, init=init_t) + def test_laplacianimage_random_forward_scale_list(self) -> None: + size = (224, 224) + channels = 3 + batch = 1 + scale_list = [1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 56.0, 112.0] + image_param = images.LaplacianImage( + size=size, channels=channels, batch=batch, scale_list=scale_list + ) + test_tensor = image_param.forward().rename(None) + self.assertEqual(list(test_tensor.shape), [batch, channels, size[0], size[1]]) + + def test_laplacianimage_random_forward_scale_list_error(self) -> None: + scale_list = [1.0, 2.0, 4.0, 8.0, 16.0, 64.0, 144.0] + with self.assertRaises(AssertionError): + images.LaplacianImage( + size=(224, 224), channels=3, batch=1, scale_list=scale_list + ) + + def test_laplacianimage_init_tensor(self) -> None: + init_tensor = torch.zeros(1, 3, 224, 224) + image_param = images.LaplacianImage(init=init_tensor) output = image_param.forward().detach().rename(None) assertTensorAlmostEqual(self, torch.ones_like(output) * 0.5, output, mode="max") + def test_laplacianimage_random_forward_cuda(self) -> None: + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping LaplacianImage CUDA test due to not supporting CUDA." + ) + image_param = images.LaplacianImage(size=(224, 224), channels=3, batch=1).cuda() + test_tensor = image_param.forward().rename(None) + self.assertTrue(test_tensor.is_cuda) + self.assertEqual(list(test_tensor.shape), [1, 3, 224, 224]) + self.assertTrue(test_tensor.requires_grad) + + def test_laplcianimage_forward_dtype_float64(self) -> None: + dtype = torch.float64 + image_param = images.LaplacianImage(size=(224, 224)).to(dtype=dtype) + output = image_param() + self.assertEqual(output.dtype, dtype) + + def test_laplcianimage_forward_dtype_float32(self) -> None: + dtype = torch.float32 + image_param = images.LaplacianImage(size=(224, 224)).to(dtype=dtype) + output = image_param() + self.assertEqual(output.dtype, dtype) + class TestSimpleTensorParameterization(BaseTest): def test_subclass(self) -> None: @@ -674,12 +771,7 @@ def test_interpolate_tensor(self) -> None: output_tensor = image_param._interpolate_tensor( test_tensor, batch, channels, size[0], size[1] ) - - self.assertEqual(output_tensor.dim(), 4) - self.assertEqual(output_tensor.size(0), batch) - self.assertEqual(output_tensor.size(1), channels) - self.assertEqual(output_tensor.size(2), size[0]) - self.assertEqual(output_tensor.size(3), size[1]) + self.assertEqual(list(output_tensor.shape), [batch, channels] + list(size)) def test_sharedimage_single_shape_hw_forward(self) -> None: shared_shapes = (128 // 2, 128 // 2) @@ -697,11 +789,7 @@ def test_sharedimage_single_shape_hw_forward(self) -> None: self.assertEqual( list(image_param.shared_init[0]().shape), [1, 1] + list(shared_shapes) ) - self.assertEqual(test_tensor.dim(), 4) - self.assertEqual(test_tensor.size(0), batch) - self.assertEqual(test_tensor.size(1), channels) - self.assertEqual(test_tensor.size(2), size[0]) - self.assertEqual(test_tensor.size(3), size[1]) + self.assertEqual(list(test_tensor.shape), [batch, channels] + list(size)) def test_sharedimage_single_shape_chw_forward(self) -> None: shared_shapes = (3, 128 // 2, 128 // 2) @@ -719,11 +807,7 @@ def test_sharedimage_single_shape_chw_forward(self) -> None: self.assertEqual( list(image_param.shared_init[0]().shape), [1] + list(shared_shapes) ) - self.assertEqual(test_tensor.dim(), 4) - self.assertEqual(test_tensor.size(0), batch) - self.assertEqual(test_tensor.size(1), channels) - self.assertEqual(test_tensor.size(2), size[0]) - self.assertEqual(test_tensor.size(3), size[1]) + self.assertEqual(list(test_tensor.shape), [batch, channels] + list(size)) def test_sharedimage_single_shape_bchw_forward(self) -> None: shared_shapes = (1, 3, 128 // 2, 128 // 2) @@ -739,11 +823,7 @@ def test_sharedimage_single_shape_bchw_forward(self) -> None: self.assertIsNone(image_param.offset) self.assertEqual(image_param.shared_init[0]().dim(), 4) self.assertEqual(list(image_param.shared_init[0]().shape), list(shared_shapes)) - self.assertEqual(test_tensor.dim(), 4) - self.assertEqual(test_tensor.size(0), batch) - self.assertEqual(test_tensor.size(1), channels) - self.assertEqual(test_tensor.size(2), size[0]) - self.assertEqual(test_tensor.size(3), size[1]) + self.assertEqual(list(test_tensor.shape), [batch, channels] + list(size)) def test_sharedimage_multiple_shapes_forward(self) -> None: shared_shapes = ( @@ -769,11 +849,7 @@ def test_sharedimage_multiple_shapes_forward(self) -> None: self.assertEqual( list(image_param.shared_init[i]().shape), list(shared_shapes[i]) ) - self.assertEqual(test_tensor.dim(), 4) - self.assertEqual(test_tensor.size(0), batch) - self.assertEqual(test_tensor.size(1), channels) - self.assertEqual(test_tensor.size(2), size[0]) - self.assertEqual(test_tensor.size(3), size[1]) + self.assertEqual(list(test_tensor.shape), [batch, channels] + list(size)) def test_sharedimage_multiple_shapes_diff_len_forward(self) -> None: shared_shapes = ( @@ -800,11 +876,7 @@ def test_sharedimage_multiple_shapes_diff_len_forward(self) -> None: s_shape = ([1] * (4 - len(s_shape))) + list(s_shape) self.assertEqual(list(image_param.shared_init[i]().shape), s_shape) - self.assertEqual(test_tensor.dim(), 4) - self.assertEqual(test_tensor.size(0), batch) - self.assertEqual(test_tensor.size(1), channels) - self.assertEqual(test_tensor.size(2), size[0]) - self.assertEqual(test_tensor.size(3), size[1]) + self.assertEqual(list(test_tensor.shape), [batch, channels] + list(size)) def test_sharedimage_multiple_shapes_diff_len_forward_jit_module(self) -> None: if version.parse(torch.__version__) <= version.parse("1.8.0"): @@ -831,12 +903,7 @@ def test_sharedimage_multiple_shapes_diff_len_forward_jit_module(self) -> None: ) jit_image_param = torch.jit.script(image_param) test_tensor = jit_image_param() - - self.assertEqual(test_tensor.dim(), 4) - self.assertEqual(test_tensor.size(0), batch) - self.assertEqual(test_tensor.size(1), channels) - self.assertEqual(test_tensor.size(2), size[0]) - self.assertEqual(test_tensor.size(3), size[1]) + self.assertEqual(list(test_tensor.shape), [batch, channels] + list(size)) class TestStackImage(BaseTest): @@ -1071,11 +1138,19 @@ def test_natural_image_init_func_pixelimage(self) -> None: self.assertIsInstance(image_param.decorrelate, ToRGB) self.assertEqual(image_param.squash_func, torch.sigmoid) - def test_natural_image_init_func_default_init_tensor(self) -> None: - image_param = images.NaturalImage(init=torch.ones(1, 3, 1, 1)) + def test_natural_image_custom_squash_func(self) -> None: + init_tensor = torch.randn(1, 3, 1, 1) + + def clamp_image(x: torch.Tensor) -> torch.Tensor: + return x.clamp(0, 1) + + image_param = images.NaturalImage(init=init_tensor, squash_func=clamp_image) + image = image_param.forward().detach() + self.assertIsInstance(image_param.parameterization, images.FFTImage) self.assertIsInstance(image_param.decorrelate, ToRGB) - self.assertEqual(image_param.squash_func, image_param._clamp_image) + self.assertEqual(image_param.squash_func, clamp_image) + assertTensorAlmostEqual(self, image, init_tensor.clamp(0, 1)) def test_natural_image_init_tensor_pixelimage_sf_sigmoid(self) -> None: if version.parse(torch.__version__) <= version.parse("1.8.0"): @@ -1084,10 +1159,10 @@ def test_natural_image_init_tensor_pixelimage_sf_sigmoid(self) -> None: + " test due to insufficient Torch version." ) image_param = images.NaturalImage( - init=torch.ones(1, 3, 1, 1), + init=torch.ones(1, 3, 1, 1).float(), parameterization=images.PixelImage, squash_func=torch.sigmoid, - ) + ).to(dtype=torch.float32) output_tensor = image_param() self.assertEqual(image_param.squash_func, torch.sigmoid) @@ -1103,9 +1178,10 @@ def test_natural_image_0(self) -> None: ) def test_natural_image_1(self) -> None: - image_param = images.NaturalImage(init=torch.ones(3, 1, 1)) + init_tensor = torch.ones(3, 1, 1) + image_param = images.NaturalImage(init=init_tensor) image = image_param.forward().detach() - assertTensorAlmostEqual(self, image, torch.ones_like(image), mode="max") + assertTensorAlmostEqual(self, image, torch.sigmoid(init_tensor).unsqueeze(0)) def test_natural_image_cuda(self) -> None: if not torch.cuda.is_available(): @@ -1132,10 +1208,11 @@ def test_natural_image_jit_module_init_tensor(self) -> None: "Skipping NaturalImage init tensor JIT module test due to" + " insufficient Torch version." ) - image_param = images.NaturalImage(init=torch.ones(1, 3, 1, 1)) + init_tensor = torch.ones(1, 3, 1, 1) + image_param = images.NaturalImage(init=init_tensor) jit_image_param = torch.jit.script(image_param) output_tensor = jit_image_param() - assertTensorAlmostEqual(self, output_tensor, torch.ones_like(output_tensor)) + assertTensorAlmostEqual(self, output_tensor, torch.sigmoid(init_tensor)) def test_natural_image_jit_module_init_tensor_pixelimage(self) -> None: if version.parse(torch.__version__) <= version.parse("1.8.0"): @@ -1143,12 +1220,13 @@ def test_natural_image_jit_module_init_tensor_pixelimage(self) -> None: "Skipping NaturalImage PixelImage init tensor JIT module" + " test due to insufficient Torch version." ) + init_tensor = torch.ones(1, 3, 1, 1) image_param = images.NaturalImage( - init=torch.ones(1, 3, 1, 1), parameterization=images.PixelImage + init=init_tensor, parameterization=images.PixelImage ) jit_image_param = torch.jit.script(image_param) output_tensor = jit_image_param() - assertTensorAlmostEqual(self, output_tensor, torch.ones_like(output_tensor)) + assertTensorAlmostEqual(self, output_tensor, torch.sigmoid(init_tensor)) def test_natural_image_decorrelation_module_none(self) -> None: if version.parse(torch.__version__) <= version.parse("1.8.0"): @@ -1156,9 +1234,43 @@ def test_natural_image_decorrelation_module_none(self) -> None: "Skipping NaturalImage no decorrelation module" + " test due to insufficient Torch version." ) - image_param = images.NaturalImage( - init=torch.ones(1, 3, 4, 4), decorrelation_module=None - ) + init_tensor = torch.ones(1, 3, 1, 1) + image_param = images.NaturalImage(init=init_tensor, decorrelation_module=None) image = image_param.forward().detach() self.assertIsNone(image_param.decorrelate) - assertTensorAlmostEqual(self, image, torch.ones_like(image)) + assertTensorAlmostEqual(self, image, torch.sigmoid(init_tensor)) + + def test_natural_image_forward_dtype_float64(self) -> None: + dtype = torch.float64 + image_param = images.NaturalImage( + size=(224, 224), decorrelation_module=ToRGB("klt") + ).to(dtype=dtype) + output = image_param() + self.assertEqual(output.dtype, dtype) + + def test_natural_image_forward_dtype_float32(self) -> None: + dtype = torch.float32 + image_param = images.NaturalImage( + size=(224, 224), decorrelation_module=ToRGB("klt") + ).to(dtype=dtype) + output = image_param() + self.assertEqual(output.dtype, dtype) + + def test_fftimage_forward_dtype_float16(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.12.0"): + raise unittest.SkipTest( + "Skipping NaturalImage float16 dtype test due to" + + " insufficient Torch version." + ) + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping NaturalImage float16 dtype test due to not supporting CUDA." + ) + dtype = torch.float16 + image_param = ( + images.NaturalImage(size=(256, 256), decorrelation_module=ToRGB("klt")) + .cuda() + .to(dtype=dtype) + ) + output = image_param() + self.assertEqual(output.dtype, dtype) diff --git a/tests/optim/param/test_transforms.py b/tests/optim/param/test_transforms.py index 385006a7a..3568b4c53 100644 --- a/tests/optim/param/test_transforms.py +++ b/tests/optim/param/test_transforms.py @@ -261,6 +261,24 @@ def test_random_scale_jit_module(self) -> None: 0, ) + def test_random_scale_dtype_float64(self) -> None: + dtype = torch.float64 + scale_module = transforms.RandomScale(scale=[0.975, 1.025, 0.95, 1.05]).to( + dtype=dtype + ) + x = torch.ones([1, 3, 224, 224], dtype=dtype) + output = scale_module(x) + self.assertEqual(output.dtype, dtype) + + def test_random_scale_dtype_float32(self) -> None: + dtype = torch.float32 + scale_module = transforms.RandomScale(scale=[0.975, 1.025, 0.95, 1.05]).to( + dtype=dtype + ) + x = torch.ones([1, 3, 224, 224], dtype=dtype) + output = scale_module(x) + self.assertEqual(output.dtype, dtype) + class TestRandomScaleAffine(BaseTest): def test_random_scale_affine_init(self) -> None: @@ -430,6 +448,40 @@ def test_random_scale_affine_jit_module(self) -> None: 0, ) + def test_random_scale_affine_dtype_float64(self) -> None: + dtype = torch.float64 + scale_module = transforms.RandomScaleAffine( + scale=[0.975, 1.025, 0.95, 1.05] + ).to(dtype=dtype) + x = torch.ones([1, 3, 224, 224], dtype=dtype) + output = scale_module(x) + self.assertEqual(output.dtype, dtype) + + def test_random_scale_affine_dtype_float32(self) -> None: + dtype = torch.float32 + scale_module = transforms.RandomScaleAffine( + scale=[0.975, 1.025, 0.95, 1.05] + ).to(dtype=dtype) + x = torch.ones([1, 3, 224, 224], dtype=dtype) + output = scale_module(x) + self.assertEqual(output.dtype, dtype) + + def test_random_scale_affine_dtype_float16(self) -> None: + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping RandomScaleAffine float16 dtype test due to not supporting" + + " CUDA." + ) + dtype = torch.float16 + scale_module = ( + transforms.RandomScaleAffine(scale=[0.975, 1.025, 0.95, 1.05]) + .cuda() + .to(dtype=dtype) + ) + x = torch.ones([1, 3, 224, 224], dtype=dtype).cuda() + output = scale_module(x) + self.assertEqual(output.dtype, dtype) + class TestRandomRotation(BaseTest): def test_random_rotation_init(self) -> None: @@ -629,6 +681,37 @@ def test_random_rotation_jit_module(self) -> None: ) assertTensorAlmostEqual(self, test_output, expected_output, 0.005) + def test_random_rotation_dtype_float64(self) -> None: + dtype = torch.float64 + degrees = list(range(-25, -5)) + list(range(5, 25)) + rotation_module = transforms.RandomRotation(degrees=degrees).to(dtype=dtype) + x = torch.ones([1, 3, 224, 224], dtype=dtype) + output = rotation_module(x) + self.assertEqual(output.dtype, dtype) + + def test_random_rotation_dtype_float32(self) -> None: + dtype = torch.float32 + degrees = list(range(-25, -5)) + list(range(5, 25)) + rotation_module = transforms.RandomRotation(degrees=degrees).to(dtype=dtype) + x = torch.ones([1, 3, 224, 224], dtype=dtype) + output = rotation_module(x) + self.assertEqual(output.dtype, dtype) + + def test_random_rotation_dtype_float16(self) -> None: + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping RandomRotation float16 dtype test due to not supporting" + + " CUDA." + ) + dtype = torch.float16 + degrees = list(range(-25, -5)) + list(range(5, 25)) + rotation_module = ( + transforms.RandomRotation(degrees=degrees).cuda().to(dtype=dtype) + ) + x = torch.ones([1, 3, 224, 224], dtype=dtype).cuda() + output = rotation_module(x) + self.assertEqual(output.dtype, dtype) + class TestRandomSpatialJitter(BaseTest): def test_random_spatial_jitter_init(self) -> None: @@ -714,6 +797,20 @@ def test_random_spatial_jitter_forward_jit_module(self) -> None: jittered_tensor = jit_spatialjitter(test_input) self.assertEqual(list(jittered_tensor.shape), list(test_input.shape)) + def test_random_spatial_jitter_dtype_float64(self) -> None: + dtype = torch.float64 + spatialjitter = transforms.RandomSpatialJitter(5).to(dtype=dtype) + x = torch.ones([1, 3, 224, 224], dtype=dtype) + output = spatialjitter(x) + self.assertEqual(output.dtype, dtype) + + def test_random_spatial_jitter_dtype_float32(self) -> None: + dtype = torch.float32 + spatialjitter = transforms.RandomSpatialJitter(5).to(dtype=dtype) + x = torch.ones([1, 3, 224, 224], dtype=dtype) + output = spatialjitter(x) + self.assertEqual(output.dtype, dtype) + class TestCenterCrop(BaseTest): def test_center_crop_init(self) -> None: @@ -1574,6 +1671,35 @@ def test_to_rgb_klt_forward_jit_module(self) -> None: self, inverse_tensor, torch.ones_like(inverse_tensor.rename(None)) ) + def test_to_rgb_dtype_float64(self) -> None: + dtype = torch.float64 + to_rgb = transforms.ToRGB(transform="klt").to(dtype=dtype) + test_tensor = torch.ones(1, 3, 224, 224, dtype=dtype) + output = to_rgb(test_tensor.refine_names("B", "C", "H", "W")) + self.assertEqual(output.dtype, dtype) + inverse_output = to_rgb(output, inverse=True) + self.assertEqual(inverse_output.dtype, dtype) + + def test_to_rgb_dtype_float32(self) -> None: + dtype = torch.float32 + to_rgb = transforms.ToRGB(transform="klt").to(dtype=dtype) + test_tensor = torch.ones(1, 3, 224, 224, dtype=dtype) + output = to_rgb(test_tensor.refine_names("B", "C", "H", "W")) + self.assertEqual(output.dtype, dtype) + inverse_output = to_rgb(output, inverse=True) + self.assertEqual(inverse_output.dtype, dtype) + + def test_to_rgb_dtype_float16_cuda(self) -> None: + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping ToRGB float16 dtype test due to not supporting CUDA." + ) + dtype = torch.float16 + to_rgb = transforms.ToRGB(transform="klt").cuda().to(dtype=dtype) + test_tensor = torch.ones(1, 3, 224, 224, dtype=dtype).cuda() + output = to_rgb(test_tensor.refine_names("B", "C", "H", "W")) + self.assertEqual(output.dtype, dtype) + class TestGaussianSmoothing(BaseTest): def test_gaussian_smoothing_init_1d(self) -> None: @@ -1582,11 +1708,17 @@ def test_gaussian_smoothing_init_1d(self) -> None: sigma = 2.0 dim = 1 smoothening_module = transforms.GaussianSmoothing( - channels, kernel_size, sigma, dim + channels, + kernel_size, + sigma, + dim, + padding=0, ) self.assertEqual(smoothening_module.groups, channels) + self.assertEqual(smoothening_module.padding, 0) weight = torch.tensor([[0.3192, 0.3617, 0.3192]]).repeat(6, 1, 1) assertTensorAlmostEqual(self, smoothening_module.weight, weight, 0.001) + self.assertFalse(smoothening_module.padding) def test_gaussian_smoothing_init_2d(self) -> None: channels = 3 @@ -1594,7 +1726,11 @@ def test_gaussian_smoothing_init_2d(self) -> None: sigma = 2.0 dim = 2 smoothening_module = transforms.GaussianSmoothing( - channels, kernel_size, sigma, dim + channels, + kernel_size, + sigma, + dim, + padding=0, ) self.assertEqual(smoothening_module.groups, channels) weight = torch.tensor( @@ -1614,7 +1750,11 @@ def test_gaussian_smoothing_init_3d(self) -> None: sigma = 1.021 dim = 3 smoothening_module = transforms.GaussianSmoothing( - channels, kernel_size, sigma, dim + channels, + kernel_size, + sigma, + dim, + padding=0, ) self.assertEqual(smoothening_module.groups, channels) weight = torch.tensor( @@ -1654,7 +1794,11 @@ def test_gaussian_smoothing_1d(self) -> None: sigma = 2.0 dim = 1 smoothening_module = transforms.GaussianSmoothing( - channels, kernel_size, sigma, dim + channels, + kernel_size, + sigma, + dim, + padding=0, ) test_tensor = torch.tensor([1.0, 5.0]).repeat(6, 2).unsqueeze(0) @@ -1671,7 +1815,11 @@ def test_gaussian_smoothing_2d(self) -> None: sigma = 2.0 dim = 2 smoothening_module = transforms.GaussianSmoothing( - channels, kernel_size, sigma, dim + channels, + kernel_size, + sigma, + dim, + padding=0, ) test_tensor = torch.tensor([1.0, 5.0]).repeat(3, 6, 3).unsqueeze(0) @@ -1688,7 +1836,11 @@ def test_gaussian_smoothing_3d(self) -> None: sigma = 1.021 dim = 3 smoothening_module = transforms.GaussianSmoothing( - channels, kernel_size, sigma, dim + channels, + kernel_size, + sigma, + dim, + padding=0, ) test_tensor = torch.tensor([1.0, 5.0, 1.0]).repeat(4, 6, 6, 2).unsqueeze(0) @@ -1712,7 +1864,11 @@ def test_gaussian_smoothing_2d_jit_module(self) -> None: sigma = 2.0 dim = 2 smoothening_module = transforms.GaussianSmoothing( - channels, kernel_size, sigma, dim + channels, + kernel_size, + sigma, + dim, + padding=0, ) jit_smoothening_module = torch.jit.script(smoothening_module) @@ -1801,12 +1957,15 @@ def check_grad(self, grad_input, grad_output): class SymmetricPaddingLayer(torch.nn.Module): def forward( - self, x: torch.Tensor, padding: List[List[int]] + self, x_input: torch.Tensor, padding: List[List[int]] ) -> torch.Tensor: - return transforms.SymmetricPadding.apply(x_pt, padding) + return transforms.SymmetricPadding.apply(x_input, padding) sym_pad = SymmetricPaddingLayer() - sym_pad.register_backward_hook(check_grad) + if version.parse(torch.__version__) >= version.parse("1.8.0"): + sym_pad.register_full_backward_hook(check_grad) + else: + sym_pad.register_backward_hook(check_grad) x_out = sym_pad(x_pt, offset_pad) (x_out.sum() * 1).backward() @@ -2008,3 +2167,17 @@ def test_transform_robustness_forward_padding_crop_output_jit_module(self) -> No test_input = torch.ones(1, 3, 224, 224) test_output = transform_robustness(test_input) self.assertEqual(test_output.shape, test_input.shape) + + def test_transform_robustness_dtype_float64(self) -> None: + dtype = torch.float64 + transform_robustness = transforms.TransformationRobustness().to(dtype=dtype) + x = torch.ones([1, 3, 224, 224], dtype=dtype) + output = transform_robustness(x) + self.assertEqual(output.dtype, dtype) + + def test_transform_robustness_dtype_float32(self) -> None: + dtype = torch.float32 + transform_robustness = transforms.TransformationRobustness().to(dtype=dtype) + x = torch.ones([1, 3, 224, 224], dtype=dtype) + output = transform_robustness(x) + self.assertEqual(output.dtype, dtype) diff --git a/tests/optim/utils/test_reducer.py b/tests/optim/utils/test_reducer.py index f2baa7675..a9fb9cc93 100644 --- a/tests/optim/utils/test_reducer.py +++ b/tests/optim/utils/test_reducer.py @@ -34,10 +34,10 @@ def test_channelreducer_pytorch(self) -> None: test_input = torch.randn(1, 32, 224, 224).abs() c_reducer = reducer.ChannelReducer(n_components=3, max_iter=100) test_output = c_reducer.fit_transform(test_input) - self.assertEquals(test_output.size(0), 1) - self.assertEquals(test_output.size(1), 3) - self.assertEquals(test_output.size(2), 224) - self.assertEquals(test_output.size(3), 224) + self.assertEqual(test_output.size(0), 1) + self.assertEqual(test_output.size(1), 3) + self.assertEqual(test_output.size(2), 224) + self.assertEqual(test_output.size(3), 224) def test_channelreducer_pytorch_dim_three(self) -> None: try: @@ -52,9 +52,7 @@ def test_channelreducer_pytorch_dim_three(self) -> None: test_input = torch.randn(32, 224, 224).abs() c_reducer = reducer.ChannelReducer(n_components=3, max_iter=100) test_output = c_reducer.fit_transform(test_input) - self.assertEquals(test_output.size(0), 3) - self.assertEquals(test_output.size(1), 224) - self.assertEquals(test_output.size(2), 224) + self.assertEqual(list(test_output.shape), [3, 224, 224]) def test_channelreducer_pytorch_pca(self) -> None: try: @@ -70,10 +68,7 @@ def test_channelreducer_pytorch_pca(self) -> None: c_reducer = reducer.ChannelReducer(n_components=3, reduction_alg="PCA") test_output = c_reducer.fit_transform(test_input) - self.assertEquals(test_output.size(0), 1) - self.assertEquals(test_output.size(1), 3) - self.assertEquals(test_output.size(2), 224) - self.assertEquals(test_output.size(3), 224) + self.assertEqual(list(test_output.shape), [1, 3, 224, 224]) def test_channelreducer_pytorch_custom_alg(self) -> None: test_input = torch.randn(1, 32, 224, 224).abs() @@ -82,10 +77,7 @@ def test_channelreducer_pytorch_custom_alg(self) -> None: n_components=3, reduction_alg=reduction_alg, max_iter=100 ) test_output = c_reducer.fit_transform(test_input) - self.assertEquals(test_output.size(0), 1) - self.assertEquals(test_output.size(1), 3) - self.assertEquals(test_output.size(2), 224) - self.assertEquals(test_output.size(3), 224) + self.assertEqual(list(test_output.shape), [1, 3, 224, 224]) def test_channelreducer_pytorch_custom_alg_components(self) -> None: reduction_alg = FakeReductionAlgorithm @@ -149,10 +141,7 @@ def test_channelreducer_noreshape_pytorch(self) -> None: test_input = torch.randn(1, 224, 224, 32).abs() c_reducer = reducer.ChannelReducer(n_components=3, max_iter=100) test_output = c_reducer.fit_transform(test_input, swap_2nd_and_last_dims=False) - self.assertEquals(test_output.size(0), 1) - self.assertEquals(test_output.size(1), 224) - self.assertEquals(test_output.size(2), 224) - self.assertEquals(test_output.size(3), 3) + self.assertEqual(list(test_output.shape), [1, 224, 224, 3]) def test_channelreducer_error(self) -> None: if not torch.cuda.is_available():