-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathinput_pred_data.py
125 lines (89 loc) · 4.21 KB
/
input_pred_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import random
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
from tensorflow.python.framework import dtypes
from tensorflow.python.framework.ops import convert_to_tensor
RANDOM_SEED = 888
HEIGHT = 256
WIDTH = 256
class Data(object):
def __init__(self, data_dir):
self.data_dir = data_dir
self._prepare_data_index()
def get_data(self):
return self.data_index['prediction']
def get_size(self):
return len(self.data_index['prediction'])
def _prepare_data_index(self):
random.seed(RANDOM_SEED)
self.data_index = {'prediction': []}
# Look through all the subfolders to find audio samples
search_path = os.path.join(self.data_dir, '*')
for image_path in gfile.Glob(search_path):
img = os.path.join(image_path, 'images', os.path.basename(image_path)) + '.png'
# img = os.path.join(image_path, 'water', os.path.basename(image_path)) + '.png'
self.data_index['prediction'].append({'image': img})
class DataLoader(object):
"""
Wrapper class around the new Tensorflows _dataset pipeline.
Handles loading, partitioning, and preparing training data.
"""
def __init__(self, data, img_size, batch_size):
self.data_info = {}
self.data_size = len(data)
images_path, images_name = self._get_data(data)
self.img_size = img_size
# create _dataset, Creating a source
dataset = tf.data.Dataset.from_tensor_slices((images_path, images_name))
# distinguish between train/infer. when calling the parsing functions
# transform to images, preprocess, repeat, batch...
dataset = dataset.map(self._parse_function, num_parallel_calls=8)
dataset = dataset.prefetch(buffer_size = 10 * batch_size)
# create a new _dataset with batches of images
dataset = dataset.batch(batch_size)
self.dataset = dataset
def _get_data(self, data):
# Data will be populated and returned.
image_paths = np.array(data)
image_names = np.array(data)
for idx, image_path in enumerate(image_paths):
image_paths[idx] = image_path['image']
image_names[idx] = os.path.basename(image_path['image'])[:-4]
image_paths.sort()
image_names.sort()
# convert lists to TF tensor
image_paths = convert_to_tensor(image_paths, dtype=dtypes.string)
image_names = convert_to_tensor(image_names, dtype=dtypes.string)
return image_paths, image_names
def _parse_function(self, image_path, image_name):
image_string = tf.read_file(image_path)
image_decoded = tf.image.decode_png(image_string, channels=3)
image_resized = tf.image.resize_images(image_decoded,
[self.img_size, self.img_size],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
# image = tf.cast(image_decoded, tf.float32)
image = tf.image.convert_image_dtype(image_resized, dtype=tf.float32)
# Finally, rescale to [-1,1] instead of [0, 1)
# image = tf.subtract(image, 0.5)
# image = tf.multiply(image, 2.0)
# image = tf.image.rgb_to_grayscale(image)
return image, image_name