-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathLoadData.py
49 lines (36 loc) · 1.35 KB
/
LoadData.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import numpy as np
import tqdm
def loaddata(img_dir, img3d, nclass, result_dir, color=False, skip=True):
files = os.listdir(img_dir)
X = []
labels = []
labellist = []
pbar = tqdm(total=len(files))
for filename in files:
pbar.update(1)
if filename == '.DS_Store':
continue
name = os.path.join(img_dir, filename)
for sample_files in os.listdir(name):
img_file_path = os.path.join(name, sample_files)
img_files = [f"{img_file_path}/{x}" for x in os.listdir(img_file_path)]
label = filename
if label not in labellist:
if len(labellist) >= nclass:
continue
labellist.append(label)
labels.append(label)
X.append(img3d.img3d(img_files, color=color, skip=skip))
pbar.close()
with open(os.path.join(result_dir, 'classes.txt'), 'w') as fp:
for i in range(len(labellist)):
fp.write('{}\n'.format(labellist[i]))
for num, label in enumerate(labellist):
for i in range(len(labels)):
if label == labels[i]:
labels[i] = num
if color:
return np.array(X).transpose((0, 2, 3, 4, 1)), labels
else:
return np.array(X).transpose((0, 2, 3, 1)), labels