-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathresnet152.py
65 lines (53 loc) · 1.83 KB
/
resnet152.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import torch
import task.common as util
MODEL_NAME = 'resnet152'
def import_data(batch_size):
filename = 'dog.jpg'
# Download an example image from the pytorch website
if not os.path.isfile(filename):
import urllib
url = 'https://github.com/pytorch/hub/raw/master/images/dog.jpg'
try:
urllib.URLopener().retrieve(url, filename)
except:
urllib.request.urlretrieve(url, filename)
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
input_image = Image.open(filename)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
image = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
images = torch.cat([image] * batch_size)
target = torch.tensor([0] * batch_size)
return images, target
def import_model():
model = torch.hub.load('pytorch/vision:v0.4.2',
MODEL_NAME,
pretrained=True)
util.set_fullname(model, MODEL_NAME)
return model
def partition_model(model):
group_list = []
before_core = []
core_complete = False
after_core = []
group_list.append(before_core)
for name, child in model.named_children():
if 'layer' in name:
core_complete = True
for _, child_child in child.named_children():
group_list.append([child_child])
else:
if not core_complete:
before_core.append(child)
else:
after_core.append(child)
group_list.append(after_core)
return group_list