-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathinception_v3_inference.py
55 lines (45 loc) · 1.59 KB
/
inception_v3_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import contextlib
import threading
import time
import torch
import numpy
import task.inception_v3 as inception_v3
import task.common as util
TASK_NAME = 'inception_v3_inference'
@contextlib.contextmanager
def timer(prefix):
_start = time.time()
yield
_end = time.time()
print(prefix, 'cost', _end - _start)
def import_data_loader():
return None
def import_model():
model = inception_v3.import_model()
model.eval()
return model
def import_func():
def inference(model, data_b):
print(threading.currentThread().getName(),
'inception_v3 inference >>>>>>>>>>', time.time(),
'model status', model.training)
with timer('inception_v3 inference func'):
data = torch.from_numpy(numpy.frombuffer(data_b, dtype=numpy.float32))
input_batch = data.view(-1, 3, 299, 299).cuda(non_blocking=True)
with torch.no_grad():
output = model(input_batch)
return output.sum().item()
return inference
def import_task():
model = import_model()
func = import_func()
group_list = inception_v3.partition_model(model)
group_list = [group for group in group_list if 'AuxLogits' not in group[0].fullname]
shape_list = [util.group_to_shape(group) for group in group_list]
return model, func, shape_list
def import_parameters():
model = import_model()
group_list = inception_v3.partition_model(model)
group_list = [group for group in group_list if 'AuxLogits' not in group[0].fullname]
batch_list = [util.group_to_batch(group) for group in group_list]
return batch_list