Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.

Commit d958171

Browse files
committed
test and flake
1 parent 0c16162 commit d958171

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py

-2
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,6 @@ def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.Lightning
8686
if accelerator.is_distributed:
8787
if accelerator.use_ddp:
8888
self.evaluator = DistributedDataParallel(self.evaluator, device_ids=[pl_module.device]) # type: ignore
89-
elif accelerator.use_dp:
90-
self.evaluator = DataParallel(self.evaluator, device_ids=[pl_module.device]) # type: ignore
9189
else:
9290
rank_zero_warn("This type of distributed accelerator is not supported. "
9391
"The online evaluator will not synchronize across GPUs.")

Tests/SSL/test_ssl_containers.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pytest
1313
import torch
1414
from pl_bolts.models.self_supervised.resnets import ResNet
15-
from pytorch_lightning import Trainer
15+
from pytorch_lightning import LightningModule, Trainer
1616
from pytorch_lightning.callbacks import ModelCheckpoint
1717
from torch.nn import Module
1818
from torch.nn.parallel import DistributedDataParallel
@@ -344,13 +344,11 @@ def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> No
344344

345345

346346
@pytest.mark.gpu
347-
def test_online_evaluator_distributed() -> None:
347+
def test_online_evaluator_not_distributed() -> None:
348348
"""
349-
A very basic test to check if the online evaluator uses the DDP flag correctly.
349+
Check if the online evaluator uses the DDP flag correctly when running not distributed
350350
"""
351-
mock_ddp_result = "mock_ddp_result"
352-
with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel",
353-
return_value=mock_ddp_result) as mock_ddp:
351+
with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel") as mock_ddp:
354352
callback = SSLOnlineEvaluatorInnerEye(class_weights=None,
355353
z_dim=1,
356354
num_classes=2,
@@ -361,20 +359,38 @@ def test_online_evaluator_distributed() -> None:
361359

362360
# Standard trainer without DDP
363361
trainer = Trainer()
362+
# Test the flag that the internal logic of on_pretrain_routine_start uses
363+
assert not trainer.accelerator_connector.is_distributed
364364
mock_module = mock.MagicMock(device=torch.device("cpu"))
365365
callback.on_pretrain_routine_start(trainer, mock_module)
366366
assert isinstance(callback.evaluator, Module)
367367
mock_ddp.assert_not_called()
368368

369+
370+
@pytest.mark.gpu
371+
def test_online_evaluator_distributed() -> None:
372+
"""
373+
Check if the online evaluator uses the DDP flag correctly when running distributed.
374+
"""
375+
mock_ddp_result = "mock_ddp_result"
376+
with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel",
377+
return_value=mock_ddp_result) as mock_ddp:
378+
callback = SSLOnlineEvaluatorInnerEye(class_weights=None,
379+
z_dim=1,
380+
num_classes=2,
381+
dataset="foo",
382+
drop_p=0.2,
383+
learning_rate=1e-5)
384+
369385
# Trainer with DDP
370-
mock_device = "fake_device"
371-
mock_module = mock.MagicMock(device=mock_device)
386+
device = torch.device("cuda:0")
387+
mock_module = mock.MagicMock(device=device)
372388
trainer = Trainer(accelerator="ddp", gpus=2)
373389
# Test the two flags that the internal logic of on_pretrain_routine_start uses
374390
assert trainer.accelerator_connector.is_distributed
375391
assert trainer.accelerator_connector.use_ddp
376392
callback.on_pretrain_routine_start(trainer, mock_module)
377393
# Check that the evaluator has been turned into a DDP object
378394
# We still need to mock DDP here because the constructor relies on having a process group available
379-
mock_ddp.assert_called_once_with(callback.evaluator, device_ids=[mock_device])
395+
mock_ddp.assert_called_once_with(callback.evaluator, device_ids=[device])
380396
assert callback.evaluator == mock_ddp_result

0 commit comments

Comments
 (0)