-
Notifications
You must be signed in to change notification settings - Fork 6
/
eval_img2pc.py
64 lines (52 loc) · 2.47 KB
/
eval_img2pc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import argparse
import os
from tools.Trainer import ImageToPCTrainer
from tools.PointCloudDataset import ImageToPointCloudDataset
from models.AutoEncoder import PointCloudVAE
from models.AutoEncoder import ChamferLoss
from models.AutoEncoder import ChamferWithNormalLoss
from models.AutoEncoder import L2WithNormalLoss
from models.ImageToShape import MultiResImageToShape
from models.ImageToShape import SingleResImageToShape
from models.ImageToShape import FCImageToShape
parser = argparse.ArgumentParser(description='MultiResolution image to shape model.')
parser.add_argument("-n", "--name", type=str, help="Name of the experiment.", default="MRI2PC")
parser.add_argument("-a", "--arch", type=str, help="Encoder architecture.", default="vgg")
parser.add_argument("-pt", "--pretrained", type=str, help="Use pretrained net", default="True")
parser.add_argument("-c", "--category", type=str, help="Category code (all is possible)", default="all")
parser.add_argument("--train", dest='train', action='store_true')
parser.set_defaults(train=False)
image_datapath = "/media/mgadelha/hd2/ShapenetRenderings"
pc_datapath = "/media/mgadelha/hd2/shapenet_4k"
if __name__ == '__main__':
args = parser.parse_args()
ptrain = None
if args.pretrained == "False":
ptrain = False
elif args.pretrained == "True":
ptrain = True
full_name = "{}_{}_{}_{}".format(args.name, args.category, args.arch, ptrain)
#full_name = args.name
print full_name
mri2pc = MultiResImageToShape(size=4096, dim=3, batch_size=1,
name=full_name, pretrained=ptrain, arch=args.arch)
mri2pc.load('checkpoint')
optimizer = optim.Adam(mri2pc.parameters(), lr=0.001)
train_dataset = ImageToPointCloudDataset(image_datapath, pc_datapath,
category=args.category, train_mode=True)
test_dataset = ImageToPointCloudDataset(image_datapath, pc_datapath,
category=args.category, train_mode=False)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1,
shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1,
shuffle=True, num_workers=2)
log_dir = os.path.join("log", full_name)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
trainer = ImageToPCTrainer(mri2pc, train_loader, test_loader,
optimizer, ChamferLoss(), log_dir=log_dir)
trainer.evaluate()