Skip to content

Using / reproducing TRIM from the paper "Transformation Importance with Applications to Cosmology" ๐ŸŒŒ (ICLR Workshop 2020)

License

Notifications You must be signed in to change notification settings

csinva/transformation-importance

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

Official code for using / reproducing TRIM from the paper Transformation Importance with Applications to Cosmology (ICLR 2020 Workshop). This code shows examples and provides useful wrappers for calculating importance in a transformed feature space.

This repo is actively maintained. For any questions please file an issue.

trim

examples/documentation

  • dependencies: depends on the pip-installable acd package
  • examples: different folders (e.g. ex_cosmology, ex_fake_news, ex_mnist, ex_urban_sound contain examples for using TRIM in different settings)
  • src: the core code is in the trim folder, containing wrappers and code for different transformations
  • requirements: tested with python 3.7 and pytorch > 1.0
Attribution to different scales in cosmological images Fake news attribution to different topics
Attribution to different NMF components in MNIST classification Attribution to different frequencies in audio classification

sample usage

import torch
import torch.nn as nn
from trim import TrimModel
from functools import partial

# setup a trim model
model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 1)) # orig model
transform = partial(torch.rfft, signal_ndim=1, onesided=False) # fft
inv_transform = partial(torch.irfft, signal_ndim=1, onesided=False) # inverse fft
model_trim = TrimModel(model=model, inv_transform=inv_transform) # trim model

# get a data point
x = torch.randn(1, 10)
s = transform(x)

# can now use any attribution method on the trim model
# get (input_x_gradient) attribution in the fft space
s.requires_grad = True
model_trim(s).backward()
input_x_gradient = s.grad * s
  • see notebooks for more detailed usage

related work

  • ACD (ICLR 2019 pdf, github) - extends CD to CNNs / arbitrary DNNs, and aggregates explanations into a hierarchy
  • CDEP (ICML 2020 pdf, github) - penalizes CD / ACD scores during training to make models generalize better
  • DAC (arXiv 2019 pdf, github) - finds disentangled interpretations for random forests
  • PDR framework (PNAS 2019 pdf) - an overarching framewwork for guiding and framing interpretable machine learning

reference

  • feel free to use/share this code openly
  • if you find this code useful for your research, please cite the following:
@article{singh2020transformation,
    title={Transformation Importance with Applications to Cosmology},
    author={Singh, Chandan and Ha, Wooseok and Lanusse, Francois, and Boehm, Vanessa, and Liu, Jia and Yu, Bin},
    journal={arXiv preprint arXiv:2003.01926},
    year={2020},
    url={https://arxiv.org/abs/2003.01926},
}