forked from Jacklu0831/Real-Time-Object-Detection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathYOLO_live.py
175 lines (124 loc) · 5.47 KB
/
YOLO_live.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# ===========================================================
# Webcam Real-Time Object Detection with YOLOv3 / YOLOv3-tiny
# ===========================================================
# RUN WITH EXAMPLE COMMANDS BELOW:
# python YOLO_live.py -y yolov3-tiny
# python YOLO_live.py -y yolov3
import numpy as np
import argparse
from imutils.video import VideoStream, FPS
from imutils import resize
import imutils
import time
import cv2
import os
"""User inputs through command line (no input & output since this is through device camera)"""
ap = argparse.ArgumentParser()
ap.add_argument("-y", "--yolo", required=True, help="base path to YOLO directory") # pass in either yolov3 or yolov3-tiny
ap.add_argument("-c", "--confidence", type=float, default=0.3, help="minimum probability to filter weak detections")
ap.add_argument("-t", "--threshold", type=float, default=0.3, help="threshold when applying non-maxima suppression")
args = vars(ap.parse_args())
def get_model():
"""Load YOLOv3 or YOLOv3-tiny using cv2 built in DNN module."""
net = cv2.dnn.readNetFromDarknet(os.path.sep.join([args["yolo"], "yolo.cfg"]),
os.path.sep.join([args["yolo"], "yolo.weights"]))
labels = open(os.path.sep.join([args["yolo"], "coco.names"])).read().strip().split("\n")
getLayer = net.getLayerNames()
out_layer_names = [getLayer[i[0] - 1] for i in net.getUnconnectedOutLayers()]
return labels, net, out_layer_names
def get_color(labels):
"""Initialize random colors."""
np.random.seed(1)
colors = np.random.randint(0, 255, size=(len(labels), 3), dtype="uint8")
return colors
def init_video():
"""Start camera (2 sec warm-up)."""
video_stream = VideoStream(src=0).start()
time.sleep(2.0)
fps_record = FPS().start()
return video_stream, fps_record
# -----------------------------------------------------------
# Below functions are all called in the video stream pipeline
# -----------------------------------------------------------
def get_input(video_stream):
"""Grab frames and return dimensions."""
frame = video_stream.read()
# frame = resize(frame, width = 400)
(frame_height, frame_width) = frame.shape[:2]
return frame, frame_width, frame_height
def preprocess_input(net, frame):
"""Augment input and set up for forward pass."""
blob = cv2.dnn.blobFromImage(frame, 1.0 / 255, (416, 416), swapRB=True, crop=False)
net.setInput(blob)
def forward_pass(net, out_layer_names):
"""Forward pass, non-max suppression done by default."""
yolo_output = net.forward(out_layer_names)
return yolo_output
def filter_output(yolo_output, frame_width, frame_height):
"""Get lists for bounding box."""
boxes = []
confidences = []
classIDs = []
# process output
for output in yolo_output:
for detection in output:
scores = detection[5:] # detection starts with locational variables (0 to 1)
classID = np.argmax(scores)
confidence = scores[classID]
if confidence > args["confidence"]: # filter out low confidence
box_data = detection[:4] * np.array([frame_width, frame_height, frame_width, frame_height])
(center_X, center_Y, box_width, box_height) = box_data.astype("int")
x = int(center_X - (box_width / 2))
y = int(center_Y - (box_height / 2))
# record box data, confidence, and class ID for the detected (note boxes is 2d)
boxes.append([x, y, int(box_width), int(box_height)])
confidences.append(float(confidence))
classIDs.append(classID)
# with box dimension, we can now call non-maxima suppression (filtering out overlapping)
indices = cv2.dnn.NMSBoxes(boxes, confidences, args["confidence"], args["threshold"])
return boxes, confidences, classIDs, indices
def draw_box(frame, boxes, confidences, classIDs, indices, labels, colors):
"""Draw all bounding boxes."""
if len(indices) > 0:
for i in indices.flatten():
x,y,w,h = boxes[i][0],boxes[i][1],boxes[i][2],boxes[i][3]
color = [int(c) for c in colors[classIDs[i]]]
cv2.rectangle(frame, (x, y), (x + w, y + h), color, 2)
object_name = "{}: {:.4f}".format(labels[classIDs[i]], confidences[i])
cv2.putText(frame, object_name, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
return frame
def show_output(frame, fps_record):
"""Show output as a frame in video stream, press q to exit and update the fps."""
stop = 0
cv2.imshow("Frame", frame)
if cv2.waitKey(25) & 0xFF == ord("q"):
stop = 1
fps_record.update()
return fps_record, stop
def loop_frames(labels, colors, net, out_layer_names, video_stream, fps_record):
"""Loop through frame inputs from camera."""
while True:
frame, frame_width, frame_height = get_input(video_stream)
preprocess_input(net, frame)
yolo_output = forward_pass(net, out_layer_names)
boxes, confidences, classIDs, indices = filter_output(yolo_output, frame_width, frame_height)
output_frame = draw_box(frame, boxes, confidences, classIDs, indices, labels, colors)
fps_record, stop = show_output(output_frame, fps_record)
if stop == 1:
break
return fps_record
def clean_up(fps_record, video_stream):
"""Stop recording fps and display performance data."""
fps_record.stop()
print("Video time: {:.2f}".format(fps_record.elapsed()))
print("Approximate FPS: {:.2f}".format(fps_record.fps()))
cv2.destroyAllWindows()
video_stream.stop()
def run():
"""Organize and call the useful functions."""
labels, net, out_layer_names = get_model()
colors = get_color(labels)
video_stream, fps_record = init_video()
fps_record = loop_frames(labels, colors, net, out_layer_names, video_stream, fps_record)
clean_up(fps_record, video_stream)
run()