-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathutils.py
78 lines (78 loc) · 2.99 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def choose_nets(nets_name, num_classes=100):
nets_name = nets_name.lower()
if nets_name == 'vgg11':
from models.VGG import VGG11
return VGG11(num_classes)
if nets_name == 'vgg13':
from models.VGG import VGG13
return VGG13(num_classes)
if nets_name == 'VGG16':
from models.VGG import VGG16
return VGG16(num_classes)
if nets_name == 'vgg19':
from models.VGG import VGG19
return VGG19(num_classes)
if nets_name == 'resnet18':
from models.ResNet import ResNet18
return ResNet18(num_classes)
if nets_name == 'resnet34':
from models.ResNet import ResNet34
return ResNet34(num_classes)
if nets_name == 'resnet50':
from models.ResNet import ResNet50
return ResNet50(num_classes)
if nets_name == 'resnet101':
from models.ResNet import ResNet101
return ResNet101(num_classes)
if nets_name == 'resnet152':
from models.ResNet import ResNet152
return ResNet152(num_classes)
if nets_name == 'googlenet':
from models.GoogLeNet import GoogLeNet
return GoogLeNet(num_classes)
if nets_name == 'inceptionv3':
from models.InceptionV3 import inceptionv3
return inceptionv3(num_classes)
if nets_name == 'mobilenet':
from models.MobileNet import mobilenet
return mobilenet(num_classes)
if nets_name == 'mobilenetv2':
from models.MobileNetV2 import mobilenetv2
return mobilenetv2(num_classes)
if nets_name == 'seresnet18':
from models.SEResNet import seresnet18
return seresnet18(num_classes)
if nets_name == 'seresnet34':
from models.SEResNet import seresnet34
return seresnet34(num_classes)
if nets_name == 'seresnet50':
from models.SEResNet import seresnet50
return seresnet50(num_classes)
if nets_name == 'seresnet101':
from models.SEResNet import seresnet101
return seresnet101(num_classes)
if nets_name == 'seresnet152':
from models.SEResNet import seresnet152
return seresnet152(num_classes)
if nets_name == 'densenet121':
from models.DenseNet import densenet121
return densenet121(num_classes)
if nets_name == 'densenet169':
from models.DenseNet import densenet169
return densenet169(num_classes)
if nets_name == 'densenet201':
from models.DenseNet import densenet201
return densenet201(num_classes)
if nets_name == 'densenet121':
from models.DenseNet import densenet161
return densenet161(num_classes)
if nets_name == 'squeezenet':
from models.SqueezeNet import squeezenet
return squeezenet(num_classes)
if nets_name == 'inceptionv4':
from models.InceptionV4 import inceptionv4
return inceptionv4(num_classes)
if nets_name == 'inception-resnet-v2':
from models.InceptionV4 import inception_resnet_v2
return inception_resnet_v2(num_classes)
raise NotImplementedError