-
Notifications
You must be signed in to change notification settings - Fork 519
Adding the Latent Shift attribution method #1024
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Open
ieee8023
wants to merge
72
commits into
pytorch:master
Choose a base branch
from
ieee8023:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+451
−0
Open
Changes from all commits
Commits
Show all changes
72 commits
Select commit
Hold shift + click to select a range
32b38e0
Add Latent Shift
ieee8023 0e57fe4
video
ieee8023 7c5025c
align text
ieee8023 3a89340
cleanup
ieee8023 d0c833a
clean up docs
ieee8023 4dc25bf
add support for colab version
ieee8023 12e78d3
cleanup
ieee8023 2cf44ba
add more docs
ieee8023 0a74565
Merge branch 'master' into master
ieee8023 0ceb34e
cleanup format and add test
ieee8023 9043790
more cleanup
ieee8023 2963ae5
cleanup and add more docs
ieee8023 d4320d2
fix flake8 errors
ieee8023 a039159
fixing flake8 for real
ieee8023 907d2d7
fix format and add opion to limit printing
ieee8023 01bb3b2
fix type error
ieee8023 222128e
flake8
ieee8023 4c588b9
autopep8
ieee8023 c1dd756
make mypy happy
ieee8023 77963d8
ufmt format
ieee8023 b0a08d7
I really think flake8 will pass now
ieee8023 8432ff3
match reference to other references
ieee8023 5597155
small change to kick off tests again
ieee8023 5407086
Merge branch 'master' into master
ieee8023 9b5272e
Merge branch 'master' into master
ieee8023 7a19759
Merge branch 'master' into master
ieee8023 7d64a75
Merge branch 'master' into master
ieee8023 e245cde
Merge branch 'master' into master
ieee8023 7245048
Merge branch 'master' into master
ieee8023 03fe557
Merge branch 'master' into master
ieee8023 298c9e8
add options for extra loops and the cmap value
ieee8023 04f16ca
Merge branch 'master' into master
ieee8023 3d6f842
fix flake8
ieee8023 17bc3af
Add Latent Shift
ieee8023 7615dcc
video
ieee8023 558b429
align text
ieee8023 430888e
cleanup
ieee8023 bf434a9
clean up docs
ieee8023 96e8b42
add support for colab version
ieee8023 34f48f6
cleanup
ieee8023 554db30
add more docs
ieee8023 cefc673
cleanup format and add test
ieee8023 42c2c36
more cleanup
ieee8023 77c574b
cleanup and add more docs
ieee8023 8aa3fec
fix flake8 errors
ieee8023 90ffd8e
fixing flake8 for real
ieee8023 67a576c
fix format and add opion to limit printing
ieee8023 2a9cab7
fix type error
ieee8023 a0f156a
flake8
ieee8023 435bee8
autopep8
ieee8023 2f618ff
make mypy happy
ieee8023 3f9bbdd
ufmt format
ieee8023 cec2237
I really think flake8 will pass now
ieee8023 b387097
match reference to other references
ieee8023 29951a0
small change to kick off tests again
ieee8023 653a67a
add options for extra loops and the cmap value
ieee8023 2292efa
fix flake8
ieee8023 390fee0
Merge branch 'master' of github.com:ieee8023/captum
ieee8023 74af5c8
refactor image writing
ieee8023 e9196ed
refactor for batches and just returning heatmaps
ieee8023 8a24e9b
pep8
ieee8023 f873da3
ufmt
ieee8023 92e93a7
format errors
ieee8023 8f9d8a2
fix typing
ieee8023 7779497
reduce string length
ieee8023 cd2c2f5
remove usage of torchvision in tests
ieee8023 6855d55
Merge branch 'master' into master
ieee8023 10a43c0
add sigmoid param
ieee8023 1d8e9dd
Merge branch 'master' of github.com:ieee8023/captum
ieee8023 c0dfdfb
Merge branch 'master' into master
ieee8023 4eb4c6d
Merge branch 'master' into master
ieee8023 d63a803
Merge branch 'master' into master
ieee8023 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,267 @@ | ||
#!/usr/bin/env python3 | ||
|
||
from typing import Any, Callable, Dict, List, Tuple, Union | ||
|
||
import numpy as np | ||
import torch | ||
from captum.attr._utils.attribution import GradientAttribution | ||
from captum.log import log_usage | ||
from torch import Tensor | ||
|
||
|
||
class LatentShift(GradientAttribution): | ||
r"""An implementation of the Latent Shift method to generate | ||
counterfactual explanations. This method uses an autoencoder to restrict | ||
the possible adversarial examples to remain in the data space by | ||
adjusting the latent space of the autoencoder using dy/dz instead of | ||
dy/dx in order to change the classifier's prediction. | ||
|
||
This class implements a search strategy to determine the lambda needed to | ||
change the prediction of the classifier by a specific amount as well as | ||
the code to generate a video and construct a heatmap representing the | ||
image changes for viewing as an image. | ||
|
||
More details regarding the latent shift method can be found in the | ||
original paper: | ||
https://arxiv.org/abs/2102.09475 | ||
And the original code repository: | ||
https://github.com/mlmed/gifsplanation | ||
""" | ||
|
||
def __init__(self, forward_func: Callable, autoencoder) -> None: | ||
r""" | ||
Args: | ||
forward_func (callable): The forward function of the model or | ||
any modification of it | ||
autoencoder: An object with an encode and decode function which | ||
maintains a gradient tape. | ||
""" | ||
GradientAttribution.__init__(self, forward_func) | ||
self.ae = autoencoder | ||
|
||
# check if ae has encode and decode | ||
assert hasattr(self.ae, "encode") | ||
assert hasattr(self.ae, "decode") | ||
|
||
@log_usage() | ||
def attribute( | ||
self, | ||
inputs: Tensor, | ||
target: int, | ||
fix_range: Union[Tuple, None] = None, | ||
search_pred_diff: float = 0.8, | ||
search_step_size: float = 10.0, | ||
search_max_steps: int = 3000, | ||
search_max_pixel_diff_pct: float = 0.05, | ||
lambda_sweep_steps: int = 10, | ||
heatmap_method: str = "int", | ||
apply_sigmoid: bool = True, | ||
verbose: bool = True, | ||
return_dicts: bool = False, | ||
) -> Union[Tensor, List[Dict[str, Any]]]: | ||
r""" | ||
This method performs a search in order to determine the correct lambda | ||
values to generate the shift. The search starts by stepping by | ||
`search_step_size` in the negative direction while trying to determine | ||
if the output of the classifier has changed by `search_pred_diff` or | ||
when the change in the predict in stops going down. In order to avoid | ||
artifacts if the shift is too large or in the wrong direction an extra | ||
stop conditions is added `search_max_pixel_diff` if the change in the | ||
image is too large. To avoid the search from taking too long a | ||
`search_max_steps` will prevent the search from going on endlessly. | ||
|
||
|
||
Args: | ||
|
||
inputs (tensor): Input for which the counterfactual is computed. | ||
target (int): Output indices for which dydz is computed (for | ||
classification cases, this is usually the target class). | ||
fix_range (tuple): Overrides searching and directly specifies the | ||
lambda range to use. e.g. [-100,0]. | ||
search_pred_diff (float): The desired change in the classifiers | ||
prediction. For example if the classifer predicts 0.9 | ||
and pred_diff=0.8 the search will try to generate a | ||
counterfactual where the prediction is 0.1. | ||
search_step_size (float): When searching for the right lambda to use | ||
this will be the initial step size. This is similar to | ||
a learning rate. Smaller values avoid jumping over the | ||
ideal lambda but the search may take a long time. | ||
search_max_steps (int): The max steps to take when doing the search. | ||
Sometimes steps make a tiny improvement and can go on | ||
forever. This just bounds the time and gives up the | ||
search. | ||
search_max_pixel_diff_pct (float): When searching, stop if the pixel | ||
difference is larger than this amount. This will | ||
prevent large artifacts being introduced into the | ||
image. |img0 - imgx| > |img0|*pct | ||
lambda_sweep_steps (int): How many frames to generate for the video. | ||
heatmap_method: Default: 'int'. Possible methods: 'int': Average | ||
per frame differences. 'mean' : Average difference | ||
between 0 and other lambda frames. 'mm': Difference | ||
between first and last frames. 'max': Max difference | ||
from lambda 0 frame | ||
apply_sigmoid: Default: True. Apply a sigmoid to the output of the | ||
model. Set to false to work with regression models or | ||
if the model already applies a sigmoid. | ||
verbose: True to print debug text | ||
return_dicts (bool): Return a list of dicts containing information | ||
from each image processed. Default False | ||
|
||
Returns: | ||
attributions or (if return_dict=True) a list of dicts containing the | ||
follow keys: | ||
generated_images: A list of images generated at each step along | ||
the dydz vector from the smallest lambda to the largest. By | ||
default the smallest lambda represents the counterfactual | ||
image and the largest lambda is 0 (representing no change). | ||
lambdas: A list of the lambda values for each generated image. | ||
preds: A list of the predictions of the model for each generated | ||
image. | ||
heatmap: A heatmap indicating the pixels which change in the | ||
video sequence of images. | ||
|
||
|
||
Example:: | ||
|
||
>>> # Load classifier and autoencoder | ||
>>> model = classifiers.FaceAttribute() | ||
>>> ae = autoencoders.VQGAN(weights="faceshq") | ||
>>> | ||
>>> # Load image | ||
>>> x = torch.randn(1, 3, 1024, 1024) | ||
>>> | ||
>>> # Defining Latent Shift module | ||
>>> attr = captum.attr.LatentShift(model, ae) | ||
>>> | ||
>>> # Computes counterfactual for class 3. | ||
>>> output = attr.attribute(x, target=3) | ||
|
||
""" | ||
|
||
assert lambda_sweep_steps > 1, "lambda_sweep_steps must be at least 2" | ||
|
||
results = [] | ||
# cheap batching | ||
for idx in range(inputs.shape[0]): | ||
inp = inputs[idx].unsqueeze(0) | ||
z = self.ae.encode(inp).detach() | ||
z.requires_grad = True | ||
x_lambda0 = self.ae.decode(z) | ||
pred = self.forward_func(x_lambda0)[:, target] | ||
if apply_sigmoid: | ||
pred = torch.sigmoid(pred) | ||
dzdxp = torch.autograd.grad(pred, z)[0] | ||
|
||
# Cache so we can reuse at sweep stage | ||
cache = {} | ||
|
||
def compute_shift(lambdax): | ||
"""Compute the shift for a specific lambda""" | ||
if lambdax not in cache: | ||
x_lambdax = self.ae.decode(z + dzdxp * lambdax).detach() | ||
pred1 = self.forward_func(x_lambdax)[:, target] | ||
if apply_sigmoid: | ||
pred1 = torch.sigmoid(pred1) | ||
pred1 = pred1.detach().cpu().numpy() | ||
cache[lambdax] = x_lambdax, pred1 | ||
return cache[lambdax] | ||
|
||
_, initial_pred = compute_shift(0) | ||
|
||
if fix_range: | ||
lbound, rbound = fix_range | ||
else: | ||
# Left range | ||
lbound = 0 | ||
last_pred = initial_pred | ||
pixel_sum = x_lambda0.abs().sum() # Used for pixel diff | ||
while True: | ||
x_lambdax, cur_pred = compute_shift(lbound) | ||
pixel_diff = torch.abs(x_lambda0 - x_lambdax).sum().detach().cpu() | ||
if verbose: | ||
toprint = [ | ||
f"Shift: {lbound}", | ||
f"Pred: {float(cur_pred)}", | ||
f"pixel_diff: {float(pixel_diff)}", | ||
f"sum*diff_pct: {pixel_sum * search_max_pixel_diff_pct}", | ||
] | ||
print(", ".join(toprint)) | ||
|
||
# If we stop decreasing the prediction | ||
if last_pred < cur_pred: | ||
break | ||
# If the prediction becomes very low | ||
if cur_pred < 0.05: | ||
break | ||
# If we have decreased the prediction by pred_diff | ||
if initial_pred - search_pred_diff > cur_pred: | ||
break | ||
# If we are moving in the latent space too much | ||
if lbound <= -search_max_steps: | ||
break | ||
# If we move too far we will distort the image | ||
if pixel_diff > (pixel_sum * search_max_pixel_diff_pct): | ||
break | ||
|
||
last_pred = cur_pred | ||
lbound = lbound - search_step_size + lbound // 10 | ||
|
||
# Right range search not implemented | ||
rbound = 0 | ||
|
||
if verbose: | ||
print("Selected bounds: ", lbound, rbound) | ||
|
||
# Sweep over the range of lambda values to create a sequence | ||
lambdas = np.linspace(lbound, rbound, lambda_sweep_steps) | ||
assert lambda_sweep_steps == len( | ||
lambdas | ||
), "Inconsistent number of lambda steps" | ||
|
||
if verbose: | ||
print("Lambdas to compute: ", lambdas) | ||
|
||
preds = [] | ||
generated_images = [] | ||
|
||
for lam in lambdas: | ||
x_lambdax, pred = compute_shift(lam) | ||
generated_images.append(x_lambdax.cpu().numpy()[0]) | ||
preds.append(float(pred)) | ||
|
||
params = {} | ||
params["generated_images"] = np.array(generated_images) | ||
params["lambdas"] = lambdas | ||
params["preds"] = preds | ||
|
||
x_lambda0 = x_lambda0.detach().cpu().numpy() | ||
if heatmap_method == "max": | ||
# Max difference from lambda 0 frame | ||
heatmap = np.max(np.abs(x_lambda0 - generated_images), 0) | ||
|
||
elif heatmap_method == "mean": | ||
# Average difference between 0 and other lambda frames | ||
heatmap = np.mean(np.abs(x_lambda0 - generated_images), 0) | ||
|
||
elif heatmap_method == "mm": | ||
# Difference between first and last frames | ||
heatmap = np.abs(generated_images[0] - generated_images[-1]) | ||
|
||
elif heatmap_method == "int": | ||
# Average per frame differences | ||
image_changes = [] | ||
for i in range(len(generated_images) - 1): | ||
image_changes.append( | ||
np.abs(generated_images[i] - generated_images[i + 1]) | ||
) | ||
heatmap = np.mean(image_changes, 0) | ||
else: | ||
raise Exception("Unknown heatmap_method for 2d image") | ||
|
||
params["heatmap"] = heatmap | ||
results.append(params) | ||
|
||
if return_dicts: | ||
return results | ||
else: | ||
return torch.tensor([result["heatmap"] for result in results]) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.