-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict.py
292 lines (240 loc) · 9.93 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import click
import json
import numpy as np
import os
import skvideo.io
import sys
import time
import tensorflow as tf
from PIL import Image
from luminoth.tools.checkpoint import get_checkpoint_config
from luminoth.utils.config import get_config, override_config_params
from luminoth.utils.detector_in_detector_predicting import PredictorNetwork
from luminoth.vis import build_colormap, vis_objects
IMAGE_FORMATS = ['jpg', 'jpeg', 'png']
VIDEO_FORMATS = ['mov', 'mp4', 'avi'] # TODO: check if more formats work
def get_file_type(filename):
extension = filename.split('.')[-1].lower()
if extension in IMAGE_FORMATS:
return 'image'
elif extension in VIDEO_FORMATS:
return 'video'
def resolve_files(path_or_dir):
"""Returns the file paths for `path_or_dir`.
Args:
path_or_dir: String or list of strings for the paths or directories to
run predictions in. For directories, will return all the files
within.
Returns:
List of strings with the full path for each file.
"""
if not isinstance(path_or_dir, tuple):
path_or_dir = (path_or_dir,)
paths = []
for entry in path_or_dir:
if tf.gfile.IsDirectory(entry):
paths.extend([
os.path.join(entry, f)
for f in tf.gfile.ListDirectory(entry)
if get_file_type(f) in ('image', 'video')
])
elif get_file_type(entry) in ('image', 'video'):
if not tf.gfile.Exists(entry):
click.echo('Input {} not found, skipping.'.format(entry))
continue
paths.append(entry)
return paths
def filter_classes(objects, only_classes=None, ignore_classes=None):
if ignore_classes:
objects = [o for o in objects if o['label'] not in ignore_classes]
if only_classes:
objects = [o for o in objects if o['label'] in only_classes]
return objects
def predict_image(network, path, only_classes=None, ignore_classes=None,
save_path=None):
click.echo('Predicting {}...'.format(path), nl=False)
# Open and read the image to predict.
with tf.gfile.Open(path, 'rb') as f:
try:
image = Image.open(f).convert('RGB')
except (tf.errors.OutOfRangeError, OSError) as e:
click.echo()
click.echo('Error while processing {}: {}'.format(path, e))
return
# Run image through the network.
objects = network.predict_image(image)
# Filter the results according to the user input.
objects = filter_classes(
objects,
only_classes=only_classes,
ignore_classes=ignore_classes
)
# Save predicted image.
if save_path:
vis_objects(np.array(image), objects).save(save_path)
click.echo(' done.')
return objects
def predict_video(network, path, only_classes=None, ignore_classes=None,
save_path=None):
if save_path:
# We hardcode the video output to mp4 for the time being.
save_path = os.path.splitext(save_path)[0] + '.mp4'
try:
writer = skvideo.io.FFmpegWriter(save_path)
except AssertionError as e:
tf.logging.error(e)
tf.logging.error(
'Please install ffmpeg before making video predictions.'
)
exit()
else:
click.echo(
'Video not being saved. Note that for the time being, no JSON '
'output is being generated. Did you mean to specify `--save-path`?'
)
num_of_frames = int(skvideo.io.ffprobe(path)['video']['@nb_frames'])
video_progress_bar = click.progressbar(
skvideo.io.vreader(path),
length=num_of_frames,
label='Predicting {}'.format(path)
)
colormap = build_colormap()
objects_per_frame = []
with video_progress_bar as bar:
try:
start_time = time.time()
for idx, frame in enumerate(bar):
# Run image through network.
objects = network.predict_image(frame)
# Filter the results according to the user input.
objects = filter_classes(
objects,
only_classes=only_classes,
ignore_classes=ignore_classes
)
objects_per_frame.append({
'frame': idx,
'objects': objects
})
# Draw the image and write it to the video file.
if save_path:
image = vis_objects(frame, objects, colormap=colormap)
writer.writeFrame(np.array(image))
stop_time = time.time()
click.echo(
'fps: {0:.1f}'.format(num_of_frames / (stop_time - start_time))
)
except RuntimeError as e:
click.echo() # Error prints next to progress bar otherwise.
click.echo('Error while processing {}: {}'.format(path, e))
if save_path:
click.echo(
'Partially processed video file saved in {}'.format(
save_path
)
)
if save_path:
writer.close()
return objects_per_frame
@click.command(help="Obtain a model's predictions.")
@click.argument('path-or-dir', nargs=-1)
@click.option('config_files', '--config', '-c', multiple=True, help='Config to use.') # noqa
@click.option('--checkpoint', help='Checkpoint to use.')
@click.option('override_params', '--override', '-o', multiple=True, help='Override model config params.') # noqa
@click.option('output_path', '--output', '-f', default='-', help='Output file with the predictions (for example, JSON bounding boxes).') # noqa
@click.option('--save-media-to', '-d', help='Directory to store media to.')
@click.option('--min-prob', default=0.5, type=float, help='When drawing, only draw bounding boxes with probability larger than.') # noqa
@click.option('--max-detections', default=100, type=int, help='Maximum number of detections per image.') # noqa
@click.option('--only-class', '-k', default=None, multiple=True, help='Class to ignore when predicting.') # noqa
@click.option('--ignore-class', '-K', default=None, multiple=True, help='Class to ignore when predicting.') # noqa
@click.option('--debug', is_flag=True, help='Set debug level logging.')
def predict(path_or_dir, config_files, checkpoint, override_params,
output_path, save_media_to, min_prob, max_detections, only_class,
ignore_class, debug):
"""Obtain a model's predictions.
Receives either `config_files` or `checkpoint` in order to load the correct
model. Afterwards, runs the model through the inputs specified by
`path-or-dir`, returning predictions according to the format specified by
`output`.
Additional model behavior may be modified with `min-prob`, `only-class` and
`ignore-class`.
"""
if debug:
tf.logging.set_verbosity(tf.logging.DEBUG)
else:
tf.logging.set_verbosity(tf.logging.ERROR)
if only_class and ignore_class:
click.echo(
"Only one of `only-class` or `ignore-class` may be specified."
)
return
# Process the input and get the actual files to predict.
files = resolve_files(path_or_dir)
if not files:
error = 'No files to predict found. Accepted formats are: {}.'.format(
', '.join(IMAGE_FORMATS + VIDEO_FORMATS)
)
click.echo(error)
return
else:
click.echo('Found {} files to predict.'.format(len(files)))
# Build the `Formatter` based on the outputs, which automatically writes
# the formatted output to all the requested output files.
if output_path == '-':
output = sys.stdout
else:
output = open(output_path, 'w')
# Create `save_media_to` if specified and it doesn't exist.
if save_media_to:
tf.gfile.MakeDirs(save_media_to)
# Resolve the config to use and initialize the model.
if checkpoint:
config = get_checkpoint_config(checkpoint)
elif config_files:
config = get_config(config_files)
else:
click.echo(
'Neither checkpoint not config specified, assuming `accurate`.'
)
config = get_checkpoint_config('accurate')
if override_params:
config = override_config_params(config, override_params)
# Filter bounding boxes according to `min_prob` and `max_detections`.
if config.model.type == 'fasterrcnn':
if config.model.network.with_rcnn:
config.model.rcnn.proposals.total_max_detections = max_detections
else:
config.model.rpn.proposals.post_nms_top_n = max_detections
config.model.rcnn.proposals.min_prob_threshold = min_prob
elif config.model.type == 'ssd':
config.model.proposals.total_max_detections = max_detections
config.model.proposals.min_prob_threshold = min_prob
else:
raise ValueError(
"Model type '{}' not supported".format(config.model.type)
)
# Instantiate the model indicated by the config.
network = PredictorNetwork(config)
# Iterate over files and run the model on each.
for file in files:
# Get the media output path, if media storage is requested.
save_path = os.path.join(
save_media_to, 'pred_{}'.format(os.path.basename(file))
) if save_media_to else None
file_type = get_file_type(file)
predictor = predict_image if file_type == 'image' else predict_video
objects = predictor(
network, file,
only_classes=only_class,
ignore_classes=ignore_class,
save_path=save_path,
)
# TODO: Not writing jsons for video files for now.
if objects is not None and file_type == 'image':
output.write(
json.dumps({
'file': file,
'objects': objects,
}) + '\n'
)
output.close()