Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Optim-wip: Improve & fix LaplacianImage #967

Open
wants to merge 30 commits into
base: optim-wip
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9d92e6b
Improve LaplacianImage
ProGamerGov Jun 4, 2022
a8fa243
Improve NaturalImage docs
ProGamerGov Jun 6, 2022
c8f9cc5
Remove old commented out TransformationRobustness version
ProGamerGov Jun 9, 2022
44203fa
Fix torch.meshgrid warning
ProGamerGov Jun 18, 2022
9e2f953
self.assertEquals -> self.assertEqual
ProGamerGov Jun 18, 2022
f867bf3
Resolve register_backward_hook -> register_full_backward_hook depreci…
ProGamerGov Jun 22, 2022
54b652d
Fix lint errors
ProGamerGov Jun 22, 2022
dd58b75
Improve ImageParameterization docs for Sphinx
ProGamerGov Jun 27, 2022
aafc4f7
Improve GaussianSmoothing
ProGamerGov Jun 28, 2022
5935119
Add underscore to some FFTImage functions
ProGamerGov Jun 28, 2022
c1161c5
Improve NaturalImage docs
ProGamerGov Jun 29, 2022
eb6930e
Improve ImageParameterization docs (#551)
ProGamerGov Jul 2, 2022
55cad28
Improve GaussianSmoothing docs
ProGamerGov Jul 5, 2022
fbaa8eb
Add missing forward docs to PixelImage
ProGamerGov Jul 6, 2022
5457544
Improve FFTImage float dtype support
ProGamerGov Jul 7, 2022
45c8511
Add dtype tests for ImageParameterizations
ProGamerGov Jul 7, 2022
3dc061d
Fix weird error: RuntimeError: expected scalar type Float but found D…
ProGamerGov Jul 7, 2022
90d9fd1
Fix test failures
ProGamerGov Jul 7, 2022
dd13dc6
NaturalImage dtype test fix + transform dtype tests
ProGamerGov Jul 9, 2022
1ffcad4
Fix dtype tests
ProGamerGov Jul 9, 2022
2a0a898
Remove failing test
ProGamerGov Jul 9, 2022
7c833ad
Simplify some image parameterization tests
ProGamerGov Jul 9, 2022
2d81aec
Fix type hints
ProGamerGov Jul 15, 2022
32c4ba5
Fix parameterization doc type hint formatting
ProGamerGov Jul 16, 2022
6259b13
Improve doc types
ProGamerGov Jul 16, 2022
5335a4e
Improve parameterization docs
ProGamerGov Jul 18, 2022
668aff1
Fix NaturalImage docs issue
ProGamerGov Jul 19, 2022
4bab7d7
Fix doc type hint
ProGamerGov Jul 20, 2022
e727712
Simplify reducer tests
ProGamerGov Jul 22, 2022
33f9f66
Fix spelling
ProGamerGov Jul 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Improve parameterization docs
  • Loading branch information
ProGamerGov authored Jul 18, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 5335a4ed43d8bf04335960151894f6d1a8913635
108 changes: 57 additions & 51 deletions captum/optim/_param/image/images.py
Original file line number Diff line number Diff line change
@@ -225,14 +225,14 @@ def __init__(
"""
Args:

size (Tuple[int, int]): The height & width dimensions to use for the
parameterized output image tensor.
size (tuple of int): The height & width dimensions to use for the
parameterized output image tensor, in the format of: (height, width).
channels (int, optional): The number of channels to use for each image.
Default: ``3``
batch (int, optional): The number of images to stack along the batch
dimension.
Default: ``1``
init (torch.Tensor, optional): Optionally specify a tensor to
init (torch.Tensor, optional): Optionally specify a CHW or NCHW tensor to
use instead of creating one.
Default: ``None``
"""
@@ -304,8 +304,8 @@ def _get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]:
torch.fft update.

Returns:
fft functions (Tuple[Callable, Callable, Callable]): A list of FFT
functions to use for irfft, rfft, and fftfreq operations.
fft_functions (tuple of callable): A list of FFT functions to use for
irfft, rfft, and fftfreq operations.
"""

if version.parse(TORCH_VERSION) > version.parse("1.7.0"):
@@ -388,14 +388,14 @@ def __init__(
"""
Args:

size (Tuple[int, int]): The height & width dimensions to use for the
parameterized output image tensor.
size (tuple of int): The height & width dimensions to use for the
parameterized output image tensor, in the format of: (height, width).
channels (int, optional): The number of channels to use for each image.
Default: ``3``
batch (int, optional): The number of images to stack along the batch
dimension.
Default: ``1``
init (torch.Tensor, optional): Optionally specify a tensor to
init (torch.Tensor, optional): Optionally specify a CHW or NCHW tensor to
use instead of creating one.
Default: ``None``
"""
@@ -445,7 +445,7 @@ class LaplacianImage(ImageParameterization):

def __init__(
self,
size: Tuple[int, int] = (224, 225),
size: Tuple[int, int] = (224, 224),
channels: int = 3,
batch: int = 1,
init: Optional[torch.Tensor] = None,
@@ -455,15 +455,14 @@ def __init__(
"""
Args:

size (Tuple[int, int], optional): The height & width dimensions to use for
the parameterized output image tensor.
Default: ``(224, 224)``
size (tuple of int): The height & width dimensions to use for the
parameterized output image tensor, in the format of: (height, width).
channels (int, optional): The number of channels to use for each image.
Default: ``3``
batch (int, optional): The number of images to stack along the batch
dimension.
Default: ``1``
init (torch.Tensor, optional): Optionally specify a tensor to
init (torch.Tensor, optional): Optionally specify a CHW or NCHW tensor to
use instead of creating one.
Default: ``None``
power (float, optional): The desired power value to use.
@@ -585,11 +584,11 @@ def __init__(
"""
Args:

shapes (List[int] or List[List[int]]): The shapes of the shared
shapes (list of int or list of list of int): The shapes of the shared
tensors to use for creating the nn.Parameter tensors.
parameterization (ImageParameterization): An image parameterization
instance.
offset (int or List[int] or List[List[int]] , optional): The offsets
offset (int or list of int or list of list of int, optional): The offsets
to use for the shared tensors.
Default: ``None``
"""
@@ -615,7 +614,7 @@ def _get_offset(self, offset: Union[int, Tuple[int]], n: int) -> List[List[int]]

Args:

offset (int or List[int] or List[List[int]], optional): The offsets
offset (int or list of int or list of list of int, optional): The offsets
to use for the shared tensors.
n (int): The number of tensors needing offset values.

@@ -641,10 +640,10 @@ def _apply_offset(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:

Args:

x_list (List[torch.Tensor]): list of tensors to offset.
x_list (list of torch.Tensor): list of tensors to offset.

Returns:
A (List[torch.Tensor]): list of offset tensors.
A (list of torch.Tensor): list of offset tensors.
"""

A: List[torch.Tensor] = []
@@ -679,8 +678,8 @@ def _interpolate_bilinear(
Args:

x (torch.Tensor): The NCHW tensor to resize.
size (Tuple[int, int]): The desired output size to resize the input
to, with a format of: [height, width].
size (tuple of int): The desired output size to resize the input to, with
a format of: [height, width].

Returns:
x (torch.Tensor): A resized NCHW tensor.
@@ -708,8 +707,8 @@ def _interpolate_trilinear(
Args:

x (torch.Tensor): The NCHW tensor to resize.
size (Tuple[int, int, int]): The desired output size to resize the input
to, with a format of: [channels, height, width].
size (tuple of int): The desired output size to resize the input to, with
a format of: [channels, height, width].

Returns:
x (torch.Tensor): A resized NCHW tensor.
@@ -819,8 +818,8 @@ def __init__(
"""
Args:

parameterizations (List[Union[ImageParameterization, torch.Tensor]]): A
list of image parameterizations and tensors to concatenate across a
parameterizations (list of ImageParameterization and torch.Tensor): A list
of image parameterizations and tensors to concatenate across a
specified dimension.
dim (int, optional): Optionally specify the dim to concatinate
parameterization outputs on. Default is set to the batch dimension.
@@ -912,11 +911,6 @@ class NaturalImage(ImageParameterization):
True
>>> print(image_tensor.shape)
torch.Size([1, 3, 224, 224])

:ivar parameterization: initial value (ImageParameterization): The given image
parameterization instance given when initializing ``NaturalImage``.
:ivar decorrelation_module: initial value (nn.Module): The given decorrelation
module instance given when initializing ``NaturalImage``.
"""

def __init__(
@@ -926,48 +920,60 @@ def __init__(
batch: int = 1,
init: Optional[torch.Tensor] = None,
parameterization: ImageParameterization = FFTImage,
squash_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
squash_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = torch.sigmoid,
decorrelation_module: Optional[nn.Module] = ToRGB(transform="klt"),
decorrelate_init: bool = True,
) -> None:
"""
Args:

size (Tuple[int, int], optional): The height and width to use for the
nn.Parameter image tensor. This parameter is not used if
parameterization is an instance.
Default: ``(224, 224)``
size (tuple of int, optional): The height and width to use for the
nn.Parameter image tensor, in the format of: (height, width).
This parameter is not used if the given ``parameterization`` is an
instance.
Default: ``(224, 224)`
channels (int, optional): The number of channels to use when creating the
nn.Parameter tensor. This parameter is not used if parameterization is
an instance.
nn.Parameter tensor. This parameter is not used if the given
``parameterization`` is an instance.
Default: ``3``
batch (int, optional): The number of channels to use when creating the
nn.Parameter tensor. This parameter is not used if ``parameterization``
is an instance.
nn.Parameter tensor. This parameter is not used if the given
``parameterization`` is an instance.
Default: ``1``
init (torch.Tensor, optional): Optionally specify a tensor to use instead
of creating one from random noise. This parameter is not used if
``parameterization`` is an instance. Set to ``None`` for random init.
of creating one from random noise. This parameter is not used if the
given ``parameterization`` is an instance. Set to ``None`` for random
init.
Default: ``None``
parameterization (ImageParameterization, optional): An image
parameterization class, or instance of an image parameterization class.
Default: FFTImage
squash_func (Callable[[torch.Tensor], torch.Tensor]], optional): The squash
function to use after color recorrelation. A function, lambda function,
or callable class instance.
Default: ``None``
Default: :class:`.FFTImage`
squash_func (callable, optional): The squash function to use after color
recorrelation. A function, lambda function, or callable class instance.
Any provided squash function should take a single input tensor and
return a single output tensor. If set to ``None``, then
:class:`torch.nn.Identity` will be used to make it a non op.
Default: :func:`torch.sigmoid`
decorrelation_module (nn.Module, optional): A module instance that
recorrelates the colors of an input image. Custom modules can make use
of the ``decorrelate_init`` parameter by having a second ``inverse``
parameter in their forward functions that performs the inverse
operation when it is set to ``True`` (see
:class:`captum.optim.transforms.ToRGB` for an example).
Set to ``None`` for no recorrelation.
Default: ``ToRGB``
operation when it is set to ``True`` (see :class:`.ToRGB` for an
example). Set to ``None`` for no recorrelation.
Default: :class:`.ToRGB`
decorrelate_init (bool, optional): Whether or not to apply color
decorrelation to the init tensor input. This parameter is not used if
``parameterization`` is an instance or if init is ``None``.
the given ``parameterization`` is an instance or if init is ``None``.
Default: ``True``

Attributes:

parameterization (ImageParameterization): The given image parameterization
instance given when initializing ``NaturalImage``.
Default: :class:`.FFTImage`
decorrelation_module (torch.nn.Module): The given decorrelation module
instance given when initializing ``NaturalImage``.
Default: :class:`.ToRGB`
"""
super().__init__()
if not isinstance(parameterization, ImageParameterization):
@@ -987,7 +993,7 @@ def __init__(
)
init = self.decorrelate(init, inverse=True).rename(None)

self.squash_func = torch.sigmoid if squash_func is None else squash_func
self.squash_func = squash_func or torch.nn.Identity()
if not isinstance(parameterization, ImageParameterization):
parameterization = parameterization(
size=size, channels=channels, batch=batch, init=init