forked from PaddlePaddle/PaddleSlim
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport_model.py
112 lines (97 loc) · 3.88 KB
/
export_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from __future__ import division
from __future__ import print_function
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
sys.path.append(
os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
import paddleslim
from paddleslim.common import get_logger
import paddle.vision.models as models
from utility import add_arguments, print_arguments
from paddle.jit import to_static
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64 * 4, "Minibatch size.")
add_arg('model', str, "mobilenet_v1", "The target model.")
add_arg('data', str, "imagenet", "Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.")
add_arg('test_period', int, 10, "Test period in epoches.")
add_arg('checkpoint', str, None, "The path of checkpoint which is used for eval.")
add_arg('pruned_ratio', float, None, "The ratios to be pruned.")
add_arg('output_path', str, None, "The path of checkpoint which is used for eval.")
# yapf: enable
model_list = models.__all__
def get_pruned_params(args, model):
params = []
if args.model == "mobilenet_v1":
skip_vars = ['linear_0.b_0',
'conv2d_0.w_0'] # skip the first conv2d and last linear
for sublayer in model.sublayers():
for param in sublayer.parameters(include_sublayers=False):
if isinstance(
sublayer, paddle.nn.Conv2D
) and sublayer._groups == 1 and param.name not in skip_vars:
params.append(param.name)
elif args.model == "mobilenet_v2":
for sublayer in model.sublayers():
for param in sublayer.parameters(include_sublayers=False):
if isinstance(sublayer, paddle.nn.Conv2D):
params.append(param.name)
return params
elif args.model == "resnet34":
for sublayer in model.sublayers():
for param in sublayer.parameters(include_sublayers=False):
if isinstance(sublayer, paddle.nn.Conv2D):
params.append(param.name)
return params
else:
raise NotImplementedError(
"Current demo only support for mobilenet_v1, mobilenet_v2, resnet34")
return params
def export(args):
paddle.set_device('cpu')
test_reader = None
if args.data == "cifar10":
class_dim = 10
image_shape = [3, 224, 224]
elif args.data == "imagenet":
class_dim = 1000
image_shape = [3, 224, 224]
else:
raise ValueError("{} is not supported.".format(args.data))
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
model_list)
# model definition
net = models.__dict__[args.model](pretrained=False, num_classes=class_dim)
pruner = paddleslim.dygraph.L1NormFilterPruner(net, [1] + image_shape)
params = get_pruned_params(args, net)
ratios = {}
for param in params:
ratios[param] = args.pruned_ratio
print("ratios: {}".format(ratios))
pruner.prune_vars(ratios, [0])
param_state_dict = paddle.load(args.checkpoint + ".pdparams")
net.set_dict(param_state_dict)
net.eval()
model = to_static(
net,
input_spec=[
paddle.static.InputSpec(
shape=[None] + image_shape, dtype='float32', name="image")
])
paddle.jit.save(net, args.output_path)
def main():
args = parser.parse_args()
print_arguments(args)
export(args)
if __name__ == '__main__':
main()