diff --git a/captum/optim/_param/image/__init__.py b/captum/optim/_param/image/__init__.py
index a2311f7c4..5c36c0c80 100755
--- a/captum/optim/_param/image/__init__.py
+++ b/captum/optim/_param/image/__init__.py
@@ -1 +1 @@
-"""(Differentiable) Input Parameterizations. Currently only 3-channel images"""
+"""(Differentiable) Input Parameterizations. Currently only images"""
diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py
index fa313b38a..f3a45346c 100644
--- a/captum/optim/_param/image/images.py
+++ b/captum/optim/_param/image/images.py
@@ -1,4 +1,3 @@
-from copy import deepcopy
 from types import MethodType
 from typing import Callable, List, Optional, Tuple, Type, Union
 
@@ -37,7 +36,12 @@ def __new__(
         Returns:
            x (ImageTensor): An `ImageTensor` instance.
         """
-        if isinstance(x, torch.Tensor) and x.is_cuda:
+        if (
+            isinstance(x, torch.Tensor)
+            and x.is_cuda
+            or isinstance(x, torch.Tensor)
+            and x.dtype != torch.float32
+        ):
             x.show = MethodType(cls.show, x)
             x.export = MethodType(cls.export, x)
             return x
@@ -181,12 +185,32 @@ def forward(self) -> torch.Tensor:
 
 
 class ImageParameterization(InputParameterization):
+    r"""The base class for all Image Parameterizations"""
     pass
 
 
 class FFTImage(ImageParameterization):
     """
     Parameterize an image using inverse real 2D FFT
+
+    Example::
+
+        >>> fft_image = opt.images.FFTImage(size=(224, 224))
+        >>> output_image = fft_image()
+        >>> print(output_image.required_grad)
+        True
+        >>> print(output_image.shape)
+        torch.Size([1, 3, 224, 224])
+
+    Example for using an initialization tensor::
+
+        >>> init = torch.randn(1, 3, 224, 224)
+        >>> fft_image = opt.images.FFTImage(init=init)
+        >>> output_image = fft_image()
+        >>> print(output_image.required_grad)
+        True
+        >>> print(output_image.shape)
+        torch.Size([1, 3, 224, 224])
     """
 
     __constants__ = ["size"]
@@ -201,16 +225,16 @@ def __init__(
         """
         Args:
 
-            size (Tuple[int, int]): The height & width dimensions to use for the
-                parameterized output image tensor.
+            size (tuple of int): The height & width dimensions to use for the
+                parameterized output image tensor, in the format of: (height, width).
             channels (int, optional): The number of channels to use for each image.
-                Default: 3
+                Default: ``3``
             batch (int, optional): The number of images to stack along the batch
                 dimension.
-                Default: 1
-            init (torch.tensor, optional): Optionally specify a tensor to
+                Default: ``1``
+            init (torch.Tensor, optional): Optionally specify a CHW or NCHW tensor to
                 use instead of creating one.
-                Default: None
+                Default: ``None``
         """
         super().__init__()
         if init is None:
@@ -221,9 +245,9 @@ def __init__(
             if init.dim() == 3:
                 init = init.unsqueeze(0)
             self.size = (init.size(2), init.size(3))
-        self.torch_rfft, self.torch_irfft, self.torch_fftfreq = self.get_fft_funcs()
+        self.torch_rfft, self.torch_irfft, self.torch_fftfreq = self._get_fft_funcs()
 
-        frequencies = self.rfft2d_freqs(*self.size)
+        frequencies = self._rfft2d_freqs(*self.size)
         scale = 1.0 / torch.max(
             frequencies,
             torch.full_like(frequencies, 1.0 / (max(self.size[0], self.size[1]))),
@@ -250,7 +274,7 @@ def __init__(
         self.register_buffer("spectrum_scale", spectrum_scale)
         self.fourier_coeffs = nn.Parameter(fourier_coeffs)
 
-    def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor:
+    def _rfft2d_freqs(self, height: int, width: int) -> torch.Tensor:
         """
         Computes 2D spectrum frequencies.
 
@@ -260,7 +284,7 @@ def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor:
             width (int): The w dimension of the 2d frequency scale.
 
         Returns:
-            **tensor** (tensor): A 2d frequency scale tensor.
+            tensor (torch.Tensor): A 2d frequency scale tensor.
         """
 
         fy = self.torch_fftfreq(height)[:, None]
@@ -268,20 +292,20 @@ def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor:
         return torch.sqrt((fx * fx) + (fy * fy))
 
     @torch.jit.export
-    def torch_irfftn(self, x: torch.Tensor) -> torch.Tensor:
-        if x.dtype != torch.complex64:
+    def _torch_irfftn(self, x: torch.Tensor) -> torch.Tensor:
+        if not torch.is_complex(x):
             x = torch.view_as_complex(x)
         return torch.fft.irfftn(x, s=self.size)  # type: ignore
 
-    def get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]:
+    def _get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]:
         """
         Support older versions of PyTorch. This function ensures that the same FFT
         operations are carried regardless of whether your PyTorch version has the
         torch.fft update.
 
         Returns:
-            fft functions (tuple of Callable): A list of FFT functions
-                to use for irfft, rfft, and fftfreq operations.
+            fft_functions (tuple of callable): A list of FFT functions to use for
+                irfft, rfft, and fftfreq operations.
         """
 
         if version.parse(TORCH_VERSION) > version.parse("1.7.0"):
@@ -292,7 +316,7 @@ def get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]:
             def torch_rfft(x: torch.Tensor) -> torch.Tensor:
                 return torch.view_as_real(torch.fft.rfftn(x, s=self.size))
 
-            torch_irfftn = self.torch_irfftn
+            torch_irfftn = self._torch_irfftn
 
             def torch_fftfreq(v: int, d: float = 1.0) -> torch.Tensor:
                 return torch.fft.fftfreq(v, d)
@@ -320,7 +344,7 @@ def torch_fftfreq(v: int, d: float = 1.0) -> torch.Tensor:
     def forward(self) -> torch.Tensor:
         """
         Returns:
-            **output** (torch.tensor): A spatially recorrelated tensor.
+            output (torch.Tensor): A spatially recorrelated NCHW tensor.
         """
 
         scaled_spectrum = self.fourier_coeffs * self.spectrum_scale
@@ -333,6 +357,25 @@ def forward(self) -> torch.Tensor:
 class PixelImage(ImageParameterization):
     """
     Parameterize a simple pixel image tensor that requires no additional transforms.
+
+    Example::
+
+        >>> pixel_image = opt.images.PixelImage(size=(224, 224))
+        >>> output_image = pixel_image()
+        >>> print(output_image.required_grad)
+        True
+        >>> print(output_image.shape)
+        torch.Size([1, 3, 224, 224])
+
+    Example for using an initialization tensor::
+
+        >>> init = torch.randn(1, 3, 224, 224)
+        >>> pixel_image = opt.images.PixelImage(init=init)
+        >>> output_image = pixel_image()
+        >>> print(output_image.required_grad)
+        True
+        >>> print(output_image.shape)
+        torch.Size([1, 3, 224, 224])
     """
 
     def __init__(
@@ -345,16 +388,16 @@ def __init__(
         """
         Args:
 
-            size (Tuple[int, int]): The height & width dimensions to use for the
-                parameterized output image tensor.
+            size (tuple of int): The height & width dimensions to use for the
+                parameterized output image tensor, in the format of: (height, width).
             channels (int, optional): The number of channels to use for each image.
-                Default: 3
+                Default: ``3``
             batch (int, optional): The number of images to stack along the batch
                 dimension.
-                Default: 1
-            init (torch.tensor, optional): Optionally specify a tensor to
+                Default: ``1``
+            init (torch.Tensor, optional): Optionally specify a CHW or NCHW tensor to
                 use instead of creating one.
-                Default: None
+                Default: ``None``
         """
         super().__init__()
         if init is None:
@@ -367,6 +410,10 @@ def __init__(
         self.image = nn.Parameter(init)
 
     def forward(self) -> torch.Tensor:
+        """
+        Returns:
+            output (torch.Tensor): An NCHW tensor.
+        """
         if torch.jit.is_scripting():
             return self.image
         return self.image.refine_names("B", "C", "H", "W")
@@ -374,96 +421,101 @@ def forward(self) -> torch.Tensor:
 
 class LaplacianImage(ImageParameterization):
     """
-    TODO: Fix divison by 6 in setup_input when init is not None.
     Parameterize an image tensor with a laplacian pyramid.
+
+    Example::
+
+        >>> laplacian_image = opt.images.LaplacianImage(size=(224, 224))
+        >>> output_image = laplacian_image()
+        >>> print(output_image.required_grad)
+        True
+        >>> print(output_image.shape)
+        torch.Size([1, 3, 224, 224])
+
+    Example for using an initialization tensor::
+
+        >>> init = torch.randn(1, 3, 224, 224)
+        >>> laplacian_image = opt.images.LaplacianImage(init=init)
+        >>> output_image = laplacian_image()
+        >>> print(output_image.required_grad)
+        True
+        >>> print(output_image.shape)
+        torch.Size([1, 3, 224, 224])
     """
 
     def __init__(
         self,
-        size: Tuple[int, int] = None,
+        size: Tuple[int, int] = (224, 224),
         channels: int = 3,
         batch: int = 1,
         init: Optional[torch.Tensor] = None,
+        power: float = 0.1,
+        scale_list: List[float] = [1.0, 2.0, 4.0, 8.0, 16.0, 32.0],
     ) -> None:
         """
         Args:
 
-            size (Tuple[int, int]): The height & width dimensions to use for the
-                parameterized output image tensor.
+            size (tuple of int): The height & width dimensions to use for the
+                parameterized output image tensor, in the format of: (height, width).
             channels (int, optional): The number of channels to use for each image.
-                Default: 3
+                Default: ``3``
             batch (int, optional): The number of images to stack along the batch
                 dimension.
-                Default: 1
-            init (torch.tensor, optional): Optionally specify a tensor to
+                Default: ``1``
+            init (torch.Tensor, optional): Optionally specify a CHW or NCHW tensor to
                 use instead of creating one.
-                Default: None
+                Default: ``None``
+            power (float, optional): The desired power value to use.
+                Default: ``0.1``
+            scale_list (list of float, optional): The desired list of scale values to
+                use in the laplacian pyramid. The height & width dimensions specified
+                in ``size`` or used in the ``init`` tensor should be divisible by every
+                scale value in the scale list with no remainder left over. The default
+                ``scale_list`` values are set to work with a ``size`` of
+                ``(224, 224)``.
+                Default: ``[1.0, 2.0, 4.0, 8.0, 16.0, 32.0]``
         """
         super().__init__()
-        power = 0.1
-
-        if init is None:
-            tensor_params, self.scaler = self._setup_input(size, channels, power, init)
-
-            self.tensor_params = torch.nn.ModuleList(
-                [deepcopy(tensor_params) for b in range(batch)]
-            )
-        else:
+        if init is not None:
+            assert init.dim() in [3, 4]
             init = init.unsqueeze(0) if init.dim() == 3 else init
-            P = []
-            for b in range(init.size(0)):
-                tensor_params, self.scaler = self._setup_input(
-                    size, channels, power, init[b].unsqueeze(0)
-                )
-                P.append(tensor_params)
-            self.tensor_params = torch.nn.ModuleList(P)
+            size = list(init.shape[2:])
 
-    def _setup_input(
-        self,
-        size: Tuple[int, int],
-        channels: int,
-        power: float = 0.1,
-        init: Optional[torch.Tensor] = None,
-    ) -> Tuple[List[torch.Tensor], List[torch.nn.Upsample]]:
         tensor_params, scaler = [], []
-        scale_list = [1, 2, 4, 8, 16, 32]
         for scale in scale_list:
+            assert size[0] % scale == 0 and size[1] % scale == 0, (
+                "The chosen image height & width dimensions"
+                + " must be divisible by all scale values "
+                + " with no remainder left over."
+            )
+
             h, w = int(size[0] // scale), int(size[1] // scale)
             if init is None:
-                x = torch.randn([1, channels, h, w]) / 10
+                x = torch.randn([batch, channels, h, w]) / 10
             else:
                 x = F.interpolate(init.clone(), size=(h, w), mode="bilinear")
-                x = x / 6  # Prevents output from being all white
+                x = x / 10
             upsample = torch.nn.Upsample(scale_factor=scale, mode="nearest")
-            x = x * (scale**power) / (32**power)
+            x = x * (scale**power) / (max(scale_list) ** power)
             x = torch.nn.Parameter(x)
             tensor_params.append(x)
             scaler.append(upsample)
-        tensor_params = torch.nn.ParameterList(tensor_params)
-        return tensor_params, scaler
+        self.tensor_params = torch.nn.ParameterList(tensor_params)
+        self.scaler = scaler
 
-    def _create_tensor(self, params_list: torch.nn.ParameterList) -> torch.Tensor:
+    def forward(self) -> torch.Tensor:
         """
-        Resize tensor parameters to the target size.
-
-        Args:
-
-            params_list (torch.nn.ParameterList): List of tensors to resize.
-
         Returns:
-            **tensor** (torch.Tensor): The sum of all tensor parameters.
+            output (torch.Tensor): An NCHW tensor created from a laplacian pyramid.
         """
-        A: List[torch.Tensor] = []
-        for xi, upsamplei in zip(params_list, self.scaler):
+        A = []
+        for xi, upsamplei in zip(self.tensor_params, self.scaler):
             A.append(upsamplei(xi))
-        return torch.sum(torch.cat(A), 0) + 0.5
+        output = sum(A) + 0.5
 
-    def forward(self) -> torch.Tensor:
-        A: List[torch.Tensor] = []
-        for params_list in self.tensor_params:
-            tensor = self._create_tensor(params_list)
-            A.append(tensor)
-        return torch.stack(A).refine_names("B", "C", "H", "W")
+        if torch.jit.is_scripting():
+            return output
+        return output.refine_names("B", "C", "H", "W")
 
 
 class SimpleTensorParameterization(ImageParameterization):
@@ -484,7 +536,8 @@ def __init__(self, tensor: torch.Tensor = None) -> None:
         """
         Args:
 
-            tensor (torch.tensor): The tensor to return everytime this module is called.
+            tensor (torch.Tensor): The tensor to return every time this module is
+                called.
         """
         super().__init__()
         assert isinstance(tensor, torch.Tensor)
@@ -509,6 +562,17 @@ class SharedImage(ImageParameterization):
 
     Mordvintsev, et al., "Differentiable Image Parameterizations", Distill, 2018.
     https://distill.pub/2018/differentiable-parameterizations/
+
+    Example::
+
+        >>> fft_image = opt.images.FFTImage(size=(224, 224), batch=2)
+        >>> shared_shapes = ((1, 3, 64, 64), (4, 3, 32, 32))
+        >>> shared_image = opt.images.SharedImage(shared_shapes, fft_image)
+        >>> output_image = shared_image()
+        >>> print(output_image.required_grad)
+        True
+        >>> print(output_image.shape)
+        torch.Size([2, 3, 224, 224])
     """
 
     __constants__ = ["offset"]
@@ -522,13 +586,13 @@ def __init__(
         """
         Args:
 
-            shapes (list of int or list of list of ints): The shapes of the shared
+            shapes (list of int or list of list of int): The shapes of the shared
                 tensors to use for creating the nn.Parameter tensors.
             parameterization (ImageParameterization): An image parameterization
                 instance.
-            offset (int or list of int or list of list of ints , optional): The offsets
+            offset (int or list of int or list of list of int, optional): The offsets
                 to use for the shared tensors.
-                Default: None
+                Default: ``None``
         """
         super().__init__()
         assert shapes is not None
@@ -552,12 +616,12 @@ def _get_offset(self, offset: Union[int, Tuple[int]], n: int) -> List[List[int]]
 
         Args:
 
-            offset (int or list of int or list of list of ints , optional): The offsets
+            offset (int or list of int or list of list of int, optional): The offsets
                 to use for the shared tensors.
             n (int): The number of tensors needing offset values.
 
         Returns:
-            **offset** (list of list of int): A list of offset values.
+            offset (List[List[int]]): A list of offset values.
         """
         if type(offset) is tuple or type(offset) is list:
             if type(offset[0]) is tuple or type(offset[0]) is list:
@@ -581,7 +645,7 @@ def _apply_offset(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
             x_list (list of torch.Tensor): list of tensors to offset.
 
         Returns:
-            **A** (list of torch.Tensor): list of offset tensors.
+            A (list of torch.Tensor): list of offset tensors.
         """
 
         A: List[torch.Tensor] = []
@@ -616,8 +680,8 @@ def _interpolate_bilinear(
         Args:
 
             x (torch.Tensor): The NCHW tensor to resize.
-            size (tuple of int): The desired output size to resize the input
-                to, with a format of: [height, width].
+            size (tuple of int): The desired output size to resize the input to, with
+                a format of: [height, width].
 
         Returns:
             x (torch.Tensor): A resized NCHW tensor.
@@ -645,8 +709,8 @@ def _interpolate_trilinear(
         Args:
 
             x (torch.Tensor): The NCHW tensor to resize.
-            size (tuple of int): The desired output size to resize the input
-                to, with a format of: [channels, height, width].
+            size (tuple of int): The desired output size to resize the input to, with
+                a format of: [channels, height, width].
 
         Returns:
             x (torch.Tensor): A resized NCHW tensor.
@@ -678,7 +742,7 @@ def _interpolate_tensor(
             width (int): The width to resize the tensor to.
 
         Returns:
-            **tensor** (torch.Tensor): A resized tensor.
+            tensor (torch.Tensor): A resized tensor.
         """
 
         if x.size(1) == channels:
@@ -721,6 +785,28 @@ def forward(self) -> torch.Tensor:
 class StackImage(ImageParameterization):
     """
     Stack multiple NCHW image parameterizations along their batch dimensions.
+
+    Example::
+
+        >>> fft_image_1 = opt.images.FFTImage(size=(224, 224), batch=1)
+        >>> fft_image_2 = opt.images.FFTImage(size=(224, 224), batch=1)
+        >>> stack_image = opt.images.StackImage([fft_image_1, fft_image_2])
+        >>> output_image = stack_image()
+        >>> print(output_image.required_grad)
+        True
+        >>> print(output_image.shape)
+        torch.Size([2, 3, 224, 224])
+
+    Example with ``ImageParameterization`` & ``torch.Tensor``::
+
+        >>> fft_image = opt.images.FFTImage(size=(224, 224), batch=1)
+        >>> tensor_image = torch.randn(1, 3, 224, 224)
+        >>> stack_image = opt.images.StackImage([fft_image, tensor_image])
+        >>> output_image = stack_image()
+        >>> print(output_image.required_grad)
+        True
+        >>> print(output_image.shape)
+        torch.Size([2, 3, 224, 224])
     """
 
     __constants__ = ["dim", "output_device"]
@@ -735,15 +821,16 @@ def __init__(
         Args:
 
             parameterizations (list of ImageParameterization and torch.Tensor): A list
-                 of image parameterizations to stack across their batch dimensions.
-            dim (int, optional): Optionally specify the dim to concatinate
-                 parameterization outputs on. Default is set to the batch dimension.
-                 Default: 0
+                of image parameterizations and tensors to concatenate across a
+                specified dimension.
+            dim (int, optional): Optionally specify the dim to concatenate
+                parameterization outputs on. Default is set to the batch dimension.
+                Default: ``0``
             output_device (torch.device, optional): If the parameterizations are on
                 different devices, then their outputs will be moved to the device
-                specified by this variable. Default is set to None with the expectation
-                that all parameterizations are on the same device.
-                Default: None
+                specified by this variable. Default is set to ``None`` with the
+                expectation that all parameterization outputs are on the same device.
+                Default: ``None``
         """
         super().__init__()
         assert len(parameterizations) > 0
@@ -786,16 +873,46 @@ def forward(self) -> torch.Tensor:
 
 
 class NaturalImage(ImageParameterization):
-    r"""Outputs an optimizable input image.
+    r"""Outputs an optimizable input image wrapped in :class:`.ImageTensor`.
 
-    By convention, single images are CHW and float32s in [0,1].
-    The underlying parameterization can be decorrelated via a ToRGB transform.
-    When used with the (default) FFT parameterization, this results in a fully
-    uncorrelated image parameterization. :-)
+    By convention, single images are CHW and float32s in [0, 1].
+    The underlying parameterization can be decorrelated via a
+    :class:`captum.optim.transforms.ToRGB` transform.
+    When used with the (default) :class:`.FFTImage` parameterization, this results in
+    a fully uncorrelated image parameterization. :-)
 
     If a model requires a normalization step, such as normalizing imagenet RGB values,
-    or rescaling to [0,255], it can perform those steps with the provided transforms or
-    inside its computation.
+    or rescaling to [0, 255], it can perform those steps with the provided transforms
+    or inside its module class.
+
+    Example::
+
+        >>> image = opt.images.NaturalImage(size=(224, 224), channels=3, batch=1)
+        >>> image_tensor = image()
+        >>> print(image_tensor.required_grad)
+        True
+        >>> print(image_tensor.shape)
+        torch.Size([1, 3, 224, 224])
+
+    Example for using an initialization tensor::
+
+        >>> init = torch.randn(1, 3, 224, 224)
+        >>> image = opt.images.NaturalImage(init=init)
+        >>> image_tensor = image()
+        >>> print(image_tensor.required_grad)
+        True
+        >>> print(image_tensor.shape)
+        torch.Size([1, 3, 224, 224])
+
+    Example for using a parameterization::
+
+        >>> fft_image = opt.images.FFTImage(size=(224, 224), channels=3, batch=1)
+        >>> image = opt.images.NaturalImage(parameterization=fft_image)
+        >>> image_tensor = image()
+        >>> print(image_tensor.required_grad)
+        True
+        >>> print(image_tensor.shape)
+        torch.Size([1, 3, 224, 224])
     """
 
     def __init__(
@@ -805,33 +922,60 @@ def __init__(
         batch: int = 1,
         init: Optional[torch.Tensor] = None,
         parameterization: ImageParameterization = FFTImage,
-        squash_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
+        squash_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = torch.sigmoid,
         decorrelation_module: Optional[nn.Module] = ToRGB(transform="klt"),
         decorrelate_init: bool = True,
     ) -> None:
         """
         Args:
 
-            size (Tuple[int, int], optional): The height and width to use for the
-                nn.Parameter image tensor.
-                Default: (224, 224)
+            size (tuple of int, optional): The height and width to use for the
+                nn.Parameter image tensor, in the format of: (height, width).
+                This parameter is not used if the given ``parameterization`` is an
+                instance.
+                Default: ``(224, 224)``
             channels (int, optional): The number of channels to use when creating the
-                nn.Parameter tensor.
-                Default: 3
+                nn.Parameter tensor. This parameter is not used if the given
+                ``parameterization`` is an instance.
+                Default: ``3``
             batch (int, optional): The number of channels to use when creating the
-                nn.Parameter tensor, or stacking init images.
-                Default: 1
+                nn.Parameter tensor. This parameter is not used if the given
+                ``parameterization`` is an instance.
+                Default: ``1``
+            init (torch.Tensor, optional): Optionally specify a tensor to use instead
+                of creating one from random noise. This parameter is not used if the
+                given ``parameterization`` is an instance. Set to ``None`` for random
+                init.
+                Default: ``None``
             parameterization (ImageParameterization, optional): An image
                 parameterization class, or instance of an image parameterization class.
-                Default: FFTImage
-            squash_func (Callable[[torch.Tensor], torch.Tensor]], optional): The squash
-                function to use after color recorrelation. A funtion or lambda function.
-                Default: None
-            decorrelation_module (nn.Module, optional): A ToRGB instance.
-                Default: ToRGB
+                Default: :class:`.FFTImage`
+            squash_func (callable, optional): The squash function to use after color
+                recorrelation. A function, lambda function, or callable class instance.
+                Any provided squash function should take a single input tensor and
+                return a single output tensor. If set to ``None``, then
+                :class:`torch.nn.Identity` will be used to make it a non op.
+                Default: :func:`torch.sigmoid`
+            decorrelation_module (nn.Module, optional): A module instance that
+                recorrelates the colors of an input image. Custom modules can make use
+                of the ``decorrelate_init`` parameter by having a second ``inverse``
+                parameter in their forward functions that performs the inverse
+                operation when it is set to ``True`` (see :class:`.ToRGB` for an
+                example). Set to ``None`` for no recorrelation.
+                Default: :class:`.ToRGB`
             decorrelate_init (bool, optional): Whether or not to apply color
-                decorrelation to the init tensor input.
-                Default: True
+                decorrelation to the init tensor input. This parameter is not used if
+                the given ``parameterization`` is an instance or if init is ``None``.
+                Default: ``True``
+
+        Attributes:
+
+            parameterization (ImageParameterization): The given image parameterization
+                instance given when initializing ``NaturalImage``.
+                Default: :class:`.FFTImage`
+            decorrelation_module (torch.nn.Module): The given decorrelation module
+                instance given when initializing ``NaturalImage``.
+                Default: :class:`.ToRGB`
         """
         super().__init__()
         if not isinstance(parameterization, ImageParameterization):
@@ -851,21 +995,13 @@ def __init__(
                 )
                 init = self.decorrelate(init, inverse=True).rename(None)
 
-            if squash_func is None:
-                squash_func = self._clamp_image
-
-        self.squash_func = torch.sigmoid if squash_func is None else squash_func
+        self.squash_func = squash_func or torch.nn.Identity()
         if not isinstance(parameterization, ImageParameterization):
             parameterization = parameterization(
                 size=size, channels=channels, batch=batch, init=init
             )
         self.parameterization = parameterization
 
-    @torch.jit.export
-    def _clamp_image(self, x: torch.Tensor) -> torch.Tensor:
-        """JIT supported squash function."""
-        return x.clamp(0, 1)
-
     @torch.jit.ignore
     def _to_image_tensor(self, x: torch.Tensor) -> torch.Tensor:
         """
@@ -873,14 +1009,20 @@ def _to_image_tensor(self, x: torch.Tensor) -> torch.Tensor:
 
         Args:
 
-            x (torch.tensor): An input tensor.
+            x (torch.Tensor): An input tensor.
 
         Returns:
-            x (ImageTensor): An instance of ImageTensor with the input tensor.
+            x (ImageTensor): An instance of ``ImageTensor`` with the input tensor.
         """
         return ImageTensor(x)
 
-    def forward(self) -> torch.Tensor:
+    def forward(self) -> ImageTensor:
+        """
+        Returns:
+            image_tensor (ImageTensor): The parameterization output wrapped in
+                :class:`.ImageTensor`, that has optionally had its colors
+                recorrelated.
+        """
         image = self.parameterization()
         if self.decorrelate is not None:
             image = self.decorrelate(image)
diff --git a/captum/optim/_param/image/transforms.py b/captum/optim/_param/image/transforms.py
index 4ec876263..8131f4fc1 100644
--- a/captum/optim/_param/image/transforms.py
+++ b/captum/optim/_param/image/transforms.py
@@ -939,24 +939,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
         return x[:, [2, 1, 0]]
 
 
-# class TransformationRobustness(nn.Module):
-#     def __init__(self, jitter=False, scale=False):
-#         super().__init__()
-#         if jitter:
-#             self.jitter = RandomSpatialJitter(4)
-#         if scale:
-#             self.scale = RandomScale()
-
-#     def forward(self, x):
-#         original_shape = x.shape
-#         if hasattr(self, "jitter"):
-#             x = self.jitter(x)
-#         if hasattr(self, "scale"):
-#             x = self.scale(x)
-#         cropped = center_crop(x, original_shape)
-#         return cropped
-
-
 # class RandomHomography(nn.Module):
 #     def __init__(self):
 #         super().__init__()
@@ -979,7 +961,7 @@ class GaussianSmoothing(nn.Module):
     in the input using a depthwise convolution.
     """
 
-    __constants__ = ["groups"]
+    __constants__ = ["groups", "padding"]
 
     def __init__(
         self,
@@ -987,6 +969,7 @@ def __init__(
         kernel_size: Union[int, Sequence[int]],
         sigma: Union[float, Sequence[float]],
         dim: int = 2,
+        padding: Union[str, int, Tuple[int, int]] = "same",
     ) -> None:
         """
         Args:
@@ -996,7 +979,11 @@ def __init__(
             kernel_size (int, sequence): Size of the gaussian kernel.
             sigma (float, sequence): Standard deviation of the gaussian kernel.
             dim (int, optional): The number of dimensions of the data.
-                Default value is 2 (spatial).
+                Default value is ``2`` for (spatial)
+            padding (str, int or list of tuple, optional): The desired padding amount
+                or mode to use. One of; ``"valid"``, ``"same"``, a single number, or a
+                tuple in the format of: (padH, padW).
+                Default: ``"same"``
         """
         super().__init__()
         if isinstance(kernel_size, numbers.Number):
@@ -1007,9 +994,18 @@ def __init__(
         # The gaussian kernel is the product of the
         # gaussian function of each dimension.
         kernel = 1
-        meshgrids = torch.meshgrid(
-            [torch.arange(size, dtype=torch.float32) for size in kernel_size]
-        )
+
+        # PyTorch v1.10.0 adds a new indexing argument
+        if version.parse(torch.__version__) >= version.parse("1.10.0"):
+            meshgrids = torch.meshgrid(
+                [torch.arange(size, dtype=torch.float32) for size in kernel_size],
+                indexing="ij",
+            )
+        else:
+            meshgrids = torch.meshgrid(
+                [torch.arange(size, dtype=torch.float32) for size in kernel_size]
+            )
+
         for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
             mean = (size - 1) / 2
             kernel *= (
@@ -1027,6 +1023,7 @@ def __init__(
 
         self.register_buffer("weight", kernel)
         self.groups = channels
+        self.padding = padding
 
         if dim == 1:
             self.conv = F.conv1d
@@ -1048,9 +1045,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
             input (torch.Tensor): Input to apply gaussian filter on.
 
         Returns:
-            **filtered** (torch.Tensor): Filtered output.
+            filtered (torch.Tensor): Filtered output.
         """
-        return self.conv(input, weight=self.weight, groups=self.groups)
+        return self.conv(
+            input, weight=self.weight, groups=self.groups, padding=self.padding
+        )
 
 
 class SymmetricPadding(torch.autograd.Function):
diff --git a/tests/optim/models/test_models_common.py b/tests/optim/models/test_models_common.py
index 5c1006076..11856e44e 100644
--- a/tests/optim/models/test_models_common.py
+++ b/tests/optim/models/test_models_common.py
@@ -6,6 +6,7 @@
 import torch
 import torch.nn.functional as F
 from captum.optim.models import googlenet
+from packaging import version
 from tests.helpers.basic import BaseTest, assertTensorAlmostEqual
 
 
@@ -37,7 +38,10 @@ def check_grad(self, grad_input, grad_output):
 
         rr_layer = model_utils.RedirectedReluLayer()
         x = torch.zeros(1, 3, 4, 4, requires_grad=True)
-        rr_layer.register_backward_hook(check_grad)
+        if version.parse(torch.__version__) >= version.parse("1.8.0"):
+            rr_layer.register_full_backward_hook(check_grad)
+        else:
+            rr_layer.register_backward_hook(check_grad)
         rr_loss = rr_layer(x * 1).mean()
         rr_loss.backward()
 
diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py
index 617d34a3a..9d23efd37 100644
--- a/tests/optim/param/test_images.py
+++ b/tests/optim/param/test_images.py
@@ -20,6 +20,17 @@ def test_new(self) -> None:
         test_tensor = images.ImageTensor(x)
         self.assertTrue(torch.is_tensor(test_tensor))
         self.assertEqual(x.shape, test_tensor.shape)
+        self.assertEqual(x.dtype, test_tensor.dtype)
+
+    def test_new_dtype_float64(self) -> None:
+        x = torch.ones(5, dtype=torch.float64)
+        test_tensor = images.ImageTensor(x)
+        self.assertEqual(test_tensor.dtype, torch.float64)
+
+    def test_new_dtype_float16(self) -> None:
+        x = torch.ones(5, dtype=torch.float16)
+        test_tensor = images.ImageTensor(x)
+        self.assertEqual(test_tensor.dtype, torch.float16)
 
     def test_new_numpy(self) -> None:
         x = torch.ones(5).numpy()
@@ -33,6 +44,13 @@ def test_new_list(self) -> None:
         self.assertTrue(torch.is_tensor(test_tensor))
         self.assertEqual(x.shape, test_tensor.shape)
 
+    def test_new_with_grad(self) -> None:
+        x = torch.ones(5, requires_grad=True)
+        test_tensor = images.ImageTensor(x)
+        self.assertTrue(test_tensor.requires_grad)
+        self.assertTrue(torch.is_tensor(test_tensor))
+        self.assertEqual(x.shape, test_tensor.shape)
+
     def test_torch_function(self) -> None:
         x = torch.ones(5)
         image_tensor = images.ImageTensor(x)
@@ -102,7 +120,7 @@ def test_subclass(self) -> None:
 
     def test_pytorch_fftfreq(self) -> None:
         image = images.FFTImage((1, 1))
-        _, _, fftfreq = image.get_fft_funcs()
+        _, _, fftfreq = image._get_fft_funcs()
         assertTensorAlmostEqual(
             self, fftfreq(4, 4), torch.as_tensor(np.fft.fftfreq(4, 4)), mode="max"
         )
@@ -114,7 +132,7 @@ def test_rfft2d_freqs(self) -> None:
 
         assertTensorAlmostEqual(
             self,
-            image.rfft2d_freqs(height, width),
+            image._rfft2d_freqs(height, width),
             torch.tensor([[0.0000, 0.3333], [0.5000, 0.6009]]),
         )
 
@@ -308,6 +326,33 @@ def test_fftimage_forward_init_batch(self) -> None:
             self, fftimage_tensor.detach(), fftimage_array, 25.0, mode="max"
         )
 
+    def test_fftimage_forward_dtype_float64(self) -> None:
+        dtype = torch.float64
+        image_param = images.FFTImage(size=(224, 224)).to(dtype=dtype)
+        output = image_param()
+        self.assertEqual(output.dtype, dtype)
+
+    def test_fftimage_forward_dtype_float32(self) -> None:
+        dtype = torch.float32
+        image_param = images.FFTImage(size=(224, 224)).to(dtype=dtype)
+        output = image_param()
+        self.assertEqual(output.dtype, dtype)
+
+    def test_fftimage_forward_dtype_float16(self) -> None:
+        if version.parse(torch.__version__) <= version.parse("1.12.0"):
+            raise unittest.SkipTest(
+                "Skipping FFTImage float16 dtype test due to"
+                + "  insufficient Torch version."
+            )
+        dtype = torch.float16
+        if not torch.cuda.is_available():
+            raise unittest.SkipTest(
+                "Skipping FFTImage float16 dtype test due to not supporting CUDA."
+            )
+        image_param = images.FFTImage(size=(256, 256)).cuda().to(dtype=dtype)
+        output = image_param()
+        self.assertEqual(output.dtype, dtype)
+
 
 class TestPixelImage(BaseTest):
     def test_subclass(self) -> None:
@@ -317,12 +362,7 @@ def test_pixelimage_random(self) -> None:
         size = (224, 224)
         channels = 3
         image_param = images.PixelImage(size=size, channels=channels)
-
-        self.assertEqual(image_param.image.dim(), 4)
-        self.assertEqual(image_param.image.size(0), 1)
-        self.assertEqual(image_param.image.size(1), channels)
-        self.assertEqual(image_param.image.size(2), size[0])
-        self.assertEqual(image_param.image.size(3), size[1])
+        self.assertEqual(list(image_param.image.shape), [1, channels] + list(size))
         self.assertTrue(image_param.image.requires_grad)
 
     def test_pixelimage_init(self) -> None:
@@ -331,11 +371,7 @@ def test_pixelimage_init(self) -> None:
         init_tensor = torch.randn(channels, *size)
         image_param = images.PixelImage(size=size, channels=channels, init=init_tensor)
 
-        self.assertEqual(image_param.image.dim(), 4)
-        self.assertEqual(image_param.image.size(0), 1)
-        self.assertEqual(image_param.image.size(1), channels)
-        self.assertEqual(image_param.image.size(2), size[0])
-        self.assertEqual(image_param.image.size(3), size[1])
+        self.assertEqual(list(image_param.image.shape), [1, channels] + list(size))
         assertTensorAlmostEqual(self, image_param.image, init_tensor[None, :], 0)
         self.assertTrue(image_param.image.requires_grad)
 
@@ -344,12 +380,7 @@ def test_pixelimage_random_forward(self) -> None:
         channels = 3
         image_param = images.PixelImage(size=size, channels=channels)
         test_tensor = image_param.forward().rename(None)
-
-        self.assertEqual(test_tensor.dim(), 4)
-        self.assertEqual(test_tensor.size(0), 1)
-        self.assertEqual(test_tensor.size(1), channels)
-        self.assertEqual(test_tensor.size(2), size[0])
-        self.assertEqual(test_tensor.size(3), size[1])
+        self.assertEqual(list(test_tensor.shape), [1, channels] + list(size))
 
     def test_pixelimage_forward_jit_module(self) -> None:
         if version.parse(torch.__version__) <= version.parse("1.8.0"):
@@ -369,13 +400,33 @@ def test_pixelimage_init_forward(self) -> None:
         image_param = images.PixelImage(size=size, channels=channels, init=init_tensor)
         test_tensor = image_param.forward().rename(None)
 
-        self.assertEqual(test_tensor.dim(), 4)
-        self.assertEqual(test_tensor.size(0), 1)
-        self.assertEqual(test_tensor.size(1), channels)
-        self.assertEqual(test_tensor.size(2), size[0])
-        self.assertEqual(test_tensor.size(3), size[1])
+        self.assertEqual(list(test_tensor.shape), [1, channels] + list(size))
         assertTensorAlmostEqual(self, test_tensor, init_tensor[None, :], 0)
 
+    def test_pixelimage_forward_dtype_float64(self) -> None:
+        dtype = torch.float64
+        image_param = images.PixelImage(size=(224, 224)).to(dtype=dtype)
+        output = image_param()
+        self.assertEqual(output.dtype, torch.float64)
+
+    def test_pixelimage_forward_dtype_float32(self) -> None:
+        dtype = torch.float32
+        image_param = images.PixelImage(size=(224, 224)).to(dtype=dtype)
+        output = image_param()
+        self.assertEqual(output.dtype, torch.float32)
+
+    def test_pixelimage_forward_dtype_float16(self) -> None:
+        dtype = torch.float16
+        image_param = images.PixelImage(size=(224, 224)).to(dtype)
+        output = image_param()
+        self.assertEqual(output.dtype, dtype)
+
+    def test_pixelimage_forward_dtype_bfloat16(self) -> None:
+        dtype = torch.bfloat16
+        image_param = images.PixelImage(size=(224, 224)).to(dtype=dtype)
+        output = image_param()
+        self.assertEqual(output.dtype, dtype)
+
 
 class TestLaplacianImage(BaseTest):
     def test_subclass(self) -> None:
@@ -384,21 +435,67 @@ def test_subclass(self) -> None:
     def test_laplacianimage_random_forward(self) -> None:
         size = (224, 224)
         channels = 3
-        image_param = images.LaplacianImage(size=size, channels=channels)
+        batch = 1
+        image_param = images.LaplacianImage(size=size, channels=channels, batch=batch)
         test_tensor = image_param.forward().rename(None)
+        self.assertEqual(list(test_tensor.shape), [batch, channels, size[0], size[1]])
+        self.assertTrue(test_tensor.requires_grad)
 
-        self.assertEqual(test_tensor.dim(), 4)
-        self.assertEqual(test_tensor.size(0), 1)
-        self.assertEqual(test_tensor.size(1), channels)
-        self.assertEqual(test_tensor.size(2), size[0])
-        self.assertEqual(test_tensor.size(3), size[1])
+    def test_laplacianimage_random_forward_batch_5(self) -> None:
+        size = (224, 224)
+        channels = 3
+        batch = 5
+        image_param = images.LaplacianImage(size=size, channels=channels, batch=batch)
+        test_tensor = image_param.forward().rename(None)
+        self.assertEqual(list(test_tensor.shape), [batch, channels, size[0], size[1]])
 
-    def test_laplacianimage_init(self) -> None:
-        init_t = torch.zeros(1, 224, 224)
-        image_param = images.LaplacianImage(size=(224, 224), channels=3, init=init_t)
+    def test_laplacianimage_random_forward_scale_list(self) -> None:
+        size = (224, 224)
+        channels = 3
+        batch = 1
+        scale_list = [1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 56.0, 112.0]
+        image_param = images.LaplacianImage(
+            size=size, channels=channels, batch=batch, scale_list=scale_list
+        )
+        test_tensor = image_param.forward().rename(None)
+        self.assertEqual(list(test_tensor.shape), [batch, channels, size[0], size[1]])
+
+    def test_laplacianimage_random_forward_scale_list_error(self) -> None:
+        scale_list = [1.0, 2.0, 4.0, 8.0, 16.0, 64.0, 144.0]
+        with self.assertRaises(AssertionError):
+            images.LaplacianImage(
+                size=(224, 224), channels=3, batch=1, scale_list=scale_list
+            )
+
+    def test_laplacianimage_init_tensor(self) -> None:
+        init_tensor = torch.zeros(1, 3, 224, 224)
+        image_param = images.LaplacianImage(init=init_tensor)
         output = image_param.forward().detach().rename(None)
         assertTensorAlmostEqual(self, torch.ones_like(output) * 0.5, output, mode="max")
 
+    def test_laplacianimage_random_forward_cuda(self) -> None:
+        if not torch.cuda.is_available():
+            raise unittest.SkipTest(
+                "Skipping LaplacianImage CUDA test due to not supporting CUDA."
+            )
+        image_param = images.LaplacianImage(size=(224, 224), channels=3, batch=1).cuda()
+        test_tensor = image_param.forward().rename(None)
+        self.assertTrue(test_tensor.is_cuda)
+        self.assertEqual(list(test_tensor.shape), [1, 3, 224, 224])
+        self.assertTrue(test_tensor.requires_grad)
+
+    def test_laplcianimage_forward_dtype_float64(self) -> None:
+        dtype = torch.float64
+        image_param = images.LaplacianImage(size=(224, 224)).to(dtype=dtype)
+        output = image_param()
+        self.assertEqual(output.dtype, dtype)
+
+    def test_laplcianimage_forward_dtype_float32(self) -> None:
+        dtype = torch.float32
+        image_param = images.LaplacianImage(size=(224, 224)).to(dtype=dtype)
+        output = image_param()
+        self.assertEqual(output.dtype, dtype)
+
 
 class TestSimpleTensorParameterization(BaseTest):
     def test_subclass(self) -> None:
@@ -674,12 +771,7 @@ def test_interpolate_tensor(self) -> None:
         output_tensor = image_param._interpolate_tensor(
             test_tensor, batch, channels, size[0], size[1]
         )
-
-        self.assertEqual(output_tensor.dim(), 4)
-        self.assertEqual(output_tensor.size(0), batch)
-        self.assertEqual(output_tensor.size(1), channels)
-        self.assertEqual(output_tensor.size(2), size[0])
-        self.assertEqual(output_tensor.size(3), size[1])
+        self.assertEqual(list(output_tensor.shape), [batch, channels] + list(size))
 
     def test_sharedimage_single_shape_hw_forward(self) -> None:
         shared_shapes = (128 // 2, 128 // 2)
@@ -697,11 +789,7 @@ def test_sharedimage_single_shape_hw_forward(self) -> None:
         self.assertEqual(
             list(image_param.shared_init[0]().shape), [1, 1] + list(shared_shapes)
         )
-        self.assertEqual(test_tensor.dim(), 4)
-        self.assertEqual(test_tensor.size(0), batch)
-        self.assertEqual(test_tensor.size(1), channels)
-        self.assertEqual(test_tensor.size(2), size[0])
-        self.assertEqual(test_tensor.size(3), size[1])
+        self.assertEqual(list(test_tensor.shape), [batch, channels] + list(size))
 
     def test_sharedimage_single_shape_chw_forward(self) -> None:
         shared_shapes = (3, 128 // 2, 128 // 2)
@@ -719,11 +807,7 @@ def test_sharedimage_single_shape_chw_forward(self) -> None:
         self.assertEqual(
             list(image_param.shared_init[0]().shape), [1] + list(shared_shapes)
         )
-        self.assertEqual(test_tensor.dim(), 4)
-        self.assertEqual(test_tensor.size(0), batch)
-        self.assertEqual(test_tensor.size(1), channels)
-        self.assertEqual(test_tensor.size(2), size[0])
-        self.assertEqual(test_tensor.size(3), size[1])
+        self.assertEqual(list(test_tensor.shape), [batch, channels] + list(size))
 
     def test_sharedimage_single_shape_bchw_forward(self) -> None:
         shared_shapes = (1, 3, 128 // 2, 128 // 2)
@@ -739,11 +823,7 @@ def test_sharedimage_single_shape_bchw_forward(self) -> None:
         self.assertIsNone(image_param.offset)
         self.assertEqual(image_param.shared_init[0]().dim(), 4)
         self.assertEqual(list(image_param.shared_init[0]().shape), list(shared_shapes))
-        self.assertEqual(test_tensor.dim(), 4)
-        self.assertEqual(test_tensor.size(0), batch)
-        self.assertEqual(test_tensor.size(1), channels)
-        self.assertEqual(test_tensor.size(2), size[0])
-        self.assertEqual(test_tensor.size(3), size[1])
+        self.assertEqual(list(test_tensor.shape), [batch, channels] + list(size))
 
     def test_sharedimage_multiple_shapes_forward(self) -> None:
         shared_shapes = (
@@ -769,11 +849,7 @@ def test_sharedimage_multiple_shapes_forward(self) -> None:
             self.assertEqual(
                 list(image_param.shared_init[i]().shape), list(shared_shapes[i])
             )
-        self.assertEqual(test_tensor.dim(), 4)
-        self.assertEqual(test_tensor.size(0), batch)
-        self.assertEqual(test_tensor.size(1), channels)
-        self.assertEqual(test_tensor.size(2), size[0])
-        self.assertEqual(test_tensor.size(3), size[1])
+        self.assertEqual(list(test_tensor.shape), [batch, channels] + list(size))
 
     def test_sharedimage_multiple_shapes_diff_len_forward(self) -> None:
         shared_shapes = (
@@ -800,11 +876,7 @@ def test_sharedimage_multiple_shapes_diff_len_forward(self) -> None:
             s_shape = ([1] * (4 - len(s_shape))) + list(s_shape)
             self.assertEqual(list(image_param.shared_init[i]().shape), s_shape)
 
-        self.assertEqual(test_tensor.dim(), 4)
-        self.assertEqual(test_tensor.size(0), batch)
-        self.assertEqual(test_tensor.size(1), channels)
-        self.assertEqual(test_tensor.size(2), size[0])
-        self.assertEqual(test_tensor.size(3), size[1])
+        self.assertEqual(list(test_tensor.shape), [batch, channels] + list(size))
 
     def test_sharedimage_multiple_shapes_diff_len_forward_jit_module(self) -> None:
         if version.parse(torch.__version__) <= version.parse("1.8.0"):
@@ -831,12 +903,7 @@ def test_sharedimage_multiple_shapes_diff_len_forward_jit_module(self) -> None:
         )
         jit_image_param = torch.jit.script(image_param)
         test_tensor = jit_image_param()
-
-        self.assertEqual(test_tensor.dim(), 4)
-        self.assertEqual(test_tensor.size(0), batch)
-        self.assertEqual(test_tensor.size(1), channels)
-        self.assertEqual(test_tensor.size(2), size[0])
-        self.assertEqual(test_tensor.size(3), size[1])
+        self.assertEqual(list(test_tensor.shape), [batch, channels] + list(size))
 
 
 class TestStackImage(BaseTest):
@@ -1071,11 +1138,19 @@ def test_natural_image_init_func_pixelimage(self) -> None:
         self.assertIsInstance(image_param.decorrelate, ToRGB)
         self.assertEqual(image_param.squash_func, torch.sigmoid)
 
-    def test_natural_image_init_func_default_init_tensor(self) -> None:
-        image_param = images.NaturalImage(init=torch.ones(1, 3, 1, 1))
+    def test_natural_image_custom_squash_func(self) -> None:
+        init_tensor = torch.randn(1, 3, 1, 1)
+
+        def clamp_image(x: torch.Tensor) -> torch.Tensor:
+            return x.clamp(0, 1)
+
+        image_param = images.NaturalImage(init=init_tensor, squash_func=clamp_image)
+        image = image_param.forward().detach()
+
         self.assertIsInstance(image_param.parameterization, images.FFTImage)
         self.assertIsInstance(image_param.decorrelate, ToRGB)
-        self.assertEqual(image_param.squash_func, image_param._clamp_image)
+        self.assertEqual(image_param.squash_func, clamp_image)
+        assertTensorAlmostEqual(self, image, init_tensor.clamp(0, 1))
 
     def test_natural_image_init_tensor_pixelimage_sf_sigmoid(self) -> None:
         if version.parse(torch.__version__) <= version.parse("1.8.0"):
@@ -1084,10 +1159,10 @@ def test_natural_image_init_tensor_pixelimage_sf_sigmoid(self) -> None:
                 + " test due to insufficient Torch version."
             )
         image_param = images.NaturalImage(
-            init=torch.ones(1, 3, 1, 1),
+            init=torch.ones(1, 3, 1, 1).float(),
             parameterization=images.PixelImage,
             squash_func=torch.sigmoid,
-        )
+        ).to(dtype=torch.float32)
         output_tensor = image_param()
 
         self.assertEqual(image_param.squash_func, torch.sigmoid)
@@ -1103,9 +1178,10 @@ def test_natural_image_0(self) -> None:
         )
 
     def test_natural_image_1(self) -> None:
-        image_param = images.NaturalImage(init=torch.ones(3, 1, 1))
+        init_tensor = torch.ones(3, 1, 1)
+        image_param = images.NaturalImage(init=init_tensor)
         image = image_param.forward().detach()
-        assertTensorAlmostEqual(self, image, torch.ones_like(image), mode="max")
+        assertTensorAlmostEqual(self, image, torch.sigmoid(init_tensor).unsqueeze(0))
 
     def test_natural_image_cuda(self) -> None:
         if not torch.cuda.is_available():
@@ -1132,10 +1208,11 @@ def test_natural_image_jit_module_init_tensor(self) -> None:
                 "Skipping NaturalImage init tensor JIT module test due to"
                 + " insufficient Torch version."
             )
-        image_param = images.NaturalImage(init=torch.ones(1, 3, 1, 1))
+        init_tensor = torch.ones(1, 3, 1, 1)
+        image_param = images.NaturalImage(init=init_tensor)
         jit_image_param = torch.jit.script(image_param)
         output_tensor = jit_image_param()
-        assertTensorAlmostEqual(self, output_tensor, torch.ones_like(output_tensor))
+        assertTensorAlmostEqual(self, output_tensor, torch.sigmoid(init_tensor))
 
     def test_natural_image_jit_module_init_tensor_pixelimage(self) -> None:
         if version.parse(torch.__version__) <= version.parse("1.8.0"):
@@ -1143,12 +1220,13 @@ def test_natural_image_jit_module_init_tensor_pixelimage(self) -> None:
                 "Skipping NaturalImage PixelImage init tensor JIT module"
                 + " test due to insufficient Torch version."
             )
+        init_tensor = torch.ones(1, 3, 1, 1)
         image_param = images.NaturalImage(
-            init=torch.ones(1, 3, 1, 1), parameterization=images.PixelImage
+            init=init_tensor, parameterization=images.PixelImage
         )
         jit_image_param = torch.jit.script(image_param)
         output_tensor = jit_image_param()
-        assertTensorAlmostEqual(self, output_tensor, torch.ones_like(output_tensor))
+        assertTensorAlmostEqual(self, output_tensor, torch.sigmoid(init_tensor))
 
     def test_natural_image_decorrelation_module_none(self) -> None:
         if version.parse(torch.__version__) <= version.parse("1.8.0"):
@@ -1156,9 +1234,43 @@ def test_natural_image_decorrelation_module_none(self) -> None:
                 "Skipping NaturalImage no decorrelation module"
                 + " test due to insufficient Torch version."
             )
-        image_param = images.NaturalImage(
-            init=torch.ones(1, 3, 4, 4), decorrelation_module=None
-        )
+        init_tensor = torch.ones(1, 3, 1, 1)
+        image_param = images.NaturalImage(init=init_tensor, decorrelation_module=None)
         image = image_param.forward().detach()
         self.assertIsNone(image_param.decorrelate)
-        assertTensorAlmostEqual(self, image, torch.ones_like(image))
+        assertTensorAlmostEqual(self, image, torch.sigmoid(init_tensor))
+
+    def test_natural_image_forward_dtype_float64(self) -> None:
+        dtype = torch.float64
+        image_param = images.NaturalImage(
+            size=(224, 224), decorrelation_module=ToRGB("klt")
+        ).to(dtype=dtype)
+        output = image_param()
+        self.assertEqual(output.dtype, dtype)
+
+    def test_natural_image_forward_dtype_float32(self) -> None:
+        dtype = torch.float32
+        image_param = images.NaturalImage(
+            size=(224, 224), decorrelation_module=ToRGB("klt")
+        ).to(dtype=dtype)
+        output = image_param()
+        self.assertEqual(output.dtype, dtype)
+
+    def test_fftimage_forward_dtype_float16(self) -> None:
+        if version.parse(torch.__version__) <= version.parse("1.12.0"):
+            raise unittest.SkipTest(
+                "Skipping NaturalImage float16 dtype test due to"
+                + "  insufficient Torch version."
+            )
+        if not torch.cuda.is_available():
+            raise unittest.SkipTest(
+                "Skipping NaturalImage float16 dtype test due to not supporting CUDA."
+            )
+        dtype = torch.float16
+        image_param = (
+            images.NaturalImage(size=(256, 256), decorrelation_module=ToRGB("klt"))
+            .cuda()
+            .to(dtype=dtype)
+        )
+        output = image_param()
+        self.assertEqual(output.dtype, dtype)
diff --git a/tests/optim/param/test_transforms.py b/tests/optim/param/test_transforms.py
index 385006a7a..3568b4c53 100644
--- a/tests/optim/param/test_transforms.py
+++ b/tests/optim/param/test_transforms.py
@@ -261,6 +261,24 @@ def test_random_scale_jit_module(self) -> None:
             0,
         )
 
+    def test_random_scale_dtype_float64(self) -> None:
+        dtype = torch.float64
+        scale_module = transforms.RandomScale(scale=[0.975, 1.025, 0.95, 1.05]).to(
+            dtype=dtype
+        )
+        x = torch.ones([1, 3, 224, 224], dtype=dtype)
+        output = scale_module(x)
+        self.assertEqual(output.dtype, dtype)
+
+    def test_random_scale_dtype_float32(self) -> None:
+        dtype = torch.float32
+        scale_module = transforms.RandomScale(scale=[0.975, 1.025, 0.95, 1.05]).to(
+            dtype=dtype
+        )
+        x = torch.ones([1, 3, 224, 224], dtype=dtype)
+        output = scale_module(x)
+        self.assertEqual(output.dtype, dtype)
+
 
 class TestRandomScaleAffine(BaseTest):
     def test_random_scale_affine_init(self) -> None:
@@ -430,6 +448,40 @@ def test_random_scale_affine_jit_module(self) -> None:
             0,
         )
 
+    def test_random_scale_affine_dtype_float64(self) -> None:
+        dtype = torch.float64
+        scale_module = transforms.RandomScaleAffine(
+            scale=[0.975, 1.025, 0.95, 1.05]
+        ).to(dtype=dtype)
+        x = torch.ones([1, 3, 224, 224], dtype=dtype)
+        output = scale_module(x)
+        self.assertEqual(output.dtype, dtype)
+
+    def test_random_scale_affine_dtype_float32(self) -> None:
+        dtype = torch.float32
+        scale_module = transforms.RandomScaleAffine(
+            scale=[0.975, 1.025, 0.95, 1.05]
+        ).to(dtype=dtype)
+        x = torch.ones([1, 3, 224, 224], dtype=dtype)
+        output = scale_module(x)
+        self.assertEqual(output.dtype, dtype)
+
+    def test_random_scale_affine_dtype_float16(self) -> None:
+        if not torch.cuda.is_available():
+            raise unittest.SkipTest(
+                "Skipping RandomScaleAffine float16 dtype test due to not supporting"
+                + " CUDA."
+            )
+        dtype = torch.float16
+        scale_module = (
+            transforms.RandomScaleAffine(scale=[0.975, 1.025, 0.95, 1.05])
+            .cuda()
+            .to(dtype=dtype)
+        )
+        x = torch.ones([1, 3, 224, 224], dtype=dtype).cuda()
+        output = scale_module(x)
+        self.assertEqual(output.dtype, dtype)
+
 
 class TestRandomRotation(BaseTest):
     def test_random_rotation_init(self) -> None:
@@ -629,6 +681,37 @@ def test_random_rotation_jit_module(self) -> None:
         )
         assertTensorAlmostEqual(self, test_output, expected_output, 0.005)
 
+    def test_random_rotation_dtype_float64(self) -> None:
+        dtype = torch.float64
+        degrees = list(range(-25, -5)) + list(range(5, 25))
+        rotation_module = transforms.RandomRotation(degrees=degrees).to(dtype=dtype)
+        x = torch.ones([1, 3, 224, 224], dtype=dtype)
+        output = rotation_module(x)
+        self.assertEqual(output.dtype, dtype)
+
+    def test_random_rotation_dtype_float32(self) -> None:
+        dtype = torch.float32
+        degrees = list(range(-25, -5)) + list(range(5, 25))
+        rotation_module = transforms.RandomRotation(degrees=degrees).to(dtype=dtype)
+        x = torch.ones([1, 3, 224, 224], dtype=dtype)
+        output = rotation_module(x)
+        self.assertEqual(output.dtype, dtype)
+
+    def test_random_rotation_dtype_float16(self) -> None:
+        if not torch.cuda.is_available():
+            raise unittest.SkipTest(
+                "Skipping RandomRotation float16 dtype test due to not supporting"
+                + " CUDA."
+            )
+        dtype = torch.float16
+        degrees = list(range(-25, -5)) + list(range(5, 25))
+        rotation_module = (
+            transforms.RandomRotation(degrees=degrees).cuda().to(dtype=dtype)
+        )
+        x = torch.ones([1, 3, 224, 224], dtype=dtype).cuda()
+        output = rotation_module(x)
+        self.assertEqual(output.dtype, dtype)
+
 
 class TestRandomSpatialJitter(BaseTest):
     def test_random_spatial_jitter_init(self) -> None:
@@ -714,6 +797,20 @@ def test_random_spatial_jitter_forward_jit_module(self) -> None:
         jittered_tensor = jit_spatialjitter(test_input)
         self.assertEqual(list(jittered_tensor.shape), list(test_input.shape))
 
+    def test_random_spatial_jitter_dtype_float64(self) -> None:
+        dtype = torch.float64
+        spatialjitter = transforms.RandomSpatialJitter(5).to(dtype=dtype)
+        x = torch.ones([1, 3, 224, 224], dtype=dtype)
+        output = spatialjitter(x)
+        self.assertEqual(output.dtype, dtype)
+
+    def test_random_spatial_jitter_dtype_float32(self) -> None:
+        dtype = torch.float32
+        spatialjitter = transforms.RandomSpatialJitter(5).to(dtype=dtype)
+        x = torch.ones([1, 3, 224, 224], dtype=dtype)
+        output = spatialjitter(x)
+        self.assertEqual(output.dtype, dtype)
+
 
 class TestCenterCrop(BaseTest):
     def test_center_crop_init(self) -> None:
@@ -1574,6 +1671,35 @@ def test_to_rgb_klt_forward_jit_module(self) -> None:
             self, inverse_tensor, torch.ones_like(inverse_tensor.rename(None))
         )
 
+    def test_to_rgb_dtype_float64(self) -> None:
+        dtype = torch.float64
+        to_rgb = transforms.ToRGB(transform="klt").to(dtype=dtype)
+        test_tensor = torch.ones(1, 3, 224, 224, dtype=dtype)
+        output = to_rgb(test_tensor.refine_names("B", "C", "H", "W"))
+        self.assertEqual(output.dtype, dtype)
+        inverse_output = to_rgb(output, inverse=True)
+        self.assertEqual(inverse_output.dtype, dtype)
+
+    def test_to_rgb_dtype_float32(self) -> None:
+        dtype = torch.float32
+        to_rgb = transforms.ToRGB(transform="klt").to(dtype=dtype)
+        test_tensor = torch.ones(1, 3, 224, 224, dtype=dtype)
+        output = to_rgb(test_tensor.refine_names("B", "C", "H", "W"))
+        self.assertEqual(output.dtype, dtype)
+        inverse_output = to_rgb(output, inverse=True)
+        self.assertEqual(inverse_output.dtype, dtype)
+
+    def test_to_rgb_dtype_float16_cuda(self) -> None:
+        if not torch.cuda.is_available():
+            raise unittest.SkipTest(
+                "Skipping ToRGB float16 dtype test due to not supporting CUDA."
+            )
+        dtype = torch.float16
+        to_rgb = transforms.ToRGB(transform="klt").cuda().to(dtype=dtype)
+        test_tensor = torch.ones(1, 3, 224, 224, dtype=dtype).cuda()
+        output = to_rgb(test_tensor.refine_names("B", "C", "H", "W"))
+        self.assertEqual(output.dtype, dtype)
+
 
 class TestGaussianSmoothing(BaseTest):
     def test_gaussian_smoothing_init_1d(self) -> None:
@@ -1582,11 +1708,17 @@ def test_gaussian_smoothing_init_1d(self) -> None:
         sigma = 2.0
         dim = 1
         smoothening_module = transforms.GaussianSmoothing(
-            channels, kernel_size, sigma, dim
+            channels,
+            kernel_size,
+            sigma,
+            dim,
+            padding=0,
         )
         self.assertEqual(smoothening_module.groups, channels)
+        self.assertEqual(smoothening_module.padding, 0)
         weight = torch.tensor([[0.3192, 0.3617, 0.3192]]).repeat(6, 1, 1)
         assertTensorAlmostEqual(self, smoothening_module.weight, weight, 0.001)
+        self.assertFalse(smoothening_module.padding)
 
     def test_gaussian_smoothing_init_2d(self) -> None:
         channels = 3
@@ -1594,7 +1726,11 @@ def test_gaussian_smoothing_init_2d(self) -> None:
         sigma = 2.0
         dim = 2
         smoothening_module = transforms.GaussianSmoothing(
-            channels, kernel_size, sigma, dim
+            channels,
+            kernel_size,
+            sigma,
+            dim,
+            padding=0,
         )
         self.assertEqual(smoothening_module.groups, channels)
         weight = torch.tensor(
@@ -1614,7 +1750,11 @@ def test_gaussian_smoothing_init_3d(self) -> None:
         sigma = 1.021
         dim = 3
         smoothening_module = transforms.GaussianSmoothing(
-            channels, kernel_size, sigma, dim
+            channels,
+            kernel_size,
+            sigma,
+            dim,
+            padding=0,
         )
         self.assertEqual(smoothening_module.groups, channels)
         weight = torch.tensor(
@@ -1654,7 +1794,11 @@ def test_gaussian_smoothing_1d(self) -> None:
         sigma = 2.0
         dim = 1
         smoothening_module = transforms.GaussianSmoothing(
-            channels, kernel_size, sigma, dim
+            channels,
+            kernel_size,
+            sigma,
+            dim,
+            padding=0,
         )
 
         test_tensor = torch.tensor([1.0, 5.0]).repeat(6, 2).unsqueeze(0)
@@ -1671,7 +1815,11 @@ def test_gaussian_smoothing_2d(self) -> None:
         sigma = 2.0
         dim = 2
         smoothening_module = transforms.GaussianSmoothing(
-            channels, kernel_size, sigma, dim
+            channels,
+            kernel_size,
+            sigma,
+            dim,
+            padding=0,
         )
 
         test_tensor = torch.tensor([1.0, 5.0]).repeat(3, 6, 3).unsqueeze(0)
@@ -1688,7 +1836,11 @@ def test_gaussian_smoothing_3d(self) -> None:
         sigma = 1.021
         dim = 3
         smoothening_module = transforms.GaussianSmoothing(
-            channels, kernel_size, sigma, dim
+            channels,
+            kernel_size,
+            sigma,
+            dim,
+            padding=0,
         )
 
         test_tensor = torch.tensor([1.0, 5.0, 1.0]).repeat(4, 6, 6, 2).unsqueeze(0)
@@ -1712,7 +1864,11 @@ def test_gaussian_smoothing_2d_jit_module(self) -> None:
         sigma = 2.0
         dim = 2
         smoothening_module = transforms.GaussianSmoothing(
-            channels, kernel_size, sigma, dim
+            channels,
+            kernel_size,
+            sigma,
+            dim,
+            padding=0,
         )
         jit_smoothening_module = torch.jit.script(smoothening_module)
 
@@ -1801,12 +1957,15 @@ def check_grad(self, grad_input, grad_output):
 
         class SymmetricPaddingLayer(torch.nn.Module):
             def forward(
-                self, x: torch.Tensor, padding: List[List[int]]
+                self, x_input: torch.Tensor, padding: List[List[int]]
             ) -> torch.Tensor:
-                return transforms.SymmetricPadding.apply(x_pt, padding)
+                return transforms.SymmetricPadding.apply(x_input, padding)
 
         sym_pad = SymmetricPaddingLayer()
-        sym_pad.register_backward_hook(check_grad)
+        if version.parse(torch.__version__) >= version.parse("1.8.0"):
+            sym_pad.register_full_backward_hook(check_grad)
+        else:
+            sym_pad.register_backward_hook(check_grad)
         x_out = sym_pad(x_pt, offset_pad)
         (x_out.sum() * 1).backward()
 
@@ -2008,3 +2167,17 @@ def test_transform_robustness_forward_padding_crop_output_jit_module(self) -> No
         test_input = torch.ones(1, 3, 224, 224)
         test_output = transform_robustness(test_input)
         self.assertEqual(test_output.shape, test_input.shape)
+
+    def test_transform_robustness_dtype_float64(self) -> None:
+        dtype = torch.float64
+        transform_robustness = transforms.TransformationRobustness().to(dtype=dtype)
+        x = torch.ones([1, 3, 224, 224], dtype=dtype)
+        output = transform_robustness(x)
+        self.assertEqual(output.dtype, dtype)
+
+    def test_transform_robustness_dtype_float32(self) -> None:
+        dtype = torch.float32
+        transform_robustness = transforms.TransformationRobustness().to(dtype=dtype)
+        x = torch.ones([1, 3, 224, 224], dtype=dtype)
+        output = transform_robustness(x)
+        self.assertEqual(output.dtype, dtype)
diff --git a/tests/optim/utils/test_reducer.py b/tests/optim/utils/test_reducer.py
index f2baa7675..a9fb9cc93 100644
--- a/tests/optim/utils/test_reducer.py
+++ b/tests/optim/utils/test_reducer.py
@@ -34,10 +34,10 @@ def test_channelreducer_pytorch(self) -> None:
         test_input = torch.randn(1, 32, 224, 224).abs()
         c_reducer = reducer.ChannelReducer(n_components=3, max_iter=100)
         test_output = c_reducer.fit_transform(test_input)
-        self.assertEquals(test_output.size(0), 1)
-        self.assertEquals(test_output.size(1), 3)
-        self.assertEquals(test_output.size(2), 224)
-        self.assertEquals(test_output.size(3), 224)
+        self.assertEqual(test_output.size(0), 1)
+        self.assertEqual(test_output.size(1), 3)
+        self.assertEqual(test_output.size(2), 224)
+        self.assertEqual(test_output.size(3), 224)
 
     def test_channelreducer_pytorch_dim_three(self) -> None:
         try:
@@ -52,9 +52,7 @@ def test_channelreducer_pytorch_dim_three(self) -> None:
         test_input = torch.randn(32, 224, 224).abs()
         c_reducer = reducer.ChannelReducer(n_components=3, max_iter=100)
         test_output = c_reducer.fit_transform(test_input)
-        self.assertEquals(test_output.size(0), 3)
-        self.assertEquals(test_output.size(1), 224)
-        self.assertEquals(test_output.size(2), 224)
+        self.assertEqual(list(test_output.shape), [3, 224, 224])
 
     def test_channelreducer_pytorch_pca(self) -> None:
         try:
@@ -70,10 +68,7 @@ def test_channelreducer_pytorch_pca(self) -> None:
         c_reducer = reducer.ChannelReducer(n_components=3, reduction_alg="PCA")
 
         test_output = c_reducer.fit_transform(test_input)
-        self.assertEquals(test_output.size(0), 1)
-        self.assertEquals(test_output.size(1), 3)
-        self.assertEquals(test_output.size(2), 224)
-        self.assertEquals(test_output.size(3), 224)
+        self.assertEqual(list(test_output.shape), [1, 3, 224, 224])
 
     def test_channelreducer_pytorch_custom_alg(self) -> None:
         test_input = torch.randn(1, 32, 224, 224).abs()
@@ -82,10 +77,7 @@ def test_channelreducer_pytorch_custom_alg(self) -> None:
             n_components=3, reduction_alg=reduction_alg, max_iter=100
         )
         test_output = c_reducer.fit_transform(test_input)
-        self.assertEquals(test_output.size(0), 1)
-        self.assertEquals(test_output.size(1), 3)
-        self.assertEquals(test_output.size(2), 224)
-        self.assertEquals(test_output.size(3), 224)
+        self.assertEqual(list(test_output.shape), [1, 3, 224, 224])
 
     def test_channelreducer_pytorch_custom_alg_components(self) -> None:
         reduction_alg = FakeReductionAlgorithm
@@ -149,10 +141,7 @@ def test_channelreducer_noreshape_pytorch(self) -> None:
         test_input = torch.randn(1, 224, 224, 32).abs()
         c_reducer = reducer.ChannelReducer(n_components=3, max_iter=100)
         test_output = c_reducer.fit_transform(test_input, swap_2nd_and_last_dims=False)
-        self.assertEquals(test_output.size(0), 1)
-        self.assertEquals(test_output.size(1), 224)
-        self.assertEquals(test_output.size(2), 224)
-        self.assertEquals(test_output.size(3), 3)
+        self.assertEqual(list(test_output.shape), [1, 224, 224, 3])
 
     def test_channelreducer_error(self) -> None:
         if not torch.cuda.is_available():