Skip to content

Commit

Permalink
Backprop.visualize passes use_gpu flag (#25)
Browse files Browse the repository at this point in the history
* Backprop.visualize passes use_gpu flag

* Bump version
  • Loading branch information
MisaOgura authored Jan 2, 2020
1 parent aa57862 commit d48b391
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 4 deletions.
2 changes: 1 addition & 1 deletion flashtorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.1'
__version__ = '0.1.2'
6 changes: 4 additions & 2 deletions flashtorch/saliency/backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,13 @@ def visualize(self, input_, target_class, guided=False, use_gpu=False,

gradients = self.calculate_gradients(input_,
target_class,
guided=guided)
guided=guided,
use_gpu=use_gpu)
max_gradients = self.calculate_gradients(input_,
target_class,
guided=guided,
take_max=True)
take_max=True,
use_gpu=use_gpu)

# Setup subplots

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
DOCLINES = (__doc__ or '').split("\n")
long_description = "\n".join(DOCLINES[2:])

version = '0.1.1'
version = '0.1.2'

setup(
name='flashtorch',
Expand Down
39 changes: 39 additions & 0 deletions tests/test_backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,45 @@ def test_warn_when_prediction_is_wrong(mocker, model):
backprop.calculate_gradients(input_, target_class)


# Test visualize method


def test_visualize_calls_calculate_gradients_twice(mocker, model):
backprop = Backprop(model)
mocker.spy(backprop, 'calculate_gradients')

top_class = 5
target_class = 5
input_ = torch.zeros([1, 3, 224, 224])

target = make_expected_gradient_target(top_class)

mock_output = make_mock_output(mocker, model, target_class)

backprop.visualize(input_, target_class, use_gpu=True)

assert backprop.calculate_gradients.call_count == 2


def test_visualize_passes_gpu_flag(mocker, model):
backprop = Backprop(model)
mocker.spy(backprop, 'calculate_gradients')

top_class = 5
target_class = 5
input_ = torch.zeros([1, 3, 224, 224])

target = make_expected_gradient_target(top_class)

mock_output = make_mock_output(mocker, model, target_class)

backprop.visualize(input_, target_class, use_gpu=True)

_, _, kwargs = backprop.calculate_gradients.mock_calls[0]

assert kwargs['use_gpu']


# Test compatibilities with torchvision models


Expand Down

0 comments on commit d48b391

Please # to comment.