-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathtrainSyncNet.py
executable file
·154 lines (102 loc) · 4.97 KB
/
trainSyncNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#!/usr/bin/python
#-*- coding: utf-8 -*-
import numpy
import sys
import time
import os
import argparse
import pdb
import glob
import torch
from SyncNetDist import SyncNet
from tuneThreshold import tuneThresholdfromScore
from sklearn import metrics
from DatasetLoader import DatasetLoader
parser = argparse.ArgumentParser(description = "TrainArgs");
## Data loader
parser.add_argument('--maxFrames', type=int, default=30, help='');
parser.add_argument('--nBatchSize', type=int, default=30, help='');
parser.add_argument('--nTrainPerEpoch', type=int, default=100000, help='');
parser.add_argument('--nTestPerEpoch', type=int, default=10000, help='');
parser.add_argument('--nDataLoaderThread', type=int, default=4, help='');
## Training details
parser.add_argument('--max_epoch', type=int, default=500, help='Maximum number of epochs');
parser.add_argument('--temporal_stride', type=int, default=1, help='');
## Model definition
parser.add_argument('--model', type=str, default="", help='Model name');
parser.add_argument('--nOut', type=int, default=1024, help='Embedding size in the last FC layer');
## Learning rates
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate');
parser.add_argument("--lr_decay", type=float, default=0.95, help='Learning rate decay every epoch');
## Joint training params
parser.add_argument('--alphaC', type=float, default=1.0, help='Sync weight');
parser.add_argument('--alphaI', type=float, default=1.0, help='Identity weight');
## Load and save
parser.add_argument('--initial_model', type=str, default="", help='Initial model weights');
parser.add_argument('--save_path', type=str, default="./data/exp01", help='Path for model and logs');
## Training and test data
parser.add_argument('--train_list', type=str, default="data/dev.txt", help='');
parser.add_argument('--verify_list', type=str, default="data/test.txt", help='');
## Speaker recognition test
parser.add_argument('--test_list', type=str, default="voxceleb/test_list.txt", help='Evaluation list');
parser.add_argument('--test_path', type=str, default="voxceleb/voxceleb1", help='Absolute path to the test set');
## For test only
parser.add_argument('--eval', dest='eval', action='store_true', help='Eval only')
args = parser.parse_args();
# ==================== MAKE DIRECTORIES ====================
model_save_path = args.save_path+"/model"
result_save_path = args.save_path+"/result"
if not(os.path.exists(model_save_path)):
os.makedirs(model_save_path)
if not(os.path.exists(result_save_path)):
os.makedirs(result_save_path)
# ==================== LOAD MODEL ====================
s = SyncNet(**vars(args));
# ==================== EVALUATE LIST ====================
it = 1;
scorefile = open(result_save_path+"/scores.txt", "a+");
for items in vars(args):
print(items, vars(args)[items]);
scorefile.write('%s %s\n'%(items, vars(args)[items]));
scorefile.flush()
# ==================== LOAD MODEL PARAMS ====================
modelfiles = glob.glob('%s/model0*.model'%model_save_path)
modelfiles.sort()
if len(modelfiles) >= 1:
s.loadParameters(modelfiles[-1]);
print("Model %s loaded from previous state!"%modelfiles[-1]);
it = int(os.path.splitext(os.path.basename(modelfiles[-1]))[0][5:]) + 1
elif(args.initial_model != ""):
s.loadParameters(args.initial_model);
print("Model %s loaded!"%args.initial_model);
for ii in range(0,it-1):
clr = s.updateLearningRate(args.lr_decay)
# ==================== EVAL ====================
if args.eval == True:
sc, lab = s.evaluateFromListSave(args.test_list, print_interval=100, test_path=args.test_path)
result = tuneThresholdfromScore(sc, lab, [1, 0.1]);
print('EER %2.4f'%result[1])
quit();
# ==================== LOAD DATA LIST ====================
print('Reading data ...')
trainLoader = DatasetLoader(args.train_list, nPerEpoch=args.nTrainPerEpoch, **vars(args))
valLoader = DatasetLoader(args.verify_list, nPerEpoch=args.nTestPerEpoch, evalmode=True, **vars(args))
print('Reading done.')
# ==================== CHECK SPK ====================
clr = s.updateLearningRate(1)
while(1):
print(time.strftime("%Y-%m-%d %H:%M:%S"), it, "Start Iteration");
loss, trainacc = s.train_network(trainLoader, evalmode=False, alpI=args.alphaI, alpC=args.alphaC);
valloss, valacc = s.train_network(valLoader, evalmode=True);
print(time.strftime("%Y-%m-%d %H:%M:%S"), "%s: IT %d, LR %f, TACC %2.2f, TLOSS %f, VACC %2.2f, VLOSS %f\n"%(args.save_path, it, max(clr), trainacc, loss, valacc, valloss));
scorefile.write("IT %d, LR %f, TACC %2.2f, TLOSS %f, VACC %2.2f, VLOSS %f\n"%(it, max(clr), trainacc, loss, valacc, valloss));
scorefile.flush()
# ==================== SAVE MODEL ====================
clr = s.updateLearningRate(args.lr_decay)
print(time.strftime("%Y-%m-%d %H:%M:%S"), "Saving model %d" % it)
s.saveParameters(model_save_path+"/model%09d.model"%it);
if it >= args.max_epoch:
quit();
it+=1;
print("");
scorefile.close();