diff --git a/docs/reference/high-resolution-pose-estimation.md b/docs/reference/high-resolution-pose-estimation.md index 63f0cd7566..e219cf9193 100644 --- a/docs/reference/high-resolution-pose-estimation.md +++ b/docs/reference/high-resolution-pose-estimation.md @@ -45,7 +45,9 @@ Constructor parameters: Specifies the height that the input image will be resized during the heatmap generation procedure. - **second_pass_height**: *int, default=540*\ Specifies the height of the image on the second inference for pose estimation procedure. -- **percentage_arround_crop**: *float, default=0.3*\ +- **method**: *str, default='adaptive' +- Determines which method (*adaptive* or *primary*) is used for ROI extraction. +- **percentage_around_crop**: *float, default=0.3*\ Specifies the percentage of an extra pad arround the cropped image - **heatmap_threshold**: *float, default=0.1*\ Specifies the threshold value that the heatmap elements should have during the first pass in order to trigger the second pass @@ -74,7 +76,61 @@ Constructor parameters: - **half_precision**: *bool, default=False*\ Enables inference using half (fp16) precision instead of single (fp32) precision. Valid only for GPU-based inference. +#### High Resolution Pose estimation using Adaptive ROI selection method +#### `HighResolutionPoseEstimationLearner.eval_adaptive` +```python +HighResolutionPoseEstimationLearner.eval_adaptive(self, dataset, silent, verbose, use_subset, subset_size, upsample_ratio, images_folder_name, annotations_filename) +``` + +This method is used to evaluate a trained model on an evaluation dataset. +Returns a dictionary containing statistics regarding evaluation. +Parameters: + +- **dataset**: *object*\ + Object that holds the evaluation dataset. + Can be of type `ExternalDataset` or a custom dataset inheriting from `DatasetIterator`. +- **silent**: *bool, default=False*\ + If set to True, disables all printing of evaluation progress reports and other information to STDOUT. +- **verbose**: *bool, default=True*\ + If set to True, enables the maximum verbosity. +- **use_subset**: *bool, default=True*\ + If set to True, a subset of the validation dataset is created and used in evaluation. +- **subset_size**: *int, default=250*\ + Controls the size of the validation subset. +- **upsample_ratio**: *int, default=4*\ + Defines the amount of upsampling to be performed on the heatmaps and PAFs when resizing,defaults to 4 +- **images_folder_name**: *str, default='val2017'*\ + Folder name that contains the dataset images. + This folder should be contained in the dataset path provided. + Note that this is a folder name, not a path. +- **annotations_filename**: *str, default='person_keypoints_val2017.json'*\ + Filename of the annotations JSON file. + This file should be contained in the dataset path provided. + +#### `HighResolutionPoseEstimation.infer_adaptive` +```python +HighResolutionPoseEstimation.infer_adaptive(self, img, upsample_ratio, stride) +``` +This method is used to perform pose estimation on an image. +The predicted poses are estimated through an adaptive ROI selection method that is applied on the high-resolution images. +The difference between the `HighResolutionPoseEstimation.infer` method is that the adaptive technique tries to separate the +detected ROI's instead of using the minimum enclosing bounding box of them as the `infer` does. +Returns a list of engine.target. Pose objects, where each holds a pose +and a heatmap that contains human silhouettes of the input image. +If no detections were made it returns an empty list for poses and a black frame for heatmap. +Parameters: + +- **img**: *object***\ + Object of type engine.data.Image. +- **upsample_ratio**: *int, default=4*\ + Defines the amount of upsampling to be performed on the heatmaps and PAFs when resizing. +- **stride**: *int, default=8*\ + Defines the stride value for creating a padded image. + + + +#### High Resolution Pose estimation using Primary ROI selection method #### `HighResolutionPoseEstimationLearner.eval` ```python HighResolutionPoseEstimationLearner.eval(self, dataset, silent, verbose, use_subset, subset_size, images_folder_name, annotations_filename) @@ -170,7 +226,6 @@ Parameters: - **img_scale**: *float, default=1/256*\ Specifies the scale based on which the images are normalized. - #### `HighResolutionPoseEstimation.download` ```python HighResolutionPoseEstimation.download(self, path, mode, verbose, url) @@ -258,7 +313,7 @@ The experiments are conducted on a 1080p image. | OpenDR - Full | 0.2 | 10.8 | 1.4 | 3.1 | -#### High-Resolution Pose Estimation +#### High-Resolution Pose Estimation (Primary ROI Selection) | Method | CPU i7-9700K (FPS) | RTX 2070 (FPS) | Jetson TX2 (FPS) | Xavier NX (FPS) | |------------------------|--------------------|----------------|------------------|-----------------| | HRPoseEstim - Baseline | 2.3 | 13.6 | 1.4 | 1.8 | @@ -283,6 +338,28 @@ was used as input to the models. The average precision and average recall on the COCO evaluation split is also reported in the tables below: +#### High-Resolution Pose Estimation (Adaptive ROI Selection) +| Method | CPU i7-9700K (FPS) | RTX 2070 (FPS) | Jetson TX2 (FPS) | Xavier NX (FPS) | +|------------------------|--------------------|----------------|------------------|-----------------| +| HRPoseEstim - Baseline | 2.4 | 10.5 | 2.1 | 1.5 | +| HRPoseEstim - Half | 2.5 | 11.3 | 2.6 | 1.9 | +| HRPoseEstim - Stride | 11.3 | 38.1 | 6.8 | 5.2 | +| HRPoseEstim - Stages | 2.8 | 10 | 2.3 | 1.9 | +| HRPoseEstim - H+S | 11.4 | 38 | 6.5 | 4.5 | +| HRPoseEstim - Full | 11.6 | 48.3 | 7.7 | 6.4 | + +As it is shown in the previous tables, OpenDR Lightweight OpenPose achieves higher FPS when it is resizing the input image into 256 pixels. +It is easier to process that image, but as it is shown in the next tables the method falls apart when it comes to accuracy and there are no detections. + +We have evaluated the effect of using different inference settings, namely: +- *HRPoseEstim - Baseline*, which refers to directly using the High Resolution Pose Estimation method, which is based on Lightweight OpenPose, +- *HRPoseEstim - Half*, which refers to enabling inference in half (FP) precision, +- *HRPoseEstim - Stride*, which refers to increasing stride by two in the input layer of the model, +- *HRPoseEstim - Stages*, which refers to removing the refinement stages, +- *HRPoseEstim - H+S*, which uses both half precision and increased stride, and +- *HRPoseEstim - Full*, which refers to combining all three available optimization and were used as input to the models. + + #### Lightweight OpenPose with resizing | Method | Average Precision (IoU=0.50) | Average Recall (IoU=0.50) | |-------------------|------------------------------|---------------------------| @@ -300,7 +377,7 @@ The average precision and average recall on the COCO evaluation split is also re -#### High Resolution Pose Estimation +#### High Resolution Pose Estimation (Primary ROI Selection) | Method | Average Precision (IoU=0.50) | Average Recall (IoU=0.50) | |------------------------|------------------------------|---------------------------| | HRPoseEstim - Baseline | 0.615 | 0.637 | @@ -323,6 +400,21 @@ The average precision and the average recall have been calculated on a 1080p ver For measuring the precision and recall we used the standard approach proposed for COCO, using an Intersection of Union (IoU) metric at 0.5. +#### High Resolution Pose Estimation (Adaptive ROI Selection) + +The average precision and the average recall have been calculated on a 1080p version of COCO2017 validation dataset and the results are reported in the table below: + +| Method | Average Precision (IoU=0.50) | Average Recall (IoU=0.50) | +|-------------------|------------------------------|---------------------------| +| HRPoseEstim - Baseline | 0.594 | 0.617 | +| HRPoseEstim - Half | 0.586 | 0.606 | +| HRPoseEstim - Stride | 0.251 | 0.271 | +| HRPoseEstim - Stages | 0.511 | 0.534 | +| HRPoseEstim - H+S | 0.251 | 0.263 | +| HRPoseEstim - Full | 0.229 | 0.247 | + +For measuring the precision and recall we used the standard approach proposed for COCO, using an Intersection of Union (IoU) metric at 0.5. + #### Notes diff --git a/projects/python/perception/pose_estimation/high_resolution_pose_estimation/README.md b/projects/python/perception/pose_estimation/high_resolution_pose_estimation/README.md index 62286ef976..96749e4f07 100644 --- a/projects/python/perception/pose_estimation/high_resolution_pose_estimation/README.md +++ b/projects/python/perception/pose_estimation/high_resolution_pose_estimation/README.md @@ -7,6 +7,9 @@ More specifically, the applications provided are: 1. demos/inference_demo.py: A tool that demonstrates how to perform inference on a single high resolution image and then draw the detected poses. 2. demos/eval_demo.py: A tool that demonstrates how to perform evaluation using the High Resolution Pose Estimation algorithm on 720p, 1080p and 1440p datasets. 3. demos/benchmarking_demo.py: A simple benchmarking tool for measuring the performance of High Resolution Pose Estimation in various platforms. -4. demos/webcam_demo.py: A tool that performs live pose estimation both with high resolution and simple lightweight OpenPose using a webcam. +4. demos/webcam_demo.py: A tool that performs live pose estimation with high resolution pose estimation method using a webcam. + If `--run-comparison` is enabled then it shows the differences between Lightweight OpenPose, and both adaptive and primary methods in High_resolution pose estimation. +NOTE: All demos can run either with "primary ROI selection" mode or with "adaptive ROI selection". +Use `--method primary` or `--method adaptive` for each case. diff --git a/projects/python/perception/pose_estimation/high_resolution_pose_estimation/demos/benchmarking_demo.py b/projects/python/perception/pose_estimation/high_resolution_pose_estimation/demos/benchmarking_demo.py index 65f3278c35..d13c44ac51 100644 --- a/projects/python/perception/pose_estimation/high_resolution_pose_estimation/demos/benchmarking_demo.py +++ b/projects/python/perception/pose_estimation/high_resolution_pose_estimation/demos/benchmarking_demo.py @@ -28,11 +28,12 @@ action="store_true") parser.add_argument("--height1", help="Base height of resizing in heatmap generation", default=360) parser.add_argument("--height2", help="Base height of resizing in second inference", default=540) - + parser.add_argument("--method", help="Choose between primary or adaptive ROI selection methodology defaults to adaptive", + default="adaptive", choices=["primary", "adaptive"]) args = parser.parse_args() - device, accelerate, base_height1, base_height2 = args.device, args.accelerate,\ - args.height1, args.height2 + device, accelerate, base_height1, base_height2, method = args.device, args.accelerate, \ + args.height1, args.height2, args.method if device == 'cpu': import torch @@ -51,13 +52,14 @@ pose_estimator = HighResolutionPoseEstimationLearner(device=device, num_refinement_stages=stages, mobilenet_use_stride=stride, half_precision=half_precision, first_pass_height=int(base_height1), - second_pass_height=int(base_height2)) + second_pass_height=int(base_height2), + method=method) pose_estimator.download(path=".", verbose=True) pose_estimator.load("openpose_default") # Download one sample image pose_estimator.download(path=".", mode="test_data") - image_path = join("temp", "dataset", "image", "000000000785_1080.jpg") + image_path = join("temp", "dataset", "image", "000000052591_1080.jpg") img = cv2.imread(image_path) fps_list = [] @@ -65,8 +67,10 @@ for i in tqdm(range(50)): start_time = time.perf_counter() # Perform inference - poses = pose_estimator.infer(img) - + if method == 'primary': + poses, _, _ = pose_estimator.infer(img) + if method == 'adaptive': + poses, _, _ = pose_estimator.infer_adaptive(img) end_time = time.perf_counter() fps_list.append(1.0 / (end_time - start_time)) print("Average FPS: %.2f" % (np.mean(fps_list))) diff --git a/projects/python/perception/pose_estimation/high_resolution_pose_estimation/demos/eval_demo.py b/projects/python/perception/pose_estimation/high_resolution_pose_estimation/demos/eval_demo.py index 1d093a99fd..ecad92d333 100644 --- a/projects/python/perception/pose_estimation/high_resolution_pose_estimation/demos/eval_demo.py +++ b/projects/python/perception/pose_estimation/high_resolution_pose_estimation/demos/eval_demo.py @@ -26,26 +26,29 @@ action="store_true") parser.add_argument("--height1", help="Base height of resizing in first inference", default=360) parser.add_argument("--height2", help="Base height of resizing in second inference", default=540) - + parser.add_argument("--method", help="Choose between primary or adaptive ROI selection " + "methodology defaults to adaptive", + default="adaptive", choices=["primary", "adaptive"]) args = parser.parse_args() - device, accelerate, base_height1, base_height2 = args.device, args.accelerate,\ - args.height1, args.height2 + device, accelerate, base_height1, base_height2, method = args.device, args.accelerate, \ + args.height1, args.height2, args.method if accelerate: stride = True stages = 0 half_precision = True else: - stride = True + stride = False stages = 2 - half_precision = True + half_precision = False pose_estimator = HighResolutionPoseEstimationLearner(device=device, num_refinement_stages=stages, mobilenet_use_stride=stride, half_precision=half_precision, first_pass_height=int(base_height1), - second_pass_height=int(base_height2)) + second_pass_height=int(base_height2), + method=method) pose_estimator.download(path=".", verbose=True) pose_estimator.load("openpose_default") @@ -55,8 +58,12 @@ eval_dataset = ExternalDataset(path=join("temp", "dataset"), dataset_type="COCO") t0 = time.time() - results_dict = pose_estimator.eval(eval_dataset, use_subset=False, verbose=True, silent=True, - images_folder_name="image", annotations_filename="annotation.json") + if method == "primary": + results_dict = pose_estimator.eval(eval_dataset, use_subset=False, verbose=True, silent=False, + images_folder_name="image", annotations_filename="annotation.json") + if method == "adaptive": + results_dict = pose_estimator.eval_adaptive(eval_dataset, use_subset=False, verbose=True, silent=False, + images_folder_name="image", annotations_filename="annotation.json") t1 = time.time() print("\n Evaluation time: ", t1 - t0, "seconds") print("Evaluation results = ", results_dict) diff --git a/projects/python/perception/pose_estimation/high_resolution_pose_estimation/demos/inference_demo.py b/projects/python/perception/pose_estimation/high_resolution_pose_estimation/demos/inference_demo.py index 754fe31c99..50877712e0 100644 --- a/projects/python/perception/pose_estimation/high_resolution_pose_estimation/demos/inference_demo.py +++ b/projects/python/perception/pose_estimation/high_resolution_pose_estimation/demos/inference_demo.py @@ -27,11 +27,12 @@ action="store_true") parser.add_argument("--height1", help="Base height of resizing in first inference", default=360) parser.add_argument("--height2", help="Base height of resizing in second inference", default=540) - + parser.add_argument("--method", help="Choose between primary or adaptive ROI selection methodology defaults to adaptive", + default="adaptive", choices=["primary", "adaptive"]) args = parser.parse_args() - device, accelerate, base_height1, base_height2 = args.device, args.accelerate,\ - args.height1, args.height2 + device, accelerate, base_height1, base_height2, method = args.device, args.accelerate, \ + args.height1, args.height2, args.method if accelerate: stride = True @@ -45,21 +46,29 @@ pose_estimator = HighResolutionPoseEstimationLearner(device=device, num_refinement_stages=stages, mobilenet_use_stride=stride, half_precision=half_precision, first_pass_height=int(base_height1), - second_pass_height=int(base_height2)) + second_pass_height=int(base_height2), + method=method) pose_estimator.download(path=".", verbose=True) pose_estimator.load("openpose_default") # Download one sample image pose_estimator.download(path=".", mode="test_data") - image_path = join("temp", "dataset", "image", "000000000785_1080.jpg") - + image_path = join("temp", "dataset", "image", "000000052591_1080.jpg") img = Image.open(image_path) - poses, _ = pose_estimator.infer(img) - + if method == 'primary': + poses, _, bounds = pose_estimator.infer(img) + if method == 'adaptive': + poses, _, bounds = pose_estimator.infer_adaptive(img) img_cv = img.opencv() for pose in poses: draw(img_cv, pose) + + for i in range(len(bounds)): + if bounds[i][0] is not None: + cv2.rectangle(img_cv, (int(bounds[i][0]), int(bounds[i][2])), + (int(bounds[i][1]), int(bounds[i][3])), (0, 0, 255), thickness=2) + img_cv = cv2.resize(img_cv, (1280, 720), interpolation=cv2.INTER_CUBIC) cv2.imshow('Results', img_cv) cv2.waitKey(0) diff --git a/projects/python/perception/pose_estimation/high_resolution_pose_estimation/demos/webcam_demo.py b/projects/python/perception/pose_estimation/high_resolution_pose_estimation/demos/webcam_demo.py index 4f57dd1657..a9260c2d34 100644 --- a/projects/python/perception/pose_estimation/high_resolution_pose_estimation/demos/webcam_demo.py +++ b/projects/python/perception/pose_estimation/high_resolution_pose_estimation/demos/webcam_demo.py @@ -48,15 +48,23 @@ def __next__(self): parser = argparse.ArgumentParser() parser.add_argument("--onnx", help="Use ONNX", default=False, action="store_true") parser.add_argument("--device", help="Device to use (cpu, cuda)", type=str, default="cuda") - parser.add_argument("--accelerate", help="Enables acceleration flags (e.g., stride)", default=True, + parser.add_argument("--accelerate", help="Enables acceleration flags (e.g., stride)", default=False, + action="store_true") + parser.add_argument("--height1", + help="Base height of resizing in first inference, defaults to 420", default=420) + parser.add_argument("--height2", + help="Base height of resizing in second inference, defaults to 360", default=360) + parser.add_argument("--method", + help="Choose between primary or adaptive ROI selection methodology defaults to adaptive", + default="adaptive", choices=["primary", "adaptive"]) + parser.add_argument("--run-comparison", + help="Enables comparison with all HR-pose-estimation methods and Lw-OpenPose", action="store_true") - parser.add_argument("--height1", help="Base height of resizing in first inference, defaults to 420", default=420) - parser.add_argument("--height2", help="Base height of resizing in second inference, defaults to 360", default=360) - args = parser.parse_args() - onnx, device, accelerate = args.onnx, args.device, args.accelerate - base_height1, base_height2 = args.height1, args.height2 + onnx, device, accelerate, run_comparison = args.onnx, args.device, args.accelerate, args.run_comparison + base_height1, base_height2, method = args.height1, args.height2, args.method + if accelerate: stride = True stages = 1 @@ -66,93 +74,202 @@ def __next__(self): stages = 2 half_precision = False - lw_pose_estimator = LightweightOpenPoseLearner(device=device, num_refinement_stages=stages, - mobilenet_use_stride=stride, half_precision=half_precision) + image_provider = VideoReader(0) # Use the first camera available on the system + image_provider = iter(image_provider) + + height = image_provider.cap.get(4) + width = image_provider.cap.get(3) + if run_comparison: + prim_hr_avg_fps = 0 + lw_avg_fps = 0 + adapt_hr_avg_fps = 0 - hr_pose_estimator = HighResolutionPoseEstimationLearner(device=device, num_refinement_stages=stages, - mobilenet_use_stride=stride, half_precision=half_precision, - first_pass_height=base_height1, - second_pass_height=base_height2, - percentage_around_crop=0.1) - hr_pose_estimator.download(path=".", verbose=True) - hr_pose_estimator.load("openpose_default") + lw_pose_estimator = LightweightOpenPoseLearner(device=device, num_refinement_stages=stages, + mobilenet_use_stride=stride, half_precision=half_precision) - if onnx: - hr_pose_estimator.optimize() + hr_pose_estimator = HighResolutionPoseEstimationLearner(device=device, num_refinement_stages=stages, + mobilenet_use_stride=stride, half_precision=half_precision, + first_pass_height=base_height1, + second_pass_height=base_height2, + percentage_around_crop=0.1, + method="primary") - lw_pose_estimator.download(path=".", verbose=True) - lw_pose_estimator.load("openpose_default") + adapt_hr_pose_estimator = HighResolutionPoseEstimationLearner(device=device, num_refinement_stages=stages, + mobilenet_use_stride=stride, + half_precision=half_precision, + first_pass_height=base_height1, + second_pass_height=base_height2, + percentage_around_crop=0.1, + method="adaptive") - if onnx: - lw_pose_estimator.optimize() + lw_pose_estimator.download(path=".", verbose=True) - hr_avg_fps = 0 - lw_avg_fps = 0 - # Use the first camera available on the system - image_provider = VideoReader(0) - image_provider = iter(image_provider) + hr_pose_estimator.load("openpose_default") + adapt_hr_pose_estimator.load("openpose_default") + lw_pose_estimator.load("openpose_default") + + if onnx: + hr_pose_estimator.optimize() + adapt_hr_pose_estimator.optimize() + lw_pose_estimator.optimize() + + if width / height == 16 / 9: + size = (2 * 1280, 2 * int(720 / 3)) + elif width / height == 4 / 3: + size = (2 * 1024, 2 * int(768 / 3)) + else: + size = (width, int(height / 3)) - height = image_provider.cap.get(4) - width = image_provider.cap.get(3) - if width / height == 16 / 9: - size = (1280, int(720 / 2)) - elif width / height == 4 / 3: - size = (1024, int(768 / 2)) else: - size = (width, int(height / 2)) + hr_avg_fps = 0 + + hr_pose_estimator = HighResolutionPoseEstimationLearner(device=device, num_refinement_stages=stages, + mobilenet_use_stride=stride, half_precision=half_precision, + first_pass_height=base_height1, + second_pass_height=base_height2, + percentage_around_crop=0.1, + method=method) + + hr_pose_estimator.load("openpose_default") + + if onnx: + hr_pose_estimator.optimize() + + if width / height == 16 / 9: + size = (1280, int(720)) + elif width / height == 4 / 3: + size = (1024, int(768)) + else: + size = (width, int(height / 2)) while True: img = next(image_provider) + if run_comparison: + total_time0 = time.time() + img_copy = np.copy(img) + adapt_img = np.copy(img) + # Perform inference + start_time = time.perf_counter() + hr_poses, heatmap, _ = hr_pose_estimator.infer(img) + hr_time = time.perf_counter() - start_time + + # Perform inference + start_time = time.perf_counter() + lw_poses = lw_pose_estimator.infer(img_copy) + lw_time = time.perf_counter() - start_time + + # Perform inference + start_time = time.perf_counter() + adapt_hr_poses, adapt_heatmap, _ = adapt_hr_pose_estimator.infer_adaptive(img) + adapt_hr_time = time.perf_counter() - start_time + + total_time = time.time() - total_time0 + + for hr_pose in hr_poses: + draw(img, hr_pose) + for lw_pose in lw_poses: + draw(img_copy, lw_pose) + for adapt_hr_pose in adapt_hr_poses: + draw(adapt_img, adapt_hr_pose) + + lw_fps = 1 / (total_time - (hr_time + adapt_hr_time)) + prim_hr_fps = 1 / (total_time - (lw_time + adapt_hr_time)) + adapt_hr_fps = 1 / (total_time - (lw_time + hr_time)) + + # Calculate a running average on FPS + prim_hr_avg_fps = 0.95 * prim_hr_avg_fps + 0.05 * prim_hr_fps + lw_avg_fps = 0.95 * lw_avg_fps + 0.05 * lw_fps + adapt_hr_avg_fps = 0.95 * adapt_hr_avg_fps + 0.05 * adapt_hr_fps + + cv2.putText(img=img, text="OpenDR High Resolution", org=(20, int(height / 10)), + fontFace=cv2.FONT_HERSHEY_TRIPLEX, + fontScale=int(np.ceil(height / 800)), color=(0, 0, 200), + thickness=int(np.ceil(height / 600))) + cv2.putText(img=img, text="Pose Estimation Primary ROI selection", org=(20, int(height / 10) + 50), + fontFace=cv2.FONT_HERSHEY_TRIPLEX, + fontScale=int(np.ceil(height / 800)), color=(0, 0, 200), + thickness=int(np.ceil(height / 600))) + cv2.putText(img=img, text='FPS:' + str(int(prim_hr_avg_fps)), org=(20, int(height / 4)), + fontFace=cv2.FONT_HERSHEY_TRIPLEX, + fontScale=int(np.ceil(height / 800)), color=(0, 0, 200), + thickness=int(np.ceil(height / 600))) + + cv2.putText(img=img_copy, text='Lightweight OpenPose ', org=(20, int(height / 10)), + fontFace=cv2.FONT_HERSHEY_TRIPLEX, + fontScale=int(np.ceil(height / 800)), color=(0, 0, 200), + thickness=int(np.ceil(height / 600))) + + cv2.putText(img=img_copy, text='FPS: ' + str(int(lw_avg_fps)), org=(20, int(height / 4)), + fontFace=cv2.FONT_HERSHEY_TRIPLEX, + fontScale=int(np.ceil(height / 800)), color=(0, 0, 200), + thickness=int(np.ceil(height / 600))) + + cv2.putText(img=adapt_img, text="OpenDR High Resolution", org=(20, int(height / 10)), + fontFace=cv2.FONT_HERSHEY_TRIPLEX, + fontScale=int(np.ceil(height / 800)), color=(0, 0, 200), + thickness=int(np.ceil(height / 600))) + cv2.putText(img=adapt_img, text="Pose Estimation Adaptive ROI selection", org=(20, int(height / 10) + 50), + fontFace=cv2.FONT_HERSHEY_TRIPLEX, + fontScale=int(np.ceil(height / 800)), color=(0, 0, 200), + thickness=int(np.ceil(height / 600))) + cv2.putText(img=adapt_img, text='FPS:' + str(int(adapt_hr_avg_fps)), org=(20, int(height / 4)), + fontFace=cv2.FONT_HERSHEY_TRIPLEX, + fontScale=int(np.ceil(height / 800)), color=(0, 0, 200), + thickness=int(np.ceil(height / 600))) + + heatmap = heatmap * 5 + heatmap = cv2.cvtColor(heatmap, cv2.COLOR_GRAY2BGR) + heatmap = cv2.resize(heatmap, (int(img.shape[1] / 4), int(img.shape[0] / 4))) + img[(img.shape[0] - heatmap.shape[0]):img.shape[0], 0:heatmap.shape[1]] = heatmap + + adapt_heatmap = adapt_heatmap * 5 + adapt_heatmap = cv2.cvtColor(adapt_heatmap, cv2.COLOR_GRAY2BGR) + adapt_heatmap = cv2.resize(adapt_heatmap, (int(img.shape[1] / 4), int(img.shape[0] / 4))) + adapt_img[(adapt_img.shape[0] - adapt_heatmap.shape[0]):adapt_img.shape[0], 0:adapt_heatmap.shape[1]]\ + = adapt_heatmap + + output_image = cv2.hconcat([img_copy, img, adapt_img]) + output_image = cv2.resize(output_image, size) + cv2.imshow('Result', output_image) + + key = cv2.waitKey(1) + if key == 27: + exit(0) + else: + total_time0 = time.time() + img_copy = np.copy(img) + # Perform inference + start_time = time.perf_counter() + hr_poses, heatmap, _ = hr_pose_estimator.infer(img) + hr_time = time.perf_counter() - start_time + total_time = time.time() - total_time0 + + for hr_pose in hr_poses: + draw(img, hr_pose) + + hr_fps = 1 / total_time + + # Calculate a running average on FPS + hr_avg_fps = 0.95 * hr_avg_fps + 0.05 * hr_fps + + cv2.putText(img=img, text="OpenDR High Resolution Pose Estimation", org=(20, int(height / 10)), + fontFace=cv2.FONT_HERSHEY_TRIPLEX, + fontScale=int(np.ceil(height / 800)), color=(0, 0, 200), + thickness=int(np.ceil(height / 600))) + + cv2.putText(img=img, text='FPS:' + str(int(hr_avg_fps)), org=(20, int(height / 4)), + fontFace=cv2.FONT_HERSHEY_TRIPLEX, + fontScale=int(np.ceil(height / 800)), color=(0, 0, 200), + thickness=int(np.ceil(height / 600))) + + heatmap = heatmap * 5 + heatmap = cv2.cvtColor(heatmap, cv2.COLOR_GRAY2BGR) + heatmap = cv2.resize(heatmap, (int(img.shape[1] / 4), int(img.shape[0] / 4))) + img[(img.shape[0] - heatmap.shape[0]):img.shape[0], 0:heatmap.shape[1]] = heatmap + + img = cv2.resize(img, size) + cv2.imshow('Result', img) - total_time0 = time.time() - img_copy = np.copy(img) - - # Perform inference - start_time = time.perf_counter() - hr_poses, heatmap = hr_pose_estimator.infer(img) - hr_time = time.perf_counter() - start_time - - # Perform inference - start_time = time.perf_counter() - lw_poses = lw_pose_estimator.infer(img_copy) - lw_time = time.perf_counter() - start_time - - total_time = time.time() - total_time0 - - for hr_pose in hr_poses: - draw(img, hr_pose) - for lw_pose in lw_poses: - draw(img_copy, lw_pose) - - lw_fps = 1 / (total_time - hr_time) - hr_fps = 1 / (total_time - lw_time) - # Calculate a running average on FPS - hr_avg_fps = 0.95 * hr_avg_fps + 0.05 * hr_fps - lw_avg_fps = 0.95 * lw_avg_fps + 0.05 * lw_fps - - cv2.putText(img=img, text="OpenDR High Resolution", org=(20, int(height / 10)), - fontFace=cv2.FONT_HERSHEY_TRIPLEX, - fontScale=int(np.ceil(height / 600)), color=(200, 0, 0), - thickness=int(np.ceil(height / 600))) - cv2.putText(img=img, text="Pose Estimation", org=(20, int(height / 10) + 50), - fontFace=cv2.FONT_HERSHEY_TRIPLEX, - fontScale=int(np.ceil(height / 600)), color=(200, 0, 0), - thickness=int(np.ceil(height / 600))) - - cv2.putText(img=img_copy, text='Lightweight OpenPose ', org=(20, int(height / 10)), - fontFace=cv2.FONT_HERSHEY_TRIPLEX, - fontScale=int(np.ceil(height / 600)), color=(200, 0, 0), - thickness=int(np.ceil(height / 600))) - - heatmap = heatmap * 5 - heatmap = cv2.cvtColor(heatmap, cv2.COLOR_GRAY2BGR) - heatmap = cv2.resize(heatmap, (int(img.shape[1] / 4), int(img.shape[0] / 4))) - img[(img.shape[0] - heatmap.shape[0]):img.shape[0], 0:heatmap.shape[1]] = heatmap - - output_image = cv2.hconcat([img_copy, img]) - output_image = cv2.resize(output_image, size) - cv2.imshow('Result', output_image) - - key = cv2.waitKey(1) - if key == 27: - exit(0) + key = cv2.waitKey(1) + if key == 27: + exit(0) diff --git a/src/opendr/perception/pose_estimation/hr_pose_estimation/high_resolution_learner.py b/src/opendr/perception/pose_estimation/hr_pose_estimation/high_resolution_learner.py index 4f59ae9f17..8fb88cf0b2 100644 --- a/src/opendr/perception/pose_estimation/hr_pose_estimation/high_resolution_learner.py +++ b/src/opendr/perception/pose_estimation/hr_pose_estimation/high_resolution_learner.py @@ -45,9 +45,9 @@ class HighResolutionPoseEstimationLearner(LightweightOpenPoseLearner): def __init__(self, device='cuda', backbone='mobilenet', temp_path='temp', mobilenet_use_stride=True, mobilenetv2_width=1.0, shufflenet_groups=3, num_refinement_stages=2, batches_per_iter=1, base_height=256, - first_pass_height=360, second_pass_height=540, percentage_around_crop=0.3, heatmap_threshold=0.1, - experiment_name='default', num_workers=8, weights_only=True, output_name='detections.json', - multiscale=False, scales=None, visualize=False, + first_pass_height=360, second_pass_height=540, method='adaptive', percentage_around_crop=0.3, + heatmap_threshold=0.1, experiment_name='default', num_workers=8, weights_only=True, + output_name='detections.json', multiscale=False, scales=None, visualize=False, img_mean=np.array([128, 128, 128], np.float32), img_scale=np.float32(1 / 256), pad_value=(0, 0, 0), half_precision=False): @@ -67,14 +67,32 @@ def __init__(self, device='cuda', backbone='mobilenet', self.first_pass_height = first_pass_height self.second_pass_height = second_pass_height + self.method = method self.perc = percentage_around_crop self.threshold = heatmap_threshold - self.xmin = None - self.ymin = None - self.xmax = None - self.ymax = None - self.counter = 0 self.prev_heatmap = np.array([]) + self.counter = 0 + if self.method == 'primary': + self.xmin = None + self.ymin = None + self.xmax = None + self.ymax = None + + elif self.method == 'adaptive': + self.xmin = None + self.ymin = None + self.xmax = None + self.ymax = None + + self.x1min = None + self.x1max = None + self.y1min = None + self.y1max = None + + self.x2min = None + self.x2max = None + self.y2min = None + self.y2max = None def __first_pass(self, img): """ @@ -177,6 +195,208 @@ def __pooling(self, img, kernel): # Pooling on input image for dimension reduct pool_img = pool_img.squeeze(0).permute(1, 2, 0).cpu().float().numpy() return pool_img + @staticmethod + def __crop_heatmap(heatmap): + """ + This method takes the generated heatmap and crops it around the desirable ROI using its nonzero values. + + :param heatmap: the heatmap that generated from __first_pass function + :type heatmap: numpy.array + + :returns An array that contains the boundaries of the cropped image + :rtype: np.array + """ + detection = False + + if heatmap.nonzero()[0].size > 10 and heatmap.nonzero()[0].size > 10: + detection = True + xmin = min(heatmap.nonzero()[1]) + ymin = min(heatmap.nonzero()[0]) + xmax = max(heatmap.nonzero()[1]) + ymax = max(heatmap.nonzero()[0]) + else: + xmin, ymin, xmax, ymax = 0, 0, 0, 0 + + heatmap_dims = (int(xmin), int(xmax), int(ymin), int(ymax)) + return heatmap_dims, detection + + @staticmethod + def __check_for_split(cropped_heatmap): + """ + This function checks weather or not the cropped heatmap needs further proccessing for extra cropping. + More specifically, returns a boolean for the decision for further crop, the decision depends on the distance between the + target subjects. + + :param cropped_heatmap: the cropped area from the original heatmap + :type cropped_heatmap: np.array + + :returns: A boolean that describes weather is needed to proceed on further cropping + :rtype: bool + """ + sum_rows = cropped_heatmap.sum(axis=1) + sum_col = cropped_heatmap.sum(axis=0) + + heatmap_activation_area = len(sum_col.nonzero()[0]) * len(sum_rows.nonzero()[0]) + crop_total_area = cropped_heatmap.shape[0] * cropped_heatmap.shape[1] # heatmap total area + + if crop_total_area != 0: + crop_rule1 = (heatmap_activation_area / crop_total_area * 100 < 80) and ( + heatmap_activation_area / crop_total_area * 100 > 5) + if crop_rule1: + return True + else: + return False + else: + return False + + @staticmethod + def __split_process(cropped_heatmap): + """ + This function uses the cropped heatmap that crated from __crop_heatmap function and splits it in parts. + + :param cropped_heatmap: the cropped area from the original heatmap + :type cropped_heatmap: np.array + + :returns: Returns a list with the new dimensions of the split parts + :rtype: list + """ + max_diff_c, max_diff_r = 0, 0 + y_crop_l = cropped_heatmap.shape[1] + y_crop_r = 0 + sum_col = cropped_heatmap.sum(axis=0) + x_crop_u = 0 + sum_rows = cropped_heatmap.sum(axis=1) + + for ind in range(len(sum_col.nonzero()[0]) - 1): + diff_c = abs(sum_col.nonzero()[0][ind + 1] - sum_col.nonzero()[0][ind]) + if (diff_c > max_diff_c) and (diff_c > 5): # nonzero columns have at least 5px difference in heatmap + max_diff_c = diff_c + y_crop_l = round(sum_col.nonzero()[0][ind]) + y_crop_r = round(sum_col.nonzero()[0][ind + 1]) + + for ind in range(len(sum_rows.nonzero()[0]) - 1): + diff_r = abs(sum_rows.nonzero()[0][ind + 1] - sum_rows.nonzero()[0][ind]) + if (diff_r > max_diff_r) and (diff_r > 5): # nonzero rows have at least 5px difference in heatmap + max_diff_r = diff_r + x_crop_u = round(sum_rows.nonzero()[0][ind]) + + if max_diff_c >= max_diff_r and max_diff_c > 0: # vertical cut + y1_i = 0 + y1_f = cropped_heatmap.shape[0] + x1_i = 0 + x1_f = int(y_crop_l) + + y2_i = 0 + y2_f = cropped_heatmap.shape[0] + x2_i = int(y_crop_r) + x2_f = cropped_heatmap.shape[1] + + crop1 = cropped_heatmap[y1_i:y1_f, x1_i:x1_f] + crop2 = cropped_heatmap[y2_i:y2_f, x2_i:x2_f] + + elif max_diff_r > max_diff_c and max_diff_r > 0: # horizontal cut + y1_i = 0 + y1_f = int(x_crop_u) + x1_i = 0 + x1_f = cropped_heatmap.shape[1] + + y2_i = int(x_crop_u + 3) + y2_f = cropped_heatmap.shape[0] + x2_i = 0 + x2_f = cropped_heatmap.shape[1] + + crop1 = cropped_heatmap[y1_i:y1_f, x1_i:x1_f] + crop2 = cropped_heatmap[y2_i:y2_f, x2_i:x2_f] + + else: + return [[cropped_heatmap, 0, cropped_heatmap.shape[1], 0, cropped_heatmap.shape[0]]] + + crops = [[crop1, x1_i, x1_f, y1_i, y1_f], [crop2, x2_i, x2_f, y2_i, y2_f]] + return crops + + @staticmethod + def __crop_enclosing_bbox(crop): + """ + This function creates the bounding box for each split part + + :param crop: A split part from the original heatmap + :type crop: np.array + + :returns: the dimensions (xmin, xmax, ymin, ymax) for enclosing bounding box + :rtype: int + """ + if crop.nonzero()[0].size > 0 and crop.nonzero()[1].size > 0: + xmin = min(np.unique(crop.nonzero()[1])) + ymin = min(np.unique(crop.nonzero()[0])) + xmax = max(np.unique(crop.nonzero()[1])) + ymax = max(np.unique(crop.nonzero()[0])) + else: + xmin, xmax, ymin, ymax = 0, 0, 0, 0 + return xmin, xmax, ymin, ymax + + @staticmethod + def __crop_image_func(xmin, xmax, ymin, ymax, pool_img, original_image, heatmap, perc): + """ + This function crops the region of interst(ROI) from the original image to use it in next steps + + :param xmin, ymin: top left corner dimensions of the split part in the original heatmap + :type xmin,ymin: int + :param xmax, ymax: bottom right dimensions of the split part in the original heatmap + :type xmin,ymin: int + :param pool_img: the resized pooled input image + :type pool_img: np.array + :param original_image: the original input image + :type original_image: np.array + :param heatmap: the heatmap generated from __first_pass function + :type heatmap: np.array + :param perc: percentage of the image that is needed for adding extra pad + :type perc: float + + :returns: Returns the cropped image part from the original image and the dimensions of the cropped part in the + original image coordinate system + :rtype :numpy.array, int, int, int, int + """ + upscale_factor_x = pool_img.shape[0] / heatmap.shape[0] + upscale_factor_y = pool_img.shape[1] / heatmap.shape[1] + xmin = upscale_factor_x * xmin + xmax = upscale_factor_x * xmax + ymin = upscale_factor_y * ymin + ymax = upscale_factor_y * ymax + + upscale_to_init_img = original_image.shape[0] / pool_img.shape[0] + xmin = upscale_to_init_img * xmin + xmax = upscale_to_init_img * xmax + ymin = upscale_to_init_img * ymin + ymax = upscale_to_init_img * ymax + + extra_pad_x = int(perc * (xmax - xmin)) # Adding an extra pad around cropped image + extra_pad_y = int(perc * (ymax - ymin)) + + if xmin - extra_pad_x > 0: + xmin = xmin - extra_pad_x + else: + xmin = xmin + if xmax + extra_pad_x < original_image.shape[1]: + xmax = xmax + extra_pad_x + else: + xmax = xmax + + if ymin - extra_pad_y > 0: + ymin = ymin - extra_pad_y + else: + ymin = ymin + if ymax + extra_pad_y < original_image.shape[0]: + ymax = ymax + extra_pad_y + else: + ymax = ymax + + if (xmax - xmin) > 40 and (ymax - ymin) > 40: + crop_img = original_image[int(ymin):int(ymax), int(xmin):int(xmax)] + else: + crop_img = original_image + + return crop_img, int(xmin), int(xmax), int(ymin), int(ymax) + def fit(self, dataset, val_dataset=None, logging_path='', logging_flush_secs=30, silent=False, verbose=True, epochs=None, use_val_subset=True, val_subset_size=250, images_folder_name="train2017", annotations_filename="person_keypoints_train2017.json", @@ -278,15 +498,6 @@ def eval(self, dataset, silent=False, verbose=True, use_subset=True, subset_size pbar_desc = "Evaluation progress" pbar_eval = tqdm(desc=pbar_desc, total=len(data), bar_format="{l_bar}%s{bar}{r_bar}" % '\x1b[38;5;231m') - img_height = data[0]['img'].shape[0] - - if img_height in (1080, 1440): - offset = 200 - elif img_height == 720: - offset = 50 - else: - offset = 0 - for sample in data: file_name = sample['file_name'] img = sample['img'] @@ -295,12 +506,10 @@ def eval(self, dataset, silent=False, verbose=True, use_subset=True, subset_size kernel = int(h / self.first_pass_height) if kernel > 0: pool_img = self.__pooling(img, kernel) - else: pool_img = img - # ------- Heatmap Generation ------- - avg_pafs = self.__first_pass(pool_img) + avg_pafs = self.__first_pass(pool_img) # Heatmap Generation avg_pafs = avg_pafs.astype(np.float32) pafs_map = cv2.blur(avg_pafs, (5, 5)) @@ -347,7 +556,7 @@ def eval(self, dataset, silent=False, verbose=True, use_subset=True, subset_size if (xmax - xmin) > 40 and (ymax - ymin) > 40: crop_img = img[ymin:ymax, xmin:xmax] else: - crop_img = img[offset:img.shape[0], offset:img.shape[1]] + crop_img = img[0:img.shape[0], 0:img.shape[1]] h, w, _ = crop_img.shape @@ -368,10 +577,10 @@ def eval(self, dataset, silent=False, verbose=True, use_subset=True, subset_size for i in range(all_keypoints.shape[0]): for j in range(all_keypoints.shape[1]): - if j == 0: # Adjust offset if needed for evaluation on our HR datasets - all_keypoints[i][j] = round((all_keypoints[i][j] + xmin) - offset) - if j == 1: # Adjust offset if needed for evaluation on our HR datasets - all_keypoints[i][j] = round((all_keypoints[i][j] + ymin) - offset) + if j == 0: + all_keypoints[i][j] = round((all_keypoints[i][j] + xmin)) + if j == 1: + all_keypoints[i][j] = round((all_keypoints[i][j] + ymin)) current_poses = [] for n in range(len(pose_entries)): @@ -400,7 +609,7 @@ def eval(self, dataset, silent=False, verbose=True, use_subset=True, subset_size if self.visualize: for keypoints in coco_keypoints: for idx in range(len(keypoints) // 3): - cv2.circle(img, (int(keypoints[idx * 3] + offset), int(keypoints[idx * 3 + 1]) + offset), + cv2.circle(img, (int(keypoints[idx * 3]), int(keypoints[idx * 3 + 1])), 3, (255, 0, 255), -1) cv2.imshow('keypoints', img) key = cv2.waitKey() @@ -424,34 +633,375 @@ def eval(self, dataset, silent=False, verbose=True, use_subset=True, subset_size print("Evaluation ended with no detections.") return {"average_precision": [0.0 for _ in range(5)], "average_recall": [0.0 for _ in range(5)]} - def infer(self, img, upsample_ratio=4, stride=8, track=True, smooth=True, multiscale=False): + def eval_adaptive(self, dataset, silent=False, verbose=True, use_subset=True, subset_size=250, upsample_ratio=4, + images_folder_name="val2017", annotations_filename="person_keypoints_val2017.json"): + """ + This method is used to evaluate a trained model on an evaluation dataset. + + :param dataset: object that holds the evaluation dataset. + :type dataset: ExternalDataset class object or DatasetIterator class object + :param silent: if set to True, disables all printing of evaluation progress reports and other + information to STDOUT, defaults to 'False' + :type silent: bool, optional + :param verbose: if set to True, enables the maximum verbosity, defaults to 'True' + :type verbose: bool, optional + :param use_subset: If set to True, a subset of the validation dataset is created and used in + evaluation, defaults to 'True' + :type use_subset: bool, optional + :param subset_size: Controls the size of the validation subset, defaults to '250' + :type subset_size: int, optional + param upsample_ratio: Defines the amount of upsampling to be performed on the heatmaps and PAFs + when resizing,defaults to 4 + :type upsample_ratio: int, optional + :param images_folder_name: Folder name that contains the dataset images. This folder should be contained + in the dataset path provided. Note that this is a folder name, not a path, defaults to 'val2017' + :type images_folder_name: str, optional + :param annotations_filename: Filename of the annotations json file. This file should be contained in the + dataset path provided, defaults to 'person_keypoints_val2017.json' + :type annotations_filename: str, optional + + :returns: returns stats regarding evaluation + :rtype: dict """ - This method is used to perform pose estimation on an image. - :param img: image to run inference on - :rtype img: engine.data.Image class object - :param upsample_ratio: Defines the amount of upsampling to be performed on the heatmaps and PAFs - when resizing,defaults to 4 - :type upsample_ratio: int, optional - :param stride: Defines the stride value for creating a padded image - :type stride: int,optional - :param track: If True, infer propagates poses ids from previous frame results to track poses, - defaults to 'True' - :type track: bool, optional - :param smooth: If True, smoothing is performed on pose keypoints between frames, defaults to 'True' - :type smooth: bool, optional - :param multiscale: Specifies whether evaluation will run in the predefined multiple scales setup or not. - :type multiscale: bool,optional + data = super(HighResolutionPoseEstimationLearner, # NOQA + self)._LightweightOpenPoseLearner__prepare_val_dataset(dataset, use_subset=use_subset, + subset_name="val_subset.json", + subset_size=subset_size, + images_folder_default_name=images_folder_name, + annotations_filename=annotations_filename, + verbose=verbose and not silent) + # Model initialization if needed + if self.model is None and self.checkpoint_load_iter != 0: + # No model loaded, initializing new + self.init_model() + # User set checkpoint_load_iter, so they want to load a checkpoint + # Try to find the checkpoint_load_iter checkpoint + checkpoint_name = "checkpoint_iter_" + str(self.checkpoint_load_iter) + ".pth" + checkpoints_folder = os.path.join(self.parent_dir, '{}_checkpoints'.format(self.experiment_name)) + full_path = os.path.join(checkpoints_folder, checkpoint_name) + try: + checkpoint = torch.load(full_path, map_location=torch.device(self.device)) + except FileNotFoundError as e: + e.strerror = "File " + checkpoint_name + " not found inside checkpoints_folder, " \ + "provided checkpoint_load_iter (" + \ + str(self.checkpoint_load_iter) + \ + ") doesn't correspond to a saved checkpoint.\nNo such file or directory." + raise e + if not silent and verbose: + print("Loading checkpoint:", full_path) - :return: Returns a list of engine.target.Pose objects, where each holds a pose - and a heatmap that contains human silhouettes of the input image. - If no detections were made returns an empty list for poses and a black frame for heatmap. + load_state(self.model, checkpoint) + elif self.model is None: + raise AttributeError("self.model is None. Please load a model or set checkpoint_load_iter.") - :rtype: poses -> list of engine.target.Pose objects - heatmap -> np.array() + self.model = self.model.eval() # Change model state to evaluation + self.model.to(self.device) + if "cuda" in self.device: + self.model = self.model.to(self.device) + if self.half: + self.model.half() + + if self.multiscale: + self.scales = [0.5, 1.0, 1.5, 2.0] + + coco_result = [] + num_keypoints = Pose.num_kpts + + pbar_eval = None + if not silent: + pbar_desc = "Evaluation progress" + pbar_eval = tqdm(desc=pbar_desc, total=len(data), bar_format="{l_bar}%s{bar}{r_bar}" % '\x1b[38;5;231m') + + for sample in data: + file_name = sample['file_name'] + img = sample['img'] + h, w, _ = img.shape + max_width = w + kernel = int(h / self.first_pass_height) + if kernel > 0: + pool_img = self.__pooling(img, kernel) + else: + pool_img = img + + avg_pafs = self.__first_pass(pool_img) # Heatmap Generation + avg_pafs = avg_pafs.astype(np.float32) + + pafs_map = cv2.blur(avg_pafs, (5, 5)) + pafs_map[pafs_map < self.threshold] = 0 + + heatmap = pafs_map.sum(axis=2) + heatmap = heatmap * 100 + heatmap = heatmap.astype(np.uint8) + heatmap = cv2.blur(heatmap, (5, 5)) + + self.prev_heatmap = heatmap + heatmap_dims, detection = self.__crop_heatmap(heatmap) + + if detection: + cropped_heatmap = heatmap[heatmap_dims[2]:heatmap_dims[3], heatmap_dims[0]:heatmap_dims[1]] + if self.__check_for_split(cropped_heatmap): + crops = self.__split_process(cropped_heatmap) # Split horizontal or vertical + + crop_part = 0 + for crop_params in crops: + crop = crop_params[0] + if crop.size > 0: + crop_part += 1 + + xmin, xmax, ymin, ymax = self.__crop_enclosing_bbox(crop) + + xmin += heatmap_dims[0] + xmax += heatmap_dims[0] + ymin += heatmap_dims[2] + ymax += heatmap_dims[2] + + xmin += crop_params[1] + xmax += crop_params[1] + ymin += crop_params[3] + ymax += crop_params[3] + + crop_img, xmin, xmax, ymin, ymax = self.__crop_image_func(xmin, xmax, ymin, ymax, pool_img, img, + heatmap, self.perc) + + if crop_part == 1: + if self.x1min is None: + self.x1min = xmin + self.y1min = ymin + self.x1max = xmax + self.y1max = ymax + else: + a = 0.2 + self.x1min = a * xmin + (1 - a) * self.x1min + self.y1min = a * ymin + (1 - a) * self.y1min + self.y1max = a * ymax + (1 - a) * self.y1max + self.x1max = a * xmax + (1 - a) * self.x1max + + elif crop_part == 2: + if self.x2min is None: + self.x2min = xmin + self.y2min = ymin + self.x2max = xmax + self.y2max = ymax + else: + a = 0.2 + self.x2min = a * xmin + (1 - a) * self.x2min + self.y2min = a * ymin + (1 - a) * self.y2min + self.y2max = a * ymax + (1 - a) * self.y2max + self.x2max = a * xmax + (1 - a) * self.x2max + + h, w, _, = crop_img.shape + if h > self.second_pass_height: + second_pass_height = self.second_pass_height + else: + second_pass_height = h + + # ------- Second pass of the image, inference for pose estimation ------- + avg_heatmaps, avg_pafs, scale, pad = self.__second_pass(crop_img, second_pass_height, + max_width, self.stride, upsample_ratio) + + total_keypoints_num = 0 + all_keypoints_by_type = [] + for kpt_idx in range(18): + total_keypoints_num += extract_keypoints(avg_heatmaps[:, :, kpt_idx], all_keypoints_by_type, + total_keypoints_num) + + pose_entries, all_keypoints = group_keypoints(all_keypoints_by_type, avg_pafs) + + for kpt_id in range(all_keypoints.shape[0]): + all_keypoints[kpt_id, 0] = ((all_keypoints[kpt_id, 0] * + self.stride / upsample_ratio - pad[1]) / scale) + all_keypoints[kpt_id, 1] = ((all_keypoints[kpt_id, 1] * + self.stride / upsample_ratio - pad[0]) / scale) + + for i in range(all_keypoints.shape[0]): + for j in range(all_keypoints.shape[1]): + if j == 0: + all_keypoints[i][j] = round((all_keypoints[i][j] + xmin)) + if j == 1: + all_keypoints[i][j] = round((all_keypoints[i][j] + ymin)) + + current_poses = [] + for n in range(len(pose_entries)): + if len(pose_entries[n]) == 0: + continue + pose_keypoints = np.ones((num_keypoints, 2), dtype=np.int32) * -1 + for kpt_id in range(num_keypoints): + if pose_entries[n][kpt_id] != -1.0: # keypoint was found + pose_keypoints[kpt_id, 0] = int(all_keypoints[int(pose_entries[n][kpt_id]), 0]) + pose_keypoints[kpt_id, 1] = int(all_keypoints[int(pose_entries[n][kpt_id]), 1]) + pose = Pose(pose_keypoints, pose_entries[n][18]) + current_poses.append(pose) + + coco_keypoints, scores = convert_to_coco_format(pose_entries, all_keypoints) + + image_id = int(file_name[0:file_name.rfind('.')]) + + for idx in range(len(coco_keypoints)): + coco_result.append({ + 'image_id': image_id, + 'category_id': 1, # person + 'keypoints': coco_keypoints[idx], + 'score': scores[idx] + }) + + if self.visualize: + for keypoints in coco_keypoints: + for idx in range(len(keypoints) // 3): + cv2.circle(img, (int(keypoints[idx * 3]), int(keypoints[idx * 3 + 1])), + 3, (255, 0, 255), -1) + cv2.imshow('keypoints', img) + key = cv2.waitKey() + if key == 27: # esc + return + else: + xmin = heatmap_dims[0] + xmax = heatmap_dims[1] + ymin = heatmap_dims[2] + ymax = heatmap_dims[3] + + h, w, _ = pool_img.shape + xmin = xmin * int((w / heatmap.shape[1])) * kernel + xmax = xmax * int((w / heatmap.shape[1])) * kernel + ymin = ymin * int((h / heatmap.shape[0])) * kernel + ymax = ymax * int((h / heatmap.shape[0])) * kernel + + extra_pad_x = int(self.perc * (xmax - xmin)) + extra_pad_y = int(self.perc * (ymax - ymin)) + + if xmin - extra_pad_x > 0: + xmin = xmin - extra_pad_x + else: + xmin = xmin + + if xmax + extra_pad_x < img.shape[1]: + xmax = xmax + extra_pad_x + else: + xmax = xmax + + if ymin - extra_pad_y > 0: + ymin = ymin - extra_pad_y + else: + ymin = ymin + + if ymax + extra_pad_y < img.shape[0]: + ymax = ymax + extra_pad_y + else: + ymax = ymax + + if (xmax - xmin) > 40 and (ymax - ymin) > 40: + crop_img = img[int(ymin):int(ymax), int(xmin):int(xmax)] + else: + crop_img = img[0:img.shape[0], 0:img.shape[1]] + + h, w, _, = crop_img.shape + if h > self.second_pass_height: + second_pass_height = self.second_pass_height + else: + second_pass_height = h + + # ------- Second pass of the image, inference for pose estimation ------- + avg_heatmaps, avg_pafs, scale, pad = self.__second_pass(crop_img, second_pass_height, + max_width, self.stride, upsample_ratio) + + total_keypoints_num = 0 + all_keypoints_by_type = [] + for kpt_idx in range(18): + total_keypoints_num += extract_keypoints(avg_heatmaps[:, :, kpt_idx], all_keypoints_by_type, + total_keypoints_num) + + pose_entries, all_keypoints = group_keypoints(all_keypoints_by_type, avg_pafs) + + for kpt_id in range(all_keypoints.shape[0]): + all_keypoints[kpt_id, 0] = (all_keypoints[kpt_id, 0] * self.stride / upsample_ratio - pad[1]) / scale + all_keypoints[kpt_id, 1] = (all_keypoints[kpt_id, 1] * self.stride / upsample_ratio - pad[0]) / scale + + for i in range(all_keypoints.shape[0]): + for j in range(all_keypoints.shape[1]): + if j == 0: + all_keypoints[i][j] = round((all_keypoints[i][j] + xmin)) + if j == 1: + all_keypoints[i][j] = round((all_keypoints[i][j] + ymin)) + + current_poses = [] + for n in range(len(pose_entries)): + if len(pose_entries[n]) == 0: + continue + pose_keypoints = np.ones((num_keypoints, 2), dtype=np.int32) * -1 + for kpt_id in range(num_keypoints): + if pose_entries[n][kpt_id] != -1.0: # keypoint was found + pose_keypoints[kpt_id, 0] = int(all_keypoints[int(pose_entries[n][kpt_id]), 0]) + pose_keypoints[kpt_id, 1] = int(all_keypoints[int(pose_entries[n][kpt_id]), 1]) + pose = Pose(pose_keypoints, pose_entries[n][18]) + current_poses.append(pose) + + coco_keypoints, scores = convert_to_coco_format(pose_entries, all_keypoints) + + image_id = int(file_name[0:file_name.rfind('.')]) + + for idx in range(len(coco_keypoints)): + coco_result.append({ + 'image_id': image_id, + 'category_id': 1, # person + 'keypoints': coco_keypoints[idx], + 'score': scores[idx] + }) + + if self.visualize: + for keypoints in coco_keypoints: + for idx in range(len(keypoints) // 3): + cv2.circle(img, (int(keypoints[idx * 3]), int(keypoints[idx * 3 + 1])), + 3, (255, 0, 255), -1) + cv2.imshow('keypoints', img) + key = cv2.waitKey() + if key == 27: # esc + return + + if not silent: + pbar_eval.update(1) + + with open(self.output_name, 'w') as f: + json.dump(coco_result, f, indent=4) + if len(coco_result) != 0: + if use_subset: + result = run_coco_eval(os.path.join(dataset.path, "val_subset.json"), + self.output_name, verbose=not silent) + else: + result = run_coco_eval(os.path.join(dataset.path, annotations_filename), + self.output_name, verbose=not silent) + return {"average_precision": result.stats[0:5], "average_recall": result.stats[5:]} + else: + if not silent and verbose: + print("Evaluation ended with no detections.") + return {"average_precision": [0.0 for _ in range(5)], "average_recall": [0.0 for _ in range(5)]} + + def infer(self, img, upsample_ratio=4, stride=8, track=True, smooth=True, multiscale=False): + """ + This method is used to perform pose estimation on an image. + + :param img: image to run inference on + :rtype img: engine.data.Image class object + :param upsample_ratio: Defines the amount of upsampling to be performed on the heatmaps and PAFs + when resizing,defaults to 4 + :type upsample_ratio: int, optional + :param stride: Defines the stride value for creating a padded image + :type stride: int,optional + :param track: If True, infer propagates poses ids from previous frame results to track poses, + defaults to 'True' + :type track: bool, optional + :param smooth: If True, smoothing is performed on pose keypoints between frames, defaults to 'True' + :type smooth: bool, optional + :param multiscale: Specifies whether evaluation will run in the predefined multiple scales setup or not. + :type multiscale: bool,optional + + :return: Returns a list of engine.target.Pose objects, where each holds a pose + and a heatmap that contains human silhouettes of the input image. + If no detections were made returns an empty list for poses and a black frame for heatmap. + + :rtype: poses -> list of engine.target.Pose objects + heatmap -> np.array() """ current_poses = [] - offset = 0 num_keypoints = Pose.num_kpts if not isinstance(img, Image): img = Image(img) @@ -467,12 +1017,10 @@ def infer(self, img, upsample_ratio=4, stride=8, track=True, smooth=True, multis kernel = int(h / self.first_pass_height) if kernel > 0: pool_img = self.__pooling(img, kernel) - else: pool_img = img - # # ------- Heatmap Generation ------- - avg_pafs = self.__first_pass(pool_img) + avg_pafs = self.__first_pass(pool_img) # Heatmap Generation avg_pafs = avg_pafs.astype(np.float32) pafs_map = cv2.blur(avg_pafs, (5, 5)) @@ -543,7 +1091,7 @@ def infer(self, img, upsample_ratio=4, stride=8, track=True, smooth=True, multis if (xmax - xmin) > 40 and (ymax - ymin) > 40: crop_img = img[int(ymin):int(ymax), int(xmin):int(xmax)] else: - crop_img = img[offset:img.shape[0], offset:img.shape[1]] + crop_img = img[0:img.shape[0], 0:img.shape[1]] h, w, _ = crop_img.shape if crop_img.shape[0] < self.second_pass_height: @@ -571,10 +1119,10 @@ def infer(self, img, upsample_ratio=4, stride=8, track=True, smooth=True, multis for i in range(all_keypoints.shape[0]): for j in range(all_keypoints.shape[1]): - if j == 0: # Adjust offset if needed for evaluation on our HR datasets - all_keypoints[i][j] = round((all_keypoints[i][j] + xmin) - offset) - if j == 1: # Adjust offset if needed for evaluation on our HR datasets - all_keypoints[i][j] = round((all_keypoints[i][j] + ymin) - offset) + if j == 0: + all_keypoints[i][j] = round((all_keypoints[i][j] + xmin)) + if j == 1: + all_keypoints[i][j] = round((all_keypoints[i][j] + ymin)) current_poses = [] for n in range(len(pose_entries)): @@ -589,7 +1137,6 @@ def infer(self, img, upsample_ratio=4, stride=8, track=True, smooth=True, multis if np.count_nonzero(pose_keypoints == -1) < 26: pose = Pose(pose_keypoints, pose_entries[n][18]) current_poses.append(pose) - else: if self.xmin is None: self.xmin = xmin @@ -602,7 +1149,6 @@ def infer(self, img, upsample_ratio=4, stride=8, track=True, smooth=True, multis self.ymin = a * ymin + (1 - a) * self.ymin self.ymax = a * ymax + (1 - a) * self.ymax self.xmax = a * xmax + (1 - a) * self.xmax - else: extra_pad_x = int(self.perc * (self.xmax - self.xmin)) # Adding an extra pad around cropped image @@ -612,6 +1158,7 @@ def infer(self, img, upsample_ratio=4, stride=8, track=True, smooth=True, multis xmin = self.xmin - extra_pad_x else: xmin = self.xmin + if self.xmax + extra_pad_x < img.shape[1]: xmax = self.xmax + extra_pad_x else: @@ -621,6 +1168,7 @@ def infer(self, img, upsample_ratio=4, stride=8, track=True, smooth=True, multis ymin = self.ymin - extra_pad_y else: ymin = self.ymin + if self.ymax + extra_pad_y < img.shape[0]: ymax = self.ymax + extra_pad_y else: @@ -629,7 +1177,7 @@ def infer(self, img, upsample_ratio=4, stride=8, track=True, smooth=True, multis if (xmax - xmin) > 40 and (ymax - ymin) > 40: crop_img = img[int(ymin):int(ymax), int(xmin):int(xmax)] else: - crop_img = img[offset:img.shape[0], offset:img.shape[1]] + crop_img = img[0:img.shape[0], 0:img.shape[1]] h, w, _ = crop_img.shape if crop_img.shape[0] < self.second_pass_height: @@ -655,10 +1203,10 @@ def infer(self, img, upsample_ratio=4, stride=8, track=True, smooth=True, multis for i in range(all_keypoints.shape[0]): for j in range(all_keypoints.shape[1]): - if j == 0: # Adjust offset if needed for evaluation on our HR datasets - all_keypoints[i][j] = round((all_keypoints[i][j] + xmin) - offset) - if j == 1: # Adjust offset if needed for evaluation on our HR datasets - all_keypoints[i][j] = round((all_keypoints[i][j] + ymin) - offset) + if j == 0: + all_keypoints[i][j] = round((all_keypoints[i][j] + xmin)) + if j == 1: + all_keypoints[i][j] = round((all_keypoints[i][j] + ymin)) current_poses = [] for n in range(len(pose_entries)): @@ -681,8 +1229,377 @@ def infer(self, img, upsample_ratio=4, stride=8, track=True, smooth=True, multis else: heatmap = self.prev_heatmap self.counter += 1 + bounds = ([self.xmin, self.xmax, self.ymin, self.ymax],) + return current_poses, heatmap, bounds + + def infer_adaptive(self, img, upsample_ratio=4, stride=8): + """ + This method is used to perform pose estimation on an image. + + :param img: image to run inference on + :rtype img: engine.data.Image class object + :param upsample_ratio: Defines the amount of upsampling to be performed on the heatmaps and PAFs + when resizing,defaults to 4 + :type upsample_ratio: int, optional + :param stride: Defines the stride value for creating a padded image + :type stride: int,optional + :param track: If True, infer propagates poses ids from previous frame results to track poses, + defaults to 'True' + :type track: bool, optional + :param smooth: If True, smoothing is performed on pose keypoints between frames, defaults to 'True' + :type smooth: bool, optional + :param multiscale: Specifies whether evaluation will run in the predefined multiple scales setup or not. + :type multiscale: bool,optional + + :return: Returns a list of engine.target.Pose objects, where each holds a pose + and a heatmap that contains human silhouettes of the input image. + If no detections were made returns an empty list for poses and a black frame for heatmap. + + :rtype: poses -> list of engine.target.Pose objects + heatmap -> np.array() + """ + current_poses = [] + num_keypoints = Pose.num_kpts + if not isinstance(img, Image): + img = Image(img) + + # Bring image into the appropriate format for the implementation + img = img.convert(format='channels_last', channel_order='bgr') + h, w, _ = img.shape + max_width = w + xmin, ymin = 0, 0 + ymax, xmax, _ = img.shape + + if self.counter % 2 == 0: + kernel = int(h / self.first_pass_height) + if kernel > 0: + pool_img = self.__pooling(img, kernel) + else: + pool_img = img + + avg_pafs = self.__first_pass(pool_img) # Heatmap Generation + + avg_pafs = avg_pafs.astype(np.float32) + pafs_map = cv2.blur(avg_pafs, (5, 5)) + pafs_map[pafs_map < self.threshold] = 0 + heatmap = pafs_map.sum(axis=2) + heatmap = heatmap * 100 + heatmap = heatmap.astype(np.uint8) + heatmap = cv2.blur(heatmap, (5, 5)) + self.prev_heatmap = heatmap + heatmap_dims, detection = self.__crop_heatmap(heatmap) + + if detection: + self.xmin = heatmap_dims[0] * (img.shape[1] / heatmap.shape[1]) + self.ymin = heatmap_dims[2] * (img.shape[0] / heatmap.shape[0]) + self.xmax = heatmap_dims[1] * (img.shape[1] / heatmap.shape[1]) + self.ymax = heatmap_dims[3] * (img.shape[0] / heatmap.shape[0]) + cropped_heatmap = heatmap[heatmap_dims[2]:heatmap_dims[3], heatmap_dims[0]:heatmap_dims[1]] + if self.__check_for_split(cropped_heatmap): + # Split horizontal or vertical + crops = self.__split_process(cropped_heatmap) + crop_part = 0 + for crop_params in crops: + crop = crop_params[0] + if crop.size > 0: + crop_part += 1 + + xmin, xmax, ymin, ymax = self.__crop_enclosing_bbox(crop) + + xmin += crop_params[1] + xmax += crop_params[1] + ymin += crop_params[3] + ymax += crop_params[3] + + xmin += heatmap_dims[0] + xmax += heatmap_dims[0] + ymin += heatmap_dims[2] + ymax += heatmap_dims[2] + + crop_img, xmin, xmax, ymin, ymax = self.__crop_image_func(xmin, xmax, ymin, ymax, pool_img, img, + heatmap, self.perc) + + if crop_part == 1: + if self.x1min is None: + self.x1min = xmin + self.y1min = ymin + self.x1max = xmax + self.y1max = ymax + else: + a = 0.2 + self.x1min = a * xmin + (1 - a) * self.x1min + self.y1min = a * ymin + (1 - a) * self.y1min + self.y1max = a * ymax + (1 - a) * self.y1max + self.x1max = a * xmax + (1 - a) * self.x1max + elif crop_part == 2: + if self.x2min is None: + self.x2min = xmin + self.y2min = ymin + self.x2max = xmax + self.y2max = ymax + else: + a = 0.2 + self.x2min = a * xmin + (1 - a) * self.x2min + self.y2min = a * ymin + (1 - a) * self.y2min + self.y2max = a * ymax + (1 - a) * self.y2max + self.x2max = a * xmax + (1 - a) * self.x2max + + h, w, _, = crop_img.shape + if h > self.second_pass_height: + second_pass_height = self.second_pass_height + else: + second_pass_height = h + + # ------- Second pass of the image, inference for pose estimation ------- + avg_heatmaps, avg_pafs, scale, pad = self.__second_pass(crop_img, second_pass_height, + max_width, self.stride, upsample_ratio) + + total_keypoints_num = 0 + all_keypoints_by_type = [] + for kpt_idx in range(18): + total_keypoints_num += extract_keypoints(avg_heatmaps[:, :, kpt_idx], all_keypoints_by_type, + total_keypoints_num) + + pose_entries, all_keypoints = group_keypoints(all_keypoints_by_type, avg_pafs) + + for kpt_id in range(all_keypoints.shape[0]): + all_keypoints[kpt_id, 0] = ((all_keypoints[kpt_id, 0] * + self.stride / upsample_ratio - pad[1]) / scale) + all_keypoints[kpt_id, 1] = ((all_keypoints[kpt_id, 1] * + self.stride / upsample_ratio - pad[0]) / scale) + + for i in range(all_keypoints.shape[0]): + for j in range(all_keypoints.shape[1]): + if j == 0: + all_keypoints[i][j] = round((all_keypoints[i][j] + xmin)) + if j == 1: + all_keypoints[i][j] = round((all_keypoints[i][j] + ymin)) + + for n in range(len(pose_entries)): + if len(pose_entries[n]) == 0: + continue + pose_keypoints = np.ones((num_keypoints, 2), dtype=np.int32) * -1 + for kpt_id in range(num_keypoints): + if pose_entries[n][kpt_id] != -1.0: # keypoint was found + pose_keypoints[kpt_id, 0] = int(all_keypoints[int(pose_entries[n][kpt_id]), 0]) + pose_keypoints[kpt_id, 1] = int(all_keypoints[int(pose_entries[n][kpt_id]), 1]) + pose = Pose(pose_keypoints, pose_entries[n][18]) + current_poses.append(pose) + + else: + xmin = heatmap_dims[0] + xmax = heatmap_dims[1] + ymin = heatmap_dims[2] + ymax = heatmap_dims[3] + + h, w, _ = pool_img.shape + xmin = xmin * int((w / heatmap.shape[1])) * kernel + xmax = xmax * int((w / heatmap.shape[1])) * kernel + ymin = ymin * int((h / heatmap.shape[0])) * kernel + ymax = ymax * int((h / heatmap.shape[0])) * kernel + + extra_pad_x = int(self.perc * (xmax - xmin)) + extra_pad_y = int(self.perc * (ymax - ymin)) + + if xmin - extra_pad_x > 0: + xmin = xmin - extra_pad_x + else: + xmin = xmin + + if xmax + extra_pad_x < img.shape[1]: + xmax = xmax + extra_pad_x + else: + xmax = xmax + + if ymin - extra_pad_y > 0: + ymin = ymin - extra_pad_y + else: + ymin = ymin + + if ymax + extra_pad_y < img.shape[0]: + ymax = ymax + extra_pad_y + else: + ymax = ymax + + if self.xmin is None: + self.xmin = xmin + self.ymin = ymin + self.xmax = xmax + self.ymax = ymax + self.x1min, self.x1max, self.y1min, self.y1max = xmin, xmax, ymin, ymax + self.x2min, self.x2max, self.y2min, self.y2max = xmin, xmax, ymin, ymax + else: + a = 0.2 + self.xmin = a * xmin + (1 - a) * self.xmin + self.ymin = a * ymin + (1 - a) * self.ymin + self.ymax = a * ymax + (1 - a) * self.ymax + self.xmax = a * xmax + (1 - a) * self.xmax + self.x1min, self.x1max, self.y1min, self.y1max = self.xmin, self.xmax, self.ymin, self.ymax + self.x2min, self.x2max, self.y2min, self.y2max = self.xmin, self.xmax, self.ymin, self.ymax + + if (xmax - xmin) > 40 and (ymax - ymin) > 40: + crop_img = img[int(ymin):int(ymax), int(xmin):int(xmax)] + else: + crop_img = img[0:img.shape[0], 0:img.shape[1]] + + h, w, _, = crop_img.shape + if h > self.second_pass_height: + second_pass_height = self.second_pass_height + else: + second_pass_height = h + + # ------- Second pass of the image, inference for pose estimation ------- + avg_heatmaps, avg_pafs, scale, pad = self.__second_pass(crop_img, second_pass_height, + max_width, self.stride, upsample_ratio) + + total_keypoints_num = 0 + all_keypoints_by_type = [] + for kpt_idx in range(18): + total_keypoints_num += extract_keypoints(avg_heatmaps[:, :, kpt_idx], all_keypoints_by_type, + total_keypoints_num) + + pose_entries, all_keypoints = group_keypoints(all_keypoints_by_type, avg_pafs) + + for kpt_id in range(all_keypoints.shape[0]): + all_keypoints[kpt_id, 0] = (all_keypoints[kpt_id, 0] * stride / upsample_ratio - pad[1]) / scale + all_keypoints[kpt_id, 1] = (all_keypoints[kpt_id, 1] * stride / upsample_ratio - pad[0]) / scale + + for i in range(all_keypoints.shape[0]): + for j in range(all_keypoints.shape[1]): + if j == 0: + all_keypoints[i][j] = round((all_keypoints[i][j] + xmin)) + if j == 1: + all_keypoints[i][j] = round((all_keypoints[i][j] + ymin)) + + current_poses = [] + for n in range(len(pose_entries)): + if len(pose_entries[n]) == 0: + continue + pose_keypoints = np.ones((num_keypoints, 2), dtype=np.int32) * -1 + for kpt_id in range(num_keypoints): + if pose_entries[n][kpt_id] != -1.0: # keypoint was found + pose_keypoints[kpt_id, 0] = int(all_keypoints[int(pose_entries[n][kpt_id]), 0]) + pose_keypoints[kpt_id, 1] = int(all_keypoints[int(pose_entries[n][kpt_id]), 1]) + pose = Pose(pose_keypoints, pose_entries[n][18]) + current_poses.append(pose) + else: + if self.xmin is None: + self.xmin = xmin + self.ymin = ymin + self.xmax = xmax + self.ymax = ymax + self.x1min, self.x1max, self.y1min, self.y1max = 0, 0, 0, 0 + self.x2min, self.x2max, self.y2min, self.y2max = 0, 0, 0, 0 + else: + a = 0.8 + self.xmin = a * xmin + (1 - a) * self.xmin + self.ymin = a * ymin + (1 - a) * self.ymin + self.ymax = a * ymax + (1 - a) * self.ymax + self.xmax = a * xmax + (1 - a) * self.xmax + + self.x1min, self.x1max, self.y1min, self.y1max = 0, 0, 0, 0 + self.x2min, self.x2max, self.y2min, self.y2max = 0, 0, 0, 0 + else: + if self.x1min is None: + self.x1min = xmin + self.y1min = ymin + self.x1max = xmax + self.y1max = ymax + if self.x2min is None: + self.x2min = xmin + self.y2min = ymin + self.x2max = xmax + self.y2max = ymax + + boxes = ([self.x1min, self.x1max, self.y1min, self.y1max], [self.x2min, self.x2max, self.y2min, self.y2max]) + for box in boxes: + xmin = box[0] + xmax = box[1] + ymin = box[2] + ymax = box[3] + + extra_pad_x = int(self.perc * (xmax - xmin)) + extra_pad_y = int(self.perc * (ymax - ymin)) + + if (xmin - extra_pad_x > 0) and (xmin > 0): + xmin = xmin - extra_pad_x + else: + xmin = xmin + + if (xmax + extra_pad_x < img.shape[1]) and (xmax < img.shape[1]): + xmax = xmax + extra_pad_x + else: + xmax = xmax + + if (ymin - extra_pad_y > 0) and (ymin > 0): + ymin = ymin - extra_pad_y + else: + ymin = ymin + + if (ymax + extra_pad_y < img.shape[0]) and (ymax < img.shape[0]): + ymax = ymax + extra_pad_y + else: + ymax = ymax + + if (xmax - xmin) > 40 and (ymax - ymin) > 40: + crop_img = img[int(ymin):int(ymax), int(xmin):int(xmax)] + else: + crop_img = img[0:img.shape[0], 0:img.shape[1]] + + h, w, _ = crop_img.shape + if crop_img.shape[0] < self.second_pass_height: + second_pass_height = crop_img.shape[0] + else: + second_pass_height = self.second_pass_height + + # ------- Second pass of the image, inference for pose estimation ------- + avg_heatmaps, avg_pafs, scale, pad = self.__second_pass(crop_img, second_pass_height, + max_width, self.stride, upsample_ratio) + + total_keypoints_num = 0 + all_keypoints_by_type = [] + for kpt_idx in range(18): + total_keypoints_num += extract_keypoints(avg_heatmaps[:, :, kpt_idx], all_keypoints_by_type, + total_keypoints_num) + + pose_entries, all_keypoints = group_keypoints(all_keypoints_by_type, avg_pafs) + + for kpt_id in range(all_keypoints.shape[0]): + all_keypoints[kpt_id, 0] = (all_keypoints[kpt_id, 0] * self.stride / upsample_ratio - pad[1]) / scale + all_keypoints[kpt_id, 1] = (all_keypoints[kpt_id, 1] * self.stride / upsample_ratio - pad[0]) / scale + + for i in range(all_keypoints.shape[0]): + for j in range(all_keypoints.shape[1]): + if j == 0: + all_keypoints[i][j] = round((all_keypoints[i][j] + xmin)) + if j == 1: + all_keypoints[i][j] = round((all_keypoints[i][j] + ymin)) + + current_poses = [] + for n in range(len(pose_entries)): + if len(pose_entries[n]) == 0: + continue + pose_keypoints = np.ones((num_keypoints, 2), dtype=np.int32) * -1 + for kpt_id in range(num_keypoints): + if pose_entries[n][kpt_id] != -1.0: # keypoint was found + pose_keypoints[kpt_id, 0] = int(all_keypoints[int(pose_entries[n][kpt_id]), 0]) + pose_keypoints[kpt_id, 1] = int(all_keypoints[int(pose_entries[n][kpt_id]), 1]) + + if np.count_nonzero(pose_keypoints == -1) < 26: + pose = Pose(pose_keypoints, pose_entries[n][18]) + current_poses.append(pose) + + if np.any(self.prev_heatmap) is False: + heatmap = np.zeros((int(img.shape[0] / (int((img.shape[0] / self.first_pass_height))) / 8), + int(img.shape[1] / (int((img.shape[0] / self.first_pass_height))) / 8)), + dtype=np.uint8) + else: + heatmap = self.prev_heatmap + self.counter += 1 - return current_poses, heatmap + bounds = [(self.x1min, self.x1max, self.y1min, self.y1max), + (self.x2min, self.x2max, self.y2min, self.y2max)] + return current_poses, heatmap, bounds def download(self, path=None, mode="pretrained", verbose=False, url=OPENDR_SERVER_URL + "perception/pose_estimation/lightweight_open_pose/", diff --git a/tests/sources/tools/perception/pose_estimation/high_resolution_pose_estimation/test_high_resolution_pose_estimation.py b/tests/sources/tools/perception/pose_estimation/high_resolution_pose_estimation/test_high_resolution_pose_estimation.py index 230b6da5fc..a6a461ccf7 100644 --- a/tests/sources/tools/perception/pose_estimation/high_resolution_pose_estimation/test_high_resolution_pose_estimation.py +++ b/tests/sources/tools/perception/pose_estimation/high_resolution_pose_estimation/test_high_resolution_pose_estimation.py @@ -81,15 +81,41 @@ def test_eval(self): warnings.simplefilter("default", ResourceWarning) warnings.simplefilter("default", DeprecationWarning) + def test_eval_adaptive(self): + # Test eval will issue resource warnings due to some files left open in pycoco tools, + # as well as a deprecation warning due to a cast of a float to integer (hopefully they will be fixed in a future + # version) + warnings.simplefilter("ignore", ResourceWarning) + warnings.simplefilter("ignore", DeprecationWarning) + + eval_dataset = ExternalDataset(path=os.path.join(self.temp_dir, "dataset"), dataset_type="COCO") + results_dict = self.pose_estimator.eval_adaptive(eval_dataset, use_subset=False, verbose=True, silent=True, + images_folder_name="image", annotations_filename="annotation.json") + self.assertNotEqual(len(results_dict['average_precision']), 0, + msg="Eval results dictionary contains empty list.") + self.assertNotEqual(len(results_dict['average_recall']), 0, + msg="Eval results dictionary contains empty list.") + # Cleanup + rmfile(os.path.join(self.temp_dir, "detections.json")) + warnings.simplefilter("default", ResourceWarning) + warnings.simplefilter("default", DeprecationWarning) + def test_infer(self): self.pose_estimator.model = None self.pose_estimator.load(os.path.join(self.temp_dir, "openpose_default")) - - img = Image.open(os.path.join(self.temp_dir, "dataset", "image", "000000000785_1080.jpg")) + img = Image.open(os.path.join(self.temp_dir, "dataset", "image", "000000052591_1080.jpg")) # Default pretrained mobilenet model detects 18 keypoints on img with id 785 self.assertGreater(len(self.pose_estimator.infer(img)[0][0].data), 0, msg="Returned pose must have non-zero number of keypoints.") + def test_infer_adaptive(self): + self.pose_estimator.model = None + self.pose_estimator.load(os.path.join(self.temp_dir, "openpose_default")) + img = Image.open(os.path.join(self.temp_dir, "dataset", "image", "000000052591_1080.jpg")) + # Default pretrained mobilenet model detects 18 keypoints on img with id 785 + self.assertGreater(len(self.pose_estimator.infer_adaptive(img)[0][0].data), 0, + msg="Returned pose must have non-zero number of keypoints.") + if __name__ == "__main__": unittest.main()