-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathexport_model.py
75 lines (65 loc) · 3.19 KB
/
export_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
'''Command line interface to export MXNet model.
'''
import os
import json
import zipfile
from mms.arg_parser import ArgParser
SIG_REQ_ENTRY = ['inputs', 'input_type', 'outputs', 'output_types']
VALID_MIME_TYPE = ['image/jpeg', 'application/json']
def _check_signature(sig_file):
'''Internal helper to check signature error when exporting model with CLI.
'''
with open(sig_file) as js_file:
signature = json.load(js_file)
assert 'input_type' in signature and 'output_type' in signature, \
'input_type and output_type are required in signature.'
assert isinstance(signature['input_type'], basestring) and \
isinstance(signature['output_type'], basestring), \
'Value of input_type and output_type should be string'
assert signature['input_type'] in VALID_MIME_TYPE and \
signature['output_type'] in VALID_MIME_TYPE, \
'Valid type should be picked from %s' % (VALID_MIME_TYPE)
assert 'inputs' in signature and 'outputs' in signature, \
'inputs and outputs are required in signature.'
assert isinstance(signature['inputs'], list) and \
isinstance(signature['outputs'], list), \
'inputs and outputs values must be list.'
for input in signature['inputs']:
assert isinstance(input, dict), 'Each input must be a dictionary.'
assert 'data_name' in input, 'data_name is required for input.'
assert isinstance(input['data_name'], basestring), 'data_name value must be string.'
assert 'data_shape' in input, 'data_shape is required for input.'
assert isinstance(input['data_shape'], list), 'data_shape value must be list.'
for output in signature['outputs']:
assert isinstance(output, dict), 'Each output must be a dictionary.'
assert 'data_name' in output, 'data_name is required for output.'
assert isinstance(output['data_name'], basestring), 'data_name value must be string.'
assert 'data_shape' in output, 'data_shape is required for output.'
assert isinstance(output['data_shape'], list), 'data_shape value must be list.'
def _export_model(args):
'''Internal helper for exporting model.
'''
_check_signature(args.signature)
model_name, model_path = args.model.split('=')
destination = args.export_path or os.getcwd()
if model_path.startswith('~'):
model_path = os.path.expanduser(model_path)
if destination.startswith('~'):
destination = os.path.expanduser(destination)
file_list = [args.signature]
for dirpath, _, filenames in os.walk(model_path):
for file_name in filenames:
if file_name.endswith('.json') or file_name.endswith('.params'):
file_list.append(os.path.join(dirpath, file_name))
if args.synset:
file_list += [args.synset]
with zipfile.ZipFile('%s/%s.zip' % (destination, model_name), 'w') as zip_file:
for item in file_list:
zip_file.write(item, os.path.basename(item))
print('Successfully exported %s model. Model file is located at %s/%s.zip.'
% (model_name, destination, model_name))
def export():
args = ArgParser.parse_export_args()
_export_model(args)
if __name__ =='__main__':
export()