-
Notifications
You must be signed in to change notification settings - Fork 42
/
Copy pathevaluate_verification.py
108 lines (86 loc) · 3.29 KB
/
evaluate_verification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import numpy as np
import os
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
from models.model import Network
from models import resnet
from config import cfg, update_config
from utils import create_logger, Genotype
from data_objects.VoxcelebTestset import VoxcelebTestset
from functions import validate_verification
def parse_args():
parser = argparse.ArgumentParser(description='Train autospeech network')
# general
parser.add_argument('--cfg',
help='experiment configure file name',
required=True,
type=str)
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
parser.add_argument('--load_path',
help="The path to resumed dir",
default=None)
parser.add_argument('--text_arch',
help="The path to arch",
default=None)
args = parser.parse_args()
return args
def main():
args = parse_args()
update_config(cfg, args)
if args.load_path is None:
raise AttributeError("Please specify load path.")
# cudnn related setting
cudnn.benchmark = cfg.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
# Set the random seed manually for reproducibility.
np.random.seed(cfg.SEED)
torch.manual_seed(cfg.SEED)
torch.cuda.manual_seed_all(cfg.SEED)
# model and optimizer
if cfg.MODEL.NAME == 'model':
if args.load_path and os.path.exists(args.load_path):
checkpoint = torch.load(args.load_path)
genotype = checkpoint['genotype']
else:
raise AssertionError('Please specify the model to evaluate')
model = Network(cfg.MODEL.INIT_CHANNELS, cfg.MODEL.NUM_CLASSES, cfg.MODEL.LAYERS, genotype)
model.drop_path_prob = 0.0
else:
model = eval('resnet.{}(num_classes={})'.format(cfg.MODEL.NAME, cfg.MODEL.NUM_CLASSES))
model = model.cuda()
# resume && make log dir and logger
if args.load_path and os.path.exists(args.load_path):
checkpoint = torch.load(args.load_path)
# load checkpoint
model.load_state_dict(checkpoint['state_dict'])
args.path_helper = checkpoint['path_helper']
logger = create_logger(os.path.dirname(args.load_path))
logger.info("=> loaded checkpoint '{}'".format(args.load_path))
else:
raise AssertionError('Please specify the model to evaluate')
logger.info(args)
logger.info(cfg)
# dataloader
test_dataset_verification = VoxcelebTestset(
Path(cfg.DATASET.DATA_DIR), cfg.DATASET.PARTIAL_N_FRAMES
)
test_loader_verification = torch.utils.data.DataLoader(
dataset=test_dataset_verification,
batch_size=1,
num_workers=cfg.DATASET.NUM_WORKERS,
pin_memory=True,
shuffle=False,
drop_last=False,
)
validate_verification(cfg, model, test_loader_verification)
if __name__ == '__main__':
main()