-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathextract_wav_mel_stftm_tfrecords_within_sess.py
66 lines (53 loc) · 2.46 KB
/
extract_wav_mel_stftm_tfrecords_within_sess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import tensorflow as tf
import os
import argparse
import scipy.io.wavfile as siowav
import numpy as np
import tqdm
def get_arguments():
parser = argparse.ArgumentParser(description="Extract wav from TFRecords file and save.")
parser.add_argument("--tfrecord_path", "-s", type=str, default="./wav_mel_stftm.tfrecords", help="")
parser.add_argument("--wav_root", "-d", type=str, default="./wav_recover_within_sess", help="")
return parser.parse_args()
def parse_single_example(example_proto):
features = {"sr": tf.FixedLenFeature([], tf.int64),
"key": tf.FixedLenFeature([], tf.string),
"frames": tf.FixedLenFeature([], tf.int64),
"wav_raw": tf.FixedLenFeature([], tf.string),
"norm_mel_raw": tf.FixedLenFeature([], tf.string),
"norm_stftm_raw": tf.FixedLenFeature([], tf.string)}
parsed = tf.parse_single_example(example_proto, features=features)
sr = tf.cast(parsed["sr"], tf.int32)
key = parsed["key"]
frames = tf.cast(parsed["frames"], tf.int32)
wav = tf.decode_raw(parsed["wav_raw"], tf.int16)
norm_mel = tf.reshape(tf.decode_raw(parsed["norm_mel_raw"], tf.float32), (frames, 80))
norm_stftm = tf.reshape(tf.decode_raw(parsed["norm_stftm_raw"], tf.float32), (frames, 513))
return {"sr": sr, "key": key, "frames": frames, "wav": wav, "norm_mel": norm_mel, "norm_stftm": norm_stftm}
def get_dataset(tfrecord_path):
dataset = tf.data.TFRecordDataset(tfrecord_path)
dataset = dataset.map(parse_single_example)
dataset = dataset.shuffle(10000)
dataset = dataset.padded_batch(3, padded_shapes={"sr": (),
"key": (),
"frames": (),
"wav": [None],
"norm_mel": [None, 80],
"norm_stftm": [None, 513]})
return dataset
def main():
args = get_arguments()
data_set = get_dataset(args.tfrecord_path)
iterator = data_set.make_one_shot_iterator()
next_item = iterator.get_next()
sess = tf.Session()
while True:
try:
print(sess.run(tf.shape(next_item["norm_mel"])))
#print(sess.run(next_item["norm_mel"])[0])
except Exception as e:
print(e)
print("Congratulations!")
if __name__ == "__main__":
print(__file__)
main()