-
Notifications
You must be signed in to change notification settings - Fork 30
/
main.py
123 lines (107 loc) · 7.58 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Main function for this repo
Author: Zhao Na, 2020
"""
import ast
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
#data
parser.add_argument('--phase', type=str, default='graphtrain', choices=['pretrain', 'finetune',
'prototrain', 'protoeval',
'mptitrain', 'mptieval'])
parser.add_argument('--dataset', type=str, default='s3dis', help='Dataset name: s3dis|scannet')
parser.add_argument('--cvfold', type=int, default=0, help='Fold left-out for testing in leave-one-out setting '
'Options:{0,1}')
parser.add_argument('--data_path', type=str, default='./datasets/S3DIS/blocks_bs1_s1',
help='Directory to the source data')
parser.add_argument('--pretrain_checkpoint_path', type=str, default=None,
help='Path to the checkpoint of pre model for resuming')
parser.add_argument('--model_checkpoint_path', type=str, default=None,
help='Path to the checkpoint of model for resuming')
parser.add_argument('--save_path', type=str, default='./log_s3dis/',
help='Directory to the save log and checkpoints')
parser.add_argument('--eval_interval', type=int, default=1500,
help='iteration/epoch inverval to evaluate model')
#optimization
parser.add_argument('--batch_size', type=int, default=32, help='Number of samples/tasks in one batch')
parser.add_argument('--n_workers', type=int, default=16, help='number of workers to load data')
parser.add_argument('--n_iters', type=int, default=30000, help='number of iterations/epochs to train')
parser.add_argument('--lr', type=float, default=0.001,
help='Model (eg. protoNet or MPTI) learning rate [default: 0.001]')
parser.add_argument('--step_size', type=int, default=5000, help='Iterations of learning rate decay')
parser.add_argument('--gamma', type=float, default=0.5, help='Multiplicative factor of learning rate decay')
parser.add_argument('--pretrain_lr', type=float, default=0.001, help='pretrain learning rate [default: 0.001]')
parser.add_argument('--pretrain_weight_decay', type=float, default=0., help='weight decay for regularization')
parser.add_argument('--pretrain_step_size', type=int, default=50, help='Period of learning rate decay')
parser.add_argument('--pretrain_gamma', type=float, default=0.5, help='Multiplicative factor of learning rate decay')
#few-shot episode setting
parser.add_argument('--n_way', type=int, default=2, help='Number of classes for each episode: 1|3')
parser.add_argument('--k_shot', type=int, default=1, help='Number of samples/shots for each class: 1|5')
parser.add_argument('--n_queries', type=int, default=1, help='Number of queries for each class')
parser.add_argument('--n_episode_test', type=int, default=100,
help='Number of episode per configuration during testing')
# Point cloud processing
parser.add_argument('--pc_npts', type=int, default=2048, help='Number of input points for PointNet.')
parser.add_argument('--pc_attribs', default='xyzrgbXYZ',
help='Point attributes fed to PointNets, if empty then all possible. '
'xyz = coordinates, rgb = color, XYZ = normalized xyz')
parser.add_argument('--pc_augm', action='store_true', help='Training augmentation for points in each superpoint')
parser.add_argument('--pc_augm_scale', type=float, default=0,
help='Training augmentation: Uniformly random scaling in [1/scale, scale]')
parser.add_argument('--pc_augm_rot', type=int, default=1,
help='Training augmentation: Bool, random rotation around z-axis')
parser.add_argument('--pc_augm_mirror_prob', type=float, default=0,
help='Training augmentation: Probability of mirroring about x or y axes')
parser.add_argument('--pc_augm_jitter', type=int, default=1,
help='Training augmentation: Bool, Gaussian jittering of all attributes')
# feature extraction network configuration
parser.add_argument('--dgcnn_k', type=int, default=20, help='Number of nearest neighbors in Edgeconv')
parser.add_argument('--edgeconv_widths', default='[[64,64], [64,64], [64,64]]', help='DGCNN Edgeconv widths')
parser.add_argument('--dgcnn_mlp_widths', default='[512, 256]', help='DGCNN MLP (following stacked Edgeconv) widths')
parser.add_argument('--base_widths', default='[128, 64]', help='BaseLearner widths')
parser.add_argument('--output_dim', type=int, default=64,
help='The dimension of the final output of attention learner or linear mapper')
parser.add_argument('--use_attention', action='store_true', help='if incorporate attention learner')
# protoNet configuration
parser.add_argument('--dist_method', default='euclidean',
help='Method to compute distance between query feature maps and prototypes.[Option: cosine|euclidean]')
# MPTI configuration
parser.add_argument('--n_subprototypes', type=int, default=100,
help='Number of prototypes for each class in support set')
parser.add_argument('--k_connect', type=int, default=200,
help='Number of nearest neighbors to construct local-constrained affinity matrix')
parser.add_argument('--sigma', type=float, default=1., help='hyeprparameter in gaussian similarity function')
args = parser.parse_args()
args.edgeconv_widths = ast.literal_eval(args.edgeconv_widths)
args.dgcnn_mlp_widths = ast.literal_eval(args.dgcnn_mlp_widths)
args.base_widths = ast.literal_eval(args.base_widths)
args.pc_in_dim = len(args.pc_attribs)
# Start trainer for pre-train, proto-train, proto-eval, mpti-train, mpti-test
if args.phase=='mptitrain':
args.log_dir = args.save_path + 'log_mpti_%s_S%d_N%d_K%d_Att%d' % (args.dataset, args.cvfold,
args.n_way, args.k_shot,
args.use_attention)
from runs.mpti_train import train
train(args)
elif args.phase=='prototrain':
args.log_dir = args.save_path + 'log_proto_%s_S%d_N%d_K%d_TL%d_Att%d' %(args.dataset, args.cvfold,
args.n_way, args.k_shot,
(args.triplet_loss_weight>0),
args.use_attention)
from runs.proto_train import train
train(args)
elif args.phase=='protoeval' or args.phase=='mptieval':
args.log_dir = args.model_checkpoint_path
from runs.eval import eval
eval(args)
elif args.phase=='pretrain':
args.log_dir = args.save_path + 'log_pretrain_%s_S%d' % (args.dataset, args.cvfold)
from runs.pre_train import pretrain
pretrain(args)
elif args.phase=='finetune':
args.log_dir = args.save_path + 'log_finetune_%s_S%d_N%d_K%d' % (args.dataset, args.cvfold,
args.n_way, args.k_shot)
from runs.fine_tune import finetune
finetune(args)
else:
raise ValueError('Please set correct phase.')