-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdataloader.py
executable file
·131 lines (108 loc) · 4.05 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
from torch.utils.data import Dataset, DataLoader
from utils import scan_directory, find_pair, addr2wav
import random
def create_dataloader(opt, mode):
if mode == 'train':
return DataLoader(
dataset=Wave_Dataset(opt, mode),
batch_size=opt.batch_size,
shuffle=True,
num_workers=0,
pin_memory=True,
drop_last=True,
sampler=None
)
elif mode == 'valid':
return DataLoader(
dataset=Wave_Dataset(opt, mode),
batch_size=opt.batch_size, shuffle=False, num_workers=0
)
elif mode == 'test':
return DataLoader(
dataset=Wave_Dataset_for_test(opt, mode),
batch_size=1, shuffle=False, num_workers=0
)
class Wave_Dataset(Dataset):
def __init__(self, opt, mode):
# load data
self.mode = mode
self.chunk_size = opt.chunk_size
if mode == 'train':
print('<Training dataset>')
print('Load the data...')
# load the wav addr
self.noisy_dirs = scan_directory(opt.noisy_dirs_for_train)
self.clean_dirs = find_pair(self.noisy_dirs)
elif mode == 'valid':
print('<Validation dataset>')
print('Load the data...')
# load the wav addr
self.noisy_dirs = scan_directory(opt.noisy_dirs_for_valid)
self.clean_dirs = find_pair(self.noisy_dirs)
elif mode == 'test':
print('<Test dataset>')
print('Load the data...')
# load the wav addr
self.noisy_dirs = scan_directory(opt.test_database)
self.clean_dirs = find_pair(self.noisy_dirs)
def __len__(self):
return len(self.noisy_dirs)
def __getitem__(self, idx):
# read the wav
inputs = addr2wav(self.noisy_dirs[idx])
targets = addr2wav(self.clean_dirs[idx])
# transform to torch from numpy
inputs = torch.from_numpy(inputs)
targets = torch.from_numpy(targets)
wav_len = len(inputs)
assert wav_len == len(targets)
if wav_len < self.chunk_size:
units = self.chunk_size // wav_len
inputs_final = []
targets_final = []
for i in range(units):
inputs_final.append(inputs)
targets_final.append(targets)
inputs_final.append(inputs[:self.chunk_size % wav_len])
targets_final.append(targets[:self.chunk_size % wav_len])
inputs = torch.cat(inputs_final, dim=-1)
targets = torch.cat(targets_final, dim=-1)
else:
stp = random.randint(0, len(inputs) - self.chunk_size)
inputs = inputs[stp:stp + self.chunk_size]
targets = targets[stp:stp + self.chunk_size]
# inputs = inputs[:self.chunk_size]
# targets = targets[:self.chunk_size]
# Normalization
inputs = torch.clamp_(inputs, -1, 1)
targets = torch.clamp_(targets, -1, 1)
return inputs, targets
class Wave_Dataset_for_test(Dataset):
def __init__(self, opt, mode):
# load data
self.mode = mode
self.chunk_size = opt.chunk_size
if mode == 'test':
print('<Test dataset>')
print('Load the data...')
# load the wav addr
self.noisy_dirs = scan_directory(opt.test_database)#[:50]
self.clean_dirs = find_pair(self.noisy_dirs)#[:50]
else:
raise Exception("Mode error!")
def __len__(self):
return len(self.noisy_dirs)
def __getitem__(self, idx):
# read the wav
inputs = addr2wav(self.noisy_dirs[idx])
targets = addr2wav(self.clean_dirs[idx])
# transform to torch from numpy
inputs = torch.from_numpy(inputs)
targets = torch.from_numpy(targets)
wav_len = len(inputs)
assert wav_len == len(targets)
# (-1, 1)
inputs = torch.clamp_(inputs, -1, 1)
targets = torch.clamp_(targets, -1, 1)
return inputs, targets