-
Notifications
You must be signed in to change notification settings - Fork 80
/
Copy pathtrain.py
126 lines (98 loc) · 4.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import tensorflow as tf
import numpy as np
import sys
import os
import data_input
import librosa
from tqdm import tqdm
import argparse
import audio
SAVE_EVERY = 5000
RESTORE_FROM = None
def train(model, config, num_steps=1000000):
sr = 24000 if 'vctk' in config.data_path else 16000
meta = data_input.load_meta(config.data_path)
config.r = meta['r']
ivocab = meta['vocab']
config.vocab_size = len(ivocab)
with tf.Session() as sess:
inputs, names, num_speakers, stft_mean, stft_std = \
data_input.load_from_npy(config.data_path)
config.num_speakers = num_speakers
# save the mean and std as tensorflow variables so they are saved with the weights
tf.Variable(stft_mean, name='stft_mean')
tf.Variable(stft_std, name='stft_std')
batch_inputs = data_input.build_dataset(sess, inputs, names)
# initialize model
model = model(config, batch_inputs, train=True)
train_writer = tf.summary.FileWriter('log/' + config.save_path + '/train', sess.graph)
tf.global_variables_initializer().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
saver = tf.train.Saver(max_to_keep=3, keep_checkpoint_every_n_hours=3)
if config.restore:
print('restoring weights')
latest_ckpt = tf.train.latest_checkpoint(
'weights/' + config.save_path[:config.save_path.rfind('/')]
)
if RESTORE_FROM is None:
if latest_ckpt is not None:
saver.restore(sess, latest_ckpt)
else:
saver.restore(sess, 'weights/' + config.save_path + '-' + str(RESTORE_FROM))
lr = model.config.init_lr
annealing_rate = model.config.annealing_rate
for _ in tqdm(range(num_steps)):
out = sess.run([
model.train_op,
model.global_step,
model.loss,
model.output,
model.alignments,
model.merged,
batch_inputs
], feed_dict={model.lr: lr})
_, global_step, loss, output, alignments, summary, inputs = out
train_writer.add_summary(summary, global_step)
# detect gradient explosion
if loss > 1e8 and global_step > 500:
print('loss exploded')
break
if global_step % 1000 == 0:
lr *= annealing_rate
if global_step % SAVE_EVERY == 0 and global_step != 0:
print('saving weights')
if not os.path.exists('weights/' + config.save_path):
os.makedirs('weights/' + config.save_path)
saver.save(sess, 'weights/' + config.save_path, global_step=global_step)
print('saving sample')
# store a sample to listen to
ideal = audio.invert_spectrogram(inputs['stft'][0]*stft_std + stft_mean)
sample = audio.invert_spectrogram(output[0]*stft_std + stft_mean)
attention_plot = data_input.generate_attention_plot(alignments[0])
step = '_' + str(global_step)
merged = sess.run(tf.summary.merge(
[tf.summary.audio('ideal' + step, ideal[None, :], sr),
tf.summary.audio('sample' + step, sample[None, :], sr),
tf.summary.image('attention' + step, attention_plot)]
))
train_writer.add_summary(merged, global_step)
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--train-set', default='nancy')
parser.add_argument('-d', '--debug', type=bool, default=False)
parser.add_argument('-r', '--restore', type=bool, default=False)
args = parser.parse_args()
from models.tacotron import Tacotron, Config
model = Tacotron
config = Config()
config.data_path = 'data/%s/' % args.train_set
config.restore = args.restore
if args.debug:
config.save_path = 'debug'
else:
config.save_path = '%s/tacotron' % args.train_set
print('Buliding Tacotron')
train(model, config)