-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy pathtest_loader.py
211 lines (185 loc) · 8.38 KB
/
test_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import os
import shutil
import unittest
import numpy as np
import torch
from tests import get_tests_input_path, get_tests_output_path
from torch.utils.data import DataLoader
from TTS.tts.datasets import TTSDataset
from TTS.tts.datasets.preprocess import ljspeech
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config
#pylint: disable=unused-variable
OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/")
os.makedirs(OUTPATH, exist_ok=True)
c = load_config(os.path.join(get_tests_input_path(), 'test_config.json'))
ok_ljspeech = os.path.exists(c.data_path)
DATA_EXIST = True
if not os.path.exists(c.data_path):
DATA_EXIST = False
print(" > Dynamic data loader test: {}".format(DATA_EXIST))
class TestTTSDataset(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestTTSDataset, self).__init__(*args, **kwargs)
self.max_loader_iter = 4
self.ap = AudioProcessor(**c.audio)
def _create_dataloader(self, batch_size, r, bgs):
items = ljspeech(c.data_path, 'metadata.csv')
dataset = TTSDataset.MyDataset(
r,
c.text_cleaner,
compute_linear_spec=True,
ap=self.ap,
meta_data=items,
tp=c.characters if 'characters' in c.keys() else None,
batch_group_size=bgs,
min_seq_len=c.min_seq_len,
max_seq_len=float("inf"),
use_phonemes=False)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
return dataloader, dataset
def test_loader(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(2, c.r, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert isinstance(speaker_name[0], str)
assert linear_input.shape[0] == c.batch_size
assert linear_input.shape[2] == self.ap.fft_size // 2 + 1
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.audio['num_mels']
# check normalization ranges
if self.ap.symmetric_norm:
assert mel_input.max() <= self.ap.max_norm
assert mel_input.min() >= -self.ap.max_norm #pylint: disable=invalid-unary-operand-type
assert mel_input.min() < 0
else:
assert mel_input.max() <= self.ap.max_norm
assert mel_input.min() >= 0
def test_batch_group_shuffle(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(2, c.r, 16)
last_length = 0
frames = dataset.items
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
avg_length = mel_lengths.numpy().mean()
assert avg_length >= last_length
dataloader.dataset.sort_items()
is_items_reordered = False
for idx, item in enumerate(dataloader.dataset.items):
if item != frames[idx]:
is_items_reordered = True
break
assert is_items_reordered
def test_padding_and_spec(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(1, 1, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
# check mel_spec consistency
wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
mel = self.ap.melspectrogram(wav).astype('float32')
mel = torch.FloatTensor(mel).contiguous()
mel_dl = mel_input[0]
# NOTE: Below needs to check == 0 but due to an unknown reason
# there is a slight difference between two matrices.
# TODO: Check this assert cond more in detail.
assert abs(mel.T - mel_dl).max() < 1e-5, abs(mel.T - mel_dl).max()
# check mel-spec correctness
mel_spec = mel_input[0].cpu().numpy()
wav = self.ap.inv_melspectrogram(mel_spec.T)
self.ap.save_wav(wav, OUTPATH + '/mel_inv_dataloader.wav')
shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader.wav')
# check linear-spec
linear_spec = linear_input[0].cpu().numpy()
wav = self.ap.inv_spectrogram(linear_spec.T)
self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader.wav')
shutil.copy(item_idx[0],
OUTPATH + '/linear_target_dataloader.wav')
# check the last time step to be zero padded
assert linear_input[0, -1].sum() != 0
assert linear_input[0, -2].sum() != 0
assert mel_input[0, -1].sum() != 0
assert mel_input[0, -2].sum() != 0
assert stop_target[0, -1] == 1
assert stop_target[0, -2] == 0
assert stop_target.sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[0] == linear_input[0].shape[0]
assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2
dataloader, dataset = self._create_dataloader(2, 1, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
if mel_lengths[0] > mel_lengths[1]:
idx = 0
else:
idx = 1
# check the first item in the batch
assert linear_input[idx, -1].sum() != 0
assert linear_input[idx, -2].sum() != 0, linear_input
assert mel_input[idx, -1].sum() != 0
assert mel_input[idx, -2].sum() != 0, mel_input
assert stop_target[idx, -1] == 1
assert stop_target[idx, -2] == 0
assert stop_target[idx].sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[idx] == mel_input[idx].shape[0]
assert mel_lengths[idx] == linear_input[idx].shape[0]
# check the second itme in the batch
assert linear_input[1 - idx, -1].sum() == 0
assert mel_input[1 - idx, -1].sum() == 0
assert stop_target[1, mel_lengths[1]-1] == 1
assert stop_target[1, mel_lengths[1]:].sum() == 0
assert len(mel_lengths.shape) == 1
# check batch zero-frame conditions (zero-frame disabled)
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0