Skip to content

Commit 5c2aab1

Browse files
author
gaoyang07
committed
fix ut bugs
1 parent 0f28970 commit 5c2aab1

File tree

1 file changed

+16
-47
lines changed

1 file changed

+16
-47
lines changed

tests/test_models/test_algorithms/test_darts.py

+16-47
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import os
3-
from typing import Dict, List, Tuple
3+
from typing import Dict
44
from unittest import TestCase
55

66
import pytest
@@ -18,8 +18,6 @@
1818
from mmrazor.models.algorithms.nas.darts import DartsDDP
1919
from mmrazor.registry import MODELS
2020

21-
# from unittest.mock import Mock
22-
2321
MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True)
2422
MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True)
2523
MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True)
@@ -72,15 +70,6 @@ def forward(self, batch_inputs, data_samples=None, mode='tensor'):
7270
return out
7371

7472

75-
class ToyDataPreprocessor(torch.nn.Module):
76-
77-
def forward(
78-
self,
79-
data: Dict,
80-
training: bool = True) -> Tuple[torch.Tensor, List[ClsDataSample]]:
81-
return data['inputs'], data['data_samples']
82-
83-
8473
class TestDarts(TestCase):
8574

8675
def setUp(self) -> None:
@@ -159,7 +148,7 @@ def test_search_subnet(self) -> None:
159148
self.assertIsInstance(subnet, dict)
160149

161150
def test_darts_train_step(self) -> None:
162-
model = ToyDiffModule(data_preprocessor=ToyDataPreprocessor())
151+
model = ToyDiffModule()
163152
mutator = DiffModuleMutator()
164153
mutator.prepare_from_supernet(model)
165154

@@ -182,7 +171,7 @@ def test_darts_train_step(self) -> None:
182171
self.assertIsNotNone(loss)
183172

184173
def test_darts_with_unroll(self) -> None:
185-
model = ToyDiffModule(data_preprocessor=ToyDataPreprocessor())
174+
model = ToyDiffModule()
186175
mutator = DiffModuleMutator()
187176
mutator.prepare_from_supernet(model)
188177

@@ -205,19 +194,17 @@ def setUpClass(cls) -> None:
205194
os.environ['MASTER_PORT'] = '12345'
206195

207196
# initialize the process group
208-
if torch.cuda.is_available():
209-
backend = 'nccl'
210-
cls.device = 'cuda'
211-
else:
212-
backend = 'gloo'
197+
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
213198
dist.init_process_group(backend, rank=0, world_size=1)
214199

215200
def prepare_model(self, device_ids=None) -> Darts:
216-
model = ToyDiffModule().to(self.device)
217-
mutator = DiffModuleMutator().to(self.device)
201+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
202+
203+
model = ToyDiffModule()
204+
mutator = DiffModuleMutator()
218205
mutator.prepare_from_supernet(model)
219206

220-
algo = Darts(model, mutator)
207+
algo = Darts(model, mutator).to(self.device)
221208

222209
return DartsDDP(
223210
module=algo, find_unused_parameters=True, device_ids=device_ids)
@@ -230,52 +217,34 @@ def tearDownClass(cls) -> None:
230217
not torch.cuda.is_available(), reason='cuda device is not avaliable')
231218
def test_init(self) -> None:
232219
ddp_model = self.prepare_model()
233-
# ddp_model = DartsDDP(module=model, device_ids=[0])
234220
self.assertIsInstance(ddp_model, DartsDDP)
235221

236222
def test_dartsddp_train_step(self) -> None:
237-
model = ToyDiffModule(data_preprocessor=ToyDataPreprocessor())
238-
mutator = DiffModuleMutator()
239-
mutator.prepare_from_supernet(model)
240-
241223
# data is tensor
242-
algo = Darts(model, mutator)
243-
ddp_model = DartsDDP(module=algo, find_unused_parameters=True)
224+
ddp_model = self.prepare_model()
244225
data = self._prepare_fake_data()
245226
optim_wrapper = build_optim_wrapper(ddp_model, self.OPTIM_WRAPPER_CFG)
246227
loss = ddp_model.train_step(data, optim_wrapper)
247228

248229
self.assertIsNotNone(loss)
249230

250231
# data is tuple or list
251-
algo = Darts(model, mutator)
252-
ddp_model = DartsDDP(module=algo, find_unused_parameters=True)
232+
ddp_model = self.prepare_model()
253233
data = [self._prepare_fake_data() for _ in range(2)]
254234
optim_wrapper_dict = OptimWrapperDict(
255-
architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)),
256-
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
235+
architecture=OptimWrapper(SGD(ddp_model.parameters(), lr=0.1)),
236+
mutator=OptimWrapper(SGD(ddp_model.parameters(), lr=0.01)))
257237
loss = ddp_model.train_step(data, optim_wrapper_dict)
258238

259239
self.assertIsNotNone(loss)
260240

261241
def test_dartsddp_with_unroll(self) -> None:
262-
model = ToyDiffModule(data_preprocessor=ToyDataPreprocessor())
263-
mutator = DiffModuleMutator()
264-
mutator.prepare_from_supernet(model)
265-
266242
# data is tuple or list
267-
algo = Darts(model, mutator, unroll=True)
268-
ddp_model = DartsDDP(module=algo, find_unused_parameters=True)
269-
243+
ddp_model = self.prepare_model()
270244
data = [self._prepare_fake_data() for _ in range(2)]
271245
optim_wrapper_dict = OptimWrapperDict(
272-
architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)),
273-
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
246+
architecture=OptimWrapper(SGD(ddp_model.parameters(), lr=0.1)),
247+
mutator=OptimWrapper(SGD(ddp_model.parameters(), lr=0.01)))
274248
loss = ddp_model.train_step(data, optim_wrapper_dict)
275249

276250
self.assertIsNotNone(loss)
277-
278-
279-
if __name__ == '__main__':
280-
import unittest
281-
unittest.main()

0 commit comments

Comments
 (0)