12
12
import pytest
13
13
import torch
14
14
from pl_bolts .models .self_supervised .resnets import ResNet
15
- from pytorch_lightning import Trainer
15
+ from pytorch_lightning import LightningModule , Trainer
16
16
from pytorch_lightning .callbacks import ModelCheckpoint
17
17
from torch .nn import Module
18
18
from torch .nn .parallel import DistributedDataParallel
@@ -344,13 +344,11 @@ def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> No
344
344
345
345
346
346
@pytest .mark .gpu
347
- def test_online_evaluator_distributed () -> None :
347
+ def test_online_evaluator_not_distributed () -> None :
348
348
"""
349
- A very basic test to check if the online evaluator uses the DDP flag correctly.
349
+ Check if the online evaluator uses the DDP flag correctly when running not distributed
350
350
"""
351
- mock_ddp_result = "mock_ddp_result"
352
- with mock .patch ("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel" ,
353
- return_value = mock_ddp_result ) as mock_ddp :
351
+ with mock .patch ("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel" ) as mock_ddp :
354
352
callback = SSLOnlineEvaluatorInnerEye (class_weights = None ,
355
353
z_dim = 1 ,
356
354
num_classes = 2 ,
@@ -361,20 +359,38 @@ def test_online_evaluator_distributed() -> None:
361
359
362
360
# Standard trainer without DDP
363
361
trainer = Trainer ()
362
+ # Test the flag that the internal logic of on_pretrain_routine_start uses
363
+ assert not trainer .accelerator_connector .is_distributed
364
364
mock_module = mock .MagicMock (device = torch .device ("cpu" ))
365
365
callback .on_pretrain_routine_start (trainer , mock_module )
366
366
assert isinstance (callback .evaluator , Module )
367
367
mock_ddp .assert_not_called ()
368
368
369
+
370
+ @pytest .mark .gpu
371
+ def test_online_evaluator_distributed () -> None :
372
+ """
373
+ Check if the online evaluator uses the DDP flag correctly when running distributed.
374
+ """
375
+ mock_ddp_result = "mock_ddp_result"
376
+ with mock .patch ("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel" ,
377
+ return_value = mock_ddp_result ) as mock_ddp :
378
+ callback = SSLOnlineEvaluatorInnerEye (class_weights = None ,
379
+ z_dim = 1 ,
380
+ num_classes = 2 ,
381
+ dataset = "foo" ,
382
+ drop_p = 0.2 ,
383
+ learning_rate = 1e-5 )
384
+
369
385
# Trainer with DDP
370
- mock_device = "fake_device"
371
- mock_module = mock .MagicMock (device = mock_device )
386
+ device = torch . device ( "cuda:0" )
387
+ mock_module = mock .MagicMock (device = device )
372
388
trainer = Trainer (accelerator = "ddp" , gpus = 2 )
373
389
# Test the two flags that the internal logic of on_pretrain_routine_start uses
374
390
assert trainer .accelerator_connector .is_distributed
375
391
assert trainer .accelerator_connector .use_ddp
376
392
callback .on_pretrain_routine_start (trainer , mock_module )
377
393
# Check that the evaluator has been turned into a DDP object
378
394
# We still need to mock DDP here because the constructor relies on having a process group available
379
- mock_ddp .assert_called_once_with (callback .evaluator , device_ids = [mock_device ])
395
+ mock_ddp .assert_called_once_with (callback .evaluator , device_ids = [device ])
380
396
assert callback .evaluator == mock_ddp_result
0 commit comments