-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathLoadBatches1D.py
57 lines (46 loc) · 1.84 KB
/
LoadBatches1D.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# -*- coding: utf-8 -*-
"""
Created on Sun Apr 21 13:46:44 2019
@author: Winham
LoadBatches1D.py: 迭代生成训练时的batch
实现参考:https://github.com/divamgupta/image-segmentation-keras/blob/master/LoadBatches.py
"""
import os
import itertools
import numpy as np
from sklearn import preprocessing as prep
def getSigArr(path, sigNorm='scale'):
sig = np.load(path)
if sigNorm == 'scale':
sig = prep.scale(sig)
elif sigNorm == 'minmax':
min_max_scaler = prep.MinMaxScaler()
sig = min_max_scaler.fit_transform(sig)
return np.expand_dims(sig, axis=1)
def getSegmentationArr(path, nClasses=3, output_length=1800, class_value=[0, 0.5, 1]):
# class_value是在generate_labels.py中定义的,背景0,正常0.5,早搏1,必须要保持一致
seg_labels = np.zeros([output_length, nClasses])
seg = np.load(path)
for i in range(nClasses):
seg_labels[:, i] = (seg == class_value[i]).astype(float)
return seg_labels
def SigSegmentationGenerator(sigs_path, segs_path, batch_size, n_classes, output_length=1800):
sigs = os.listdir(sigs_path)
segmentations = os.listdir(segs_path)
sigs.sort()
segmentations.sort()
for i in range(len(sigs)):
sigs[i] = sigs_path + sigs[i]
segmentations[i] = segs_path + segmentations[i]
assert len(sigs) == len(segmentations)
for sig, seg in zip(sigs, segmentations):
assert (sig.split('/')[-1].split(".")[0] == seg.split('/')[-1].split(".")[0])
zipped = itertools.cycle(zip(sigs, segmentations))
while True:
X = []
Y = []
for _ in range(batch_size):
sig, seg = next(zipped)
X.append(getSigArr(sig))
Y.append(getSegmentationArr(seg, n_classes, output_length))
yield np.array(X), np.array(Y)