-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
New utility feature class filtering for Object Detectors 2D (#467)
* Added initial version of class filtering * added parameter conditions for all implemented learners * added parameter conditions for all implemented learners * final implementation of class filtering for object detectors. added inference demo for yolov5. * correction inference demo for yolov5. * small fixes to pass tests * added License Copyright to pass test * Update src/opendr/perception/object_detection_2d/utils/class_filter_wrapper.py Co-authored-by: Kostas Tsampazis <27914645+tsampazk@users.noreply.github.com> * Update projects/python/perception/object_detection_2d/yolov5/filtered_inference_demo.py Co-authored-by: Kostas Tsampazis <27914645+tsampazk@users.noreply.github.com> * Update projects/python/perception/object_detection_2d/yolov5/filtered_inference_demo.py Co-authored-by: Kostas Tsampazis <27914645+tsampazk@users.noreply.github.com> * Update projects/python/perception/object_detection_2d/yolov5/filtered_inference_demo.py Co-authored-by: Kostas Tsampazis <27914645+tsampazk@users.noreply.github.com> * Update projects/python/perception/object_detection_2d/yolov5/filtered_inference_demo.py Co-authored-by: Kostas Tsampazis <27914645+tsampazk@users.noreply.github.com> * Update src/opendr/perception/object_detection_2d/utils/class_filter_wrapper.py Co-authored-by: Kostas Tsampazis <27914645+tsampazk@users.noreply.github.com> * Update src/opendr/perception/object_detection_2d/utils/class_filter_wrapper.py Co-authored-by: Kostas Tsampazis <27914645+tsampazk@users.noreply.github.com> * Removed filtering support for RetinaFace, since it exclusively detects face objects. --------- Co-authored-by: Kostas Tsampazis <27914645+tsampazk@users.noreply.github.com>
- Loading branch information
1 parent
02a055d
commit 1381ca8
Showing
2 changed files
with
185 additions
and
0 deletions.
There are no files selected for viewing
61 changes: 61 additions & 0 deletions
61
projects/python/perception/object_detection_2d/yolov5/filtered_inference_demo.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright 2020-2023 OpenDR European Project | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
|
||
import cv2 | ||
import torch | ||
import numpy as np | ||
|
||
from opendr.engine.data import Image | ||
from opendr.perception.object_detection_2d import YOLOv5DetectorLearner | ||
from opendr.perception.object_detection_2d import draw_bounding_boxes | ||
|
||
from opendr.perception.object_detection_2d.utils.class_filter_wrapper import FilteredLearnerWrapper | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--device", help="Device to use (cpu, cuda)", type=str, default="cuda", choices=["cuda", "cpu"]) | ||
parser.add_argument("--classes", help="Classes of interest to detect", type=str, nargs="+", default=['person']) | ||
args = parser.parse_args() | ||
|
||
yolo = YOLOv5DetectorLearner(model_name='yolov5s', device=args.device) | ||
|
||
for f in 'zidane.jpg', 'bus.jpg': | ||
torch.hub.download_url_to_file('https://ultralytics.com/images/' + f, f) # download 2 images | ||
im1 = Image.open('zidane.jpg') # OpenDR image | ||
im2 = cv2.imread('bus.jpg') # OpenCV image (BGR to RGB) | ||
|
||
# By default, only objects of the 'person' class are of interest | ||
filtered_yolo = FilteredLearnerWrapper(yolo, allowed_classes=args.classes) | ||
|
||
# detection before filtering | ||
print("No filtering...") | ||
results = yolo.infer(im1) | ||
draw_bounding_boxes(im1.opencv(), results, yolo.classes, show=True, line_thickness=3) | ||
|
||
# detection after filtering | ||
print("With filtering...") | ||
filtered_results = filtered_yolo.infer(im1) | ||
draw_bounding_boxes(im1.opencv(), filtered_results, filtered_yolo.classes, show=True, line_thickness=3) | ||
|
||
# detection before filtering | ||
print("No filtering...") | ||
results = yolo.infer(im2) | ||
draw_bounding_boxes(np.copy(im2), results, yolo.classes, show=True, line_thickness=3) | ||
|
||
# detection after filtering | ||
print("With filtering...") | ||
filtered_results = filtered_yolo.infer(im2) | ||
draw_bounding_boxes(np.copy(im2), filtered_results, filtered_yolo.classes, show=True, line_thickness=3) |
124 changes: 124 additions & 0 deletions
124
src/opendr/perception/object_detection_2d/utils/class_filter_wrapper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Copyright 2020-2023 OpenDR European Project | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from opendr.engine.target import BoundingBoxList | ||
from opendr.perception.object_detection_2d.centernet.centernet_learner import CenterNetDetectorLearner | ||
from opendr.perception.object_detection_2d.detr.detr_learner import DetrLearner | ||
from opendr.perception.object_detection_2d.ssd.ssd_learner import SingleShotDetectorLearner | ||
from opendr.perception.object_detection_2d.yolov3.yolov3_learner import YOLOv3DetectorLearner | ||
from opendr.perception.object_detection_2d.yolov5.yolov5_learner import YOLOv5DetectorLearner | ||
from opendr.perception.object_detection_2d.nanodet.nanodet_learner import NanodetLearner | ||
|
||
|
||
class FilteredLearnerWrapper: | ||
def __init__(self, learner, allowed_classes=None): | ||
self.learner = learner | ||
self.allowed_classes = allowed_classes if allowed_classes is not None else [] | ||
|
||
if isinstance(self.learner, | ||
(CenterNetDetectorLearner, YOLOv3DetectorLearner, YOLOv5DetectorLearner, NanodetLearner, | ||
SingleShotDetectorLearner)): | ||
self.classes = self.learner.classes | ||
if isinstance(self.learner, DetrLearner): | ||
self.classes = [ | ||
"N/A", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", | ||
"traffic light", "fire hydrant", "N/A", "stop sign", "parking meter", "bench", "bird", "cat", "dog", | ||
"horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "N/A", "backpack", "umbrella", "N/A", | ||
"N/A", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", | ||
"baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "N/A", | ||
"wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", | ||
"broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", | ||
"N/A", "dining table", "N/A", "N/A", "toilet", "N/A", "tv", "laptop", "mouse", "remote", "keyboard", | ||
"cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "N/A", "book", "clock", "vase", | ||
"scissors", "teddy bear", "hair drier", "toothbrush", | ||
] | ||
|
||
# Verify that allowed classes are in the detector's class list | ||
invalid_classes = [cls for cls in self.allowed_classes if cls not in self.classes] | ||
if invalid_classes: | ||
raise ValueError( | ||
f"The following classes are not detected by this detector: {', '.join(invalid_classes)}") | ||
|
||
def infer(self, img=None, threshold=None, keep_size=None, input=None, conf_threshold=None, iou_threshold=None, | ||
nms_max_num=None, size=None, custom_nms=None, nms_thresh=None, nms_topk=None, post_nms=None, | ||
extract_maps=None): | ||
|
||
# match variable names | ||
if isinstance(self.learner, NanodetLearner) and input is not None: | ||
img = input | ||
|
||
if img is None: | ||
raise ValueError( | ||
"An image input is required. Please provide a valid image.") | ||
|
||
if isinstance(self.learner, CenterNetDetectorLearner): | ||
if threshold is None: | ||
threshold = 0.2 | ||
if keep_size is None: | ||
keep_size = True | ||
boxes = self.learner.infer(img, threshold=threshold, keep_size=keep_size) | ||
|
||
elif isinstance(self.learner, YOLOv3DetectorLearner): | ||
if threshold is None: | ||
threshold = 0.1 | ||
if keep_size is None: | ||
keep_size = True | ||
boxes = self.learner.infer(img, threshold=threshold, keep_size=keep_size) | ||
|
||
elif isinstance(self.learner, YOLOv5DetectorLearner): | ||
if size is None: | ||
size = 640 | ||
boxes = self.learner.infer(img, size) | ||
|
||
elif isinstance(self.learner, DetrLearner): | ||
boxes = self.learner.infer(img) | ||
|
||
elif isinstance(self.learner, NanodetLearner): | ||
if conf_threshold is None: | ||
conf_threshold = 0.35 | ||
if iou_threshold is None: | ||
iou_threshold = 0.6 | ||
if nms_max_num is None: | ||
nms_max_num = 100 | ||
boxes = self.learner.infer(img, conf_threshold, iou_threshold, nms_max_num) | ||
|
||
elif isinstance(self.learner, SingleShotDetectorLearner): | ||
if threshold is None: | ||
threshold = 0.2 | ||
if keep_size is None: | ||
keep_size = False | ||
if nms_thresh is None: | ||
nms_thresh = 0.45 | ||
if nms_topk is None: | ||
nms_topk = 400 | ||
if post_nms is None: | ||
post_nms = 100 | ||
if extract_maps is None: | ||
extract_maps = False | ||
boxes = self.learner.infer(img, threshold, keep_size, custom_nms, | ||
nms_thresh, nms_topk, post_nms, extract_maps) | ||
else: | ||
raise ValueError( | ||
"Filtering has not been implemented for the specified detector class." | ||
) | ||
|
||
if not self.allowed_classes: | ||
return boxes | ||
filtered_boxes = BoundingBoxList( | ||
[box for box in boxes if self.classes[int(box.name)] in self.allowed_classes]) | ||
return filtered_boxes | ||
|
||
def __getattr__(self, attr): | ||
return getattr(self.learner, attr) |