-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerateFeatures_dep.py
77 lines (51 loc) · 2.6 KB
/
generateFeatures_dep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import sys
sys.path.append("./stqft")
sys.path.append("./qcnn")
import os
#Activate the cuda env
os.environ["LD_LIBRARY_PATH"] = "$LD_LIBRARY_PATH:/usr/local/cuda/lib64/:/usr/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.2/lib64:/usr/local/cuda/targets/x86_64-linux/lib/"
import glob
import numpy as np
import time
import pickle
from stqft.frontend import signal, transform
from stqft.stqft import stqft_framework
from qcnn.small_qsr import gen_train_from_wave, labels
datasetPath = "/ceph/mstrobl/dataset"
waveformPath = "/ceph/mstrobl/waveforms"
featurePath = "/ceph/mstrobl/features/"
av = 0
def gen_mel(speechFile, sr=16000):
start = time.time()
y = signal(samplingRate=sr, signalType='file', path=speechFile)
stqft = transform(stqft_framework, numOfShots=2048, suppressPrint=True, signalFilter=True)
y_hat_stqft, f, t = stqft.forward(y, nSamplesWindow=1024, overlapFactor=0.875, windowType='hamm')
y_hat_stqft_p, f_p, t_p = stqft.postProcess(y_hat_stqft, f ,t, scale='mel', normalize=True, samplingRate=y.samplingRate, nMels=60, fmin=40.0, fmax=y.samplingRate/2)
diff = time.time()-start
print(f"Iteration took {diff} s")
return y_hat_stqft_p
def gen_train(labels, train_audio_path, outputPath, sr=16000, port=1):
all_wave = list()
all_label = list()
for label in labels:
datasetLabelFiles = glob.glob(f"{train_audio_path}/{label}/*.wav")
portDatsetLabelFiles = datasetLabelFiles[0::port]
print(f"Using {len(portDatsetLabelFiles)} out of {len(datasetLabelFiles)} files for label '{label}'")
it = 1
for datasetLabelFile in portDatsetLabelFiles:
print(f"Processing '{datasetLabelFile}' in label '{label}' [{it}/{len(portDatsetLabelFiles)}]")
it+=1
wave = gen_mel(datasetLabelFile, sr)
all_wave.append(np.expand_dims(wave[:,1:], axis=2))
all_label.append(label)
print(f"Finished generating waveforms at {time.time()}")
with open(f"{waveformPath}/waveforms{time.time()}.pckl", 'wb') as fid:
pickle.dump(all_wave, fid, pickle.HIGHEST_PROTOCOL)
with open(f"{waveformPath}/labels{time.time()}.pckl", 'wb') as fid:
pickle.dump(all_label, fid, pickle.HIGHEST_PROTOCOL)
print(f"Finished dumping cache. Starting Feature export")
return gen_train_from_wave(all_wave=all_wave, all_label=all_label, output=outputPath)
if __name__ == '__main__':
datasetFiles = glob.glob(datasetPath + "/**/*.wav", recursive=True)
print(f"Found {len(datasetFiles)} files in the dataset")
gen_train(labels, datasetPath, featurePath, port=10)