Skip to content

Commit

Permalink
apply review suggestions
Browse files Browse the repository at this point in the history
edited hardcoded values, removed unnecessary values from learner, added docstrings in functions
edited the readme file after the changes
  • Loading branch information
mthodoris committed Dec 21, 2022
1 parent 4234552 commit 4a6c7d2
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 21 deletions.
9 changes: 6 additions & 3 deletions docs/reference/high-resolution-pose-estimation.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ The [HighResolutionPoseEstimationLearner](/src/opendr/perception/pose_estimation

#### `HighResolutionPoseEstimationLearner` constructor
```python
HighResolutionPoseEstimationLearner(self, device, backbone, temp_path, mobilenet_use_stride, mobilenetv2_width, shufflenet_groups, num_refinement_stages, batches_per_iter, base_height, first_pass_height, second_pass_height, img_resol, experiment_name, num_workers, weights_only, output_name, multiscale, scales, visualize, img_mean, img_scale, pad_value, half_precision)
HighResolutionPoseEstimationLearner(self, device, backbone, temp_path, mobilenet_use_stride, mobilenetv2_width, shufflenet_groups, num_refinement_stages, batches_per_iter, base_height, first_pass_height, second_pass_height, percentage_arround_crop, heatmap_threshold, experiment_name, num_workers, weights_only, output_name, multiscale, scales, visualize, img_mean, img_scale, pad_value, half_precision)
```

Constructor parameters:
Expand Down Expand Up @@ -45,6 +45,10 @@ Constructor parameters:
Specifies the height that the input image will be resized during the heatmap generation procedure.
- **second_pass_height**: *int, default=540*\
Specifies the height of the image on the second inference for pose estimation procedure.
- **percentage_arround_crop**: *float, default=0.3*\
Specifies the percentage of an extra pad arround the cropped image
- **heatmap_threshold**: *float, default=0.1*\
Specifies a threshlod value that the heatmap elements should have
- **experiment_name**: *str, default='default'*\
String name to attach to checkpoints.
- **num_workers**: *int, default=8*\
Expand Down Expand Up @@ -217,8 +221,7 @@ Parameters:
pose_estimator = HighResolutionPoseEstimationLearner(device='cuda', num_refinement_stages=2,
mobilenet_use_stride=False, half_precision=False,
first_pass_height=360,
second_pass_height=540,
img_resolution=1080)
second_pass_height=540)
pose_estimator.download() # Download the default pretrained mobilenet model in the temp_path

pose_estimator.load("./parent_dir/openpose_default")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class HighResolutionPoseEstimationLearner(LightweightOpenPoseLearner):
def __init__(self, device='cuda', backbone='mobilenet',
temp_path='temp', mobilenet_use_stride=True, mobilenetv2_width=1.0, shufflenet_groups=3,
num_refinement_stages=2, batches_per_iter=1, base_height=256,
first_pass_height=360, second_pass_height=540, img_resolution=1080,
first_pass_height=360, second_pass_height=540, percentage_arround_crop=0.3, heatmap_threshold=0.1,
experiment_name='default', num_workers=8, weights_only=True, output_name='detections.json',
multiscale=False, scales=None, visualize=False,
img_mean=np.array([128, 128, 128], np.float32), img_scale=np.float32(1 / 256), pad_value=(0, 0, 0),
Expand All @@ -56,7 +56,8 @@ def __init__(self, device='cuda', backbone='mobilenet',
shufflenet_groups=shufflenet_groups,
num_refinement_stages=num_refinement_stages,
batches_per_iter=batches_per_iter,
base_height=base_height, experiment_name=experiment_name,
base_height=base_height,
experiment_name=experiment_name,
num_workers=num_workers, weights_only=weights_only,
output_name=output_name, multiscale=multiscale,
scales=scales, visualize=visualize, img_mean=img_mean,
Expand All @@ -65,9 +66,23 @@ def __init__(self, device='cuda', backbone='mobilenet',

self.first_pass_height = first_pass_height
self.second_pass_height = second_pass_height
self.img_resol = img_resolution # default value for sample image in OpenDR server
self.perc = percentage_arround_crop
self.threshold = heatmap_threshold

def __first_pass(self, net, img):
"""
This method is generating a rough heatmap of the input image in order to specify the approximate location
of humans in the picture.
:param net: the pose estimation model that has been loaded
:type net: opendr.perception.pose_estimation.lightweight_open_pose.\
algorithm.models.with_mobilenet.PoseEstimationWithMobileNet class object
:param img: input image for heatmap generation
:type img: numpy.ndarray
:return: returns the Part Affinity Fields (PAFs) of the humans inside the image
:rtype: numpy.ndarray
"""

if 'cuda' in self.device:
tensor_img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float().cuda()
Expand All @@ -86,6 +101,31 @@ def __first_pass(self, net, img):
def __second_pass(self, net, img, net_input_height_size, max_width, stride, upsample_ratio,
pad_value=(0, 0, 0),
img_mean=np.array([128, 128, 128], np.float32), img_scale=np.float32(1 / 256)):
"""
This method detects the keypoints and estimates the pose of humans using the cropped image from the
previous step (__first_pass_).
param net: the pose estimation model that has been loaded
:type net: opendr.perception.pose_estimation.lightweight_open_pose.\
algorithm.models.with_mobilenet.PoseEstimationWithMobileNet class object
:param img: input image for heatmap generation
:type img: numpy.ndarray
:param net_input_height_size: the height that the input image will be resized for inference
:type net_input_height_size: int
:param max_width: this parameter is the maximum width that the resized image should have. It is introduced to
avoid cropping images with abnormal ratios e.g (30, 800)
:type max_width: int
:param upsample_ratio: Defines the amount of upsampling to be performed on the heatmaps and PAFs when resizing,
defaults to 4
:type upsample_ratio: int, optional
:returns: the heatmap of human figures, the part affinity filed (pafs), the scale of the resized image compred
to the initial and the pad arround the image
:rtype: heatmap, pafs -> numpy.ndarray
scale -> float
pad = -> list
"""

height, width, _ = img.shape
scale = net_input_height_size / height
img_ratio = width / height
Expand Down Expand Up @@ -119,7 +159,14 @@ def __second_pass(self, net, img, net_input_height_size, max_width, stride, upsa

return heatmaps, pafs, scale, pad

def __pooling(self, img, kernel): # Pooling on input image for dim reduction
def __pooling(self, img, kernel): # Pooling on input image for dimension reduction
"""This method applies a pooling filter on an input image in order to resize it in a fixed shape
:param img: input image for resizing
:rtype img: engine.data.Image class object
:param kernel: the kernel size of the pooling filter
:type kernel: int
"""
pool_img = torchvision.transforms.ToTensor()(img)
pool_img = pool_img.unsqueeze(0)
pool_img = torch.nn.functional.avg_pool2d(pool_img, kernel)
Expand All @@ -146,6 +193,34 @@ def save(self, path):

def eval(self, dataset, silent=False, verbose=True, use_subset=True, subset_size=250, upsample_ratio=4,
images_folder_name="val2017", annotations_filename="person_keypoints_val2017.json"):
"""
This method is used to evaluate a trained model on an evaluation dataset.
:param dataset: object that holds the evaluation dataset.
:type dataset: ExternalDataset class object or DatasetIterator class object
:param silent: if set to True, disables all printing of evalutaion progress reports and other
information to STDOUT, defaults to 'False'
:type silent: bool, optional
:param verbose: if set to True, enables the maximum verbosity, defaults to 'True'
:type verbose: bool, optional
:param use_subset: If set to True, a subset of the validation dataset is created and used in
evaluation, defaults to 'True'
:type use_subset: bool, optional
:param subset_size: Controls the size of the validation subset, defaults to '250'
:type subset_size: int, optional
param upsample_ratio: Defines the amount of upsampling to be performed on the heatmaps and PAFs
when resizing,defaults to 4
:type upsample_ratio: int, optional
:param images_folder_name: Folder name that contains the dataset images. This folder should be contained
in the dataset path provided. Note that this is a folder name, not a path, defaults to 'val2017'
:type images_folder_name: str, optional
:param annotations_filename: Filename of the annotations json file. This file should be contained in the
dataset path provided, defaults to 'pesron_keypoints_val2017.json'
:type annotations_filename: str, optional
:returns: returns stats regarding evaluation
:rtype: dict
"""

data = super(HighResolutionPoseEstimationLearner,
self)._LightweightOpenPoseLearner__prepare_val_dataset(dataset, use_subset=use_subset,
Expand Down Expand Up @@ -217,15 +292,12 @@ def eval(self, dataset, silent=False, verbose=True, use_subset=True, subset_siz
else:
pool_img = img

perc = 0.3 # percentage around cropping
threshold = 0.1 # threshold for heatmap

# ------- Heatmap Generation -------
avg_pafs = HighResolutionPoseEstimationLearner.__first_pass(self, self.model, pool_img)
avg_pafs = avg_pafs.astype(np.float32)

pafs_map = cv2.blur(avg_pafs, (5, 5))
pafs_map[pafs_map < threshold] = 0
pafs_map[pafs_map < self.threshold] = 0

heatmap = pafs_map.sum(axis=2)
heatmap = heatmap * 100
Expand Down Expand Up @@ -253,8 +325,8 @@ def eval(self, dataset, silent=False, verbose=True, use_subset=True, subset_siz
ymin = int(np.floor(min(ydim))) * int((h / heatmap.shape[0])) * kernel
ymax = int(np.floor(max(ydim))) * int((h / heatmap.shape[0])) * kernel

extra_pad_x = int(perc * (xmax - xmin)) # Adding an extra pad around cropped image
extra_pad_y = int(perc * (ymax - ymin))
extra_pad_x = int(self.perc * (xmax - xmin)) # Adding an extra pad around cropped image
extra_pad_y = int(self.perc * (ymax - ymin))

if xmin - extra_pad_x > 0:
xmin = xmin - extra_pad_x
Expand Down Expand Up @@ -350,6 +422,28 @@ def eval(self, dataset, silent=False, verbose=True, use_subset=True, subset_siz

def infer(self, img, upsample_ratio=4, stride=8, track=True, smooth=True,
multiscale=False):
"""
This method is used to perform pose estimation on an image.
:param img: image to run inference on
:rtype img: engine.data.Image class object
:param upsample_ratio: Defines the amount of upsampling to be performed on the heatmaps and PAFs
when resizing,defaults to 4
:type upsample_ratio: int, optional
:param stride: Defines the stride value for creating a padded image
:type stride: int,optional
:param track: If True, infer propagates poses ids from previous frame results to track poses,
defaults to 'True'
:type track: bool, optional
:param smooth: If True, smoothing is performed on pose keypoints between frames, defaults to 'True'
:type smooth: bool, optional
:param multiscale: Specifies whether evaluation will run in the predefined multiple scales setup or not.
:type multiscale: bool,optional
:return: Returns a list of engine.target.Pose objects, where each holds a pose, or returns an empty list
if no detections were made.
:rtype: list of engine.target.Pose objects
"""
current_poses = []

offset = 0
Expand All @@ -371,16 +465,12 @@ def infer(self, img, upsample_ratio=4, stride=8, track=True, smooth=True,
else:
pool_img = img

perc = 0.3 # percentage around cropping

threshold = 0.1 # threshold for heatmap

# ------- Heatmap Generation -------
avg_pafs = HighResolutionPoseEstimationLearner.__first_pass(self, self.model, pool_img)
avg_pafs = avg_pafs.astype(np.float32)
pafs_map = cv2.blur(avg_pafs, (5, 5))

pafs_map[pafs_map < threshold] = 0
pafs_map[pafs_map < self.threshold] = 0

heatmap = pafs_map.sum(axis=2)
heatmap = heatmap * 100
Expand All @@ -407,8 +497,8 @@ def infer(self, img, upsample_ratio=4, stride=8, track=True, smooth=True,
ymin = int(np.floor(min(ydim))) * int((h / heatmap.shape[0])) * kernel
ymax = int(np.floor(max(ydim))) * int((h / heatmap.shape[0])) * kernel

extra_pad_x = int(perc * (xmax - xmin)) # Adding an extra pad around cropped image
extra_pad_y = int(perc * (ymax - ymin))
extra_pad_x = int(self.perc * (xmax - xmin)) # Adding an extra pad around cropped image
extra_pad_y = int(self.perc * (ymax - ymin))

if xmin - extra_pad_x > 0:
xmin = xmin - extra_pad_x
Expand Down Expand Up @@ -510,7 +600,7 @@ def download(self, path=None, mode="pretrained", verbose=False,
file_url = os.path.join(url, "dataset", "annotation.json")
urlretrieve(file_url, os.path.join(self.temp_path, "dataset", "annotation.json"))
# Download test image
if image_resolution in(1080, 1440):
if image_resolution in (1080, 1440):
file_url = os.path.join(url, "dataset", "image", "000000000785_" + str(image_resolution) + ".jpg")
urlretrieve(file_url, os.path.join(self.temp_path, "dataset", "image", "000000000785_1080.jpg"))
else:
Expand Down

0 comments on commit 4a6c7d2

Please # to comment.