1
1
# Copyright (c) OpenMMLab. All rights reserved.
2
2
import os
3
- from typing import Dict , List , Tuple
3
+ from typing import Dict
4
4
from unittest import TestCase
5
5
6
6
import pytest
18
18
from mmrazor .models .algorithms .nas .darts import DartsDDP
19
19
from mmrazor .registry import MODELS
20
20
21
- # from unittest.mock import Mock
22
-
23
21
MODELS .register_module (name = 'torchConv2d' , module = nn .Conv2d , force = True )
24
22
MODELS .register_module (name = 'torchMaxPool2d' , module = nn .MaxPool2d , force = True )
25
23
MODELS .register_module (name = 'torchAvgPool2d' , module = nn .AvgPool2d , force = True )
@@ -72,15 +70,6 @@ def forward(self, batch_inputs, data_samples=None, mode='tensor'):
72
70
return out
73
71
74
72
75
- class ToyDataPreprocessor (torch .nn .Module ):
76
-
77
- def forward (
78
- self ,
79
- data : Dict ,
80
- training : bool = True ) -> Tuple [torch .Tensor , List [ClsDataSample ]]:
81
- return data ['inputs' ], data ['data_samples' ]
82
-
83
-
84
73
class TestDarts (TestCase ):
85
74
86
75
def setUp (self ) -> None :
@@ -159,7 +148,7 @@ def test_search_subnet(self) -> None:
159
148
self .assertIsInstance (subnet , dict )
160
149
161
150
def test_darts_train_step (self ) -> None :
162
- model = ToyDiffModule (data_preprocessor = ToyDataPreprocessor () )
151
+ model = ToyDiffModule ()
163
152
mutator = DiffModuleMutator ()
164
153
mutator .prepare_from_supernet (model )
165
154
@@ -182,7 +171,7 @@ def test_darts_train_step(self) -> None:
182
171
self .assertIsNotNone (loss )
183
172
184
173
def test_darts_with_unroll (self ) -> None :
185
- model = ToyDiffModule (data_preprocessor = ToyDataPreprocessor () )
174
+ model = ToyDiffModule ()
186
175
mutator = DiffModuleMutator ()
187
176
mutator .prepare_from_supernet (model )
188
177
@@ -205,19 +194,17 @@ def setUpClass(cls) -> None:
205
194
os .environ ['MASTER_PORT' ] = '12345'
206
195
207
196
# initialize the process group
208
- if torch .cuda .is_available ():
209
- backend = 'nccl'
210
- cls .device = 'cuda'
211
- else :
212
- backend = 'gloo'
197
+ backend = 'nccl' if torch .cuda .is_available () else 'gloo'
213
198
dist .init_process_group (backend , rank = 0 , world_size = 1 )
214
199
215
200
def prepare_model (self , device_ids = None ) -> Darts :
216
- model = ToyDiffModule ().to (self .device )
217
- mutator = DiffModuleMutator ().to (self .device )
201
+ self .device = 'cuda' if torch .cuda .is_available () else 'cpu'
202
+
203
+ model = ToyDiffModule ()
204
+ mutator = DiffModuleMutator ()
218
205
mutator .prepare_from_supernet (model )
219
206
220
- algo = Darts (model , mutator )
207
+ algo = Darts (model , mutator ). to ( self . device )
221
208
222
209
return DartsDDP (
223
210
module = algo , find_unused_parameters = True , device_ids = device_ids )
@@ -230,52 +217,34 @@ def tearDownClass(cls) -> None:
230
217
not torch .cuda .is_available (), reason = 'cuda device is not avaliable' )
231
218
def test_init (self ) -> None :
232
219
ddp_model = self .prepare_model ()
233
- # ddp_model = DartsDDP(module=model, device_ids=[0])
234
220
self .assertIsInstance (ddp_model , DartsDDP )
235
221
236
222
def test_dartsddp_train_step (self ) -> None :
237
- model = ToyDiffModule (data_preprocessor = ToyDataPreprocessor ())
238
- mutator = DiffModuleMutator ()
239
- mutator .prepare_from_supernet (model )
240
-
241
223
# data is tensor
242
- algo = Darts (model , mutator )
243
- ddp_model = DartsDDP (module = algo , find_unused_parameters = True )
224
+ ddp_model = self .prepare_model ()
244
225
data = self ._prepare_fake_data ()
245
226
optim_wrapper = build_optim_wrapper (ddp_model , self .OPTIM_WRAPPER_CFG )
246
227
loss = ddp_model .train_step (data , optim_wrapper )
247
228
248
229
self .assertIsNotNone (loss )
249
230
250
231
# data is tuple or list
251
- algo = Darts (model , mutator )
252
- ddp_model = DartsDDP (module = algo , find_unused_parameters = True )
232
+ ddp_model = self .prepare_model ()
253
233
data = [self ._prepare_fake_data () for _ in range (2 )]
254
234
optim_wrapper_dict = OptimWrapperDict (
255
- architecture = OptimWrapper (SGD (model .parameters (), lr = 0.1 )),
256
- mutator = OptimWrapper (SGD (model .parameters (), lr = 0.01 )))
235
+ architecture = OptimWrapper (SGD (ddp_model .parameters (), lr = 0.1 )),
236
+ mutator = OptimWrapper (SGD (ddp_model .parameters (), lr = 0.01 )))
257
237
loss = ddp_model .train_step (data , optim_wrapper_dict )
258
238
259
239
self .assertIsNotNone (loss )
260
240
261
241
def test_dartsddp_with_unroll (self ) -> None :
262
- model = ToyDiffModule (data_preprocessor = ToyDataPreprocessor ())
263
- mutator = DiffModuleMutator ()
264
- mutator .prepare_from_supernet (model )
265
-
266
242
# data is tuple or list
267
- algo = Darts (model , mutator , unroll = True )
268
- ddp_model = DartsDDP (module = algo , find_unused_parameters = True )
269
-
243
+ ddp_model = self .prepare_model ()
270
244
data = [self ._prepare_fake_data () for _ in range (2 )]
271
245
optim_wrapper_dict = OptimWrapperDict (
272
- architecture = OptimWrapper (SGD (model .parameters (), lr = 0.1 )),
273
- mutator = OptimWrapper (SGD (model .parameters (), lr = 0.01 )))
246
+ architecture = OptimWrapper (SGD (ddp_model .parameters (), lr = 0.1 )),
247
+ mutator = OptimWrapper (SGD (ddp_model .parameters (), lr = 0.01 )))
274
248
loss = ddp_model .train_step (data , optim_wrapper_dict )
275
249
276
250
self .assertIsNotNone (loss )
277
-
278
-
279
- if __name__ == '__main__' :
280
- import unittest
281
- unittest .main ()
0 commit comments