From aa0ce9fb6ffe2acb83d0e4db80df9bbf6be1b987 Mon Sep 17 00:00:00 2001 From: liqikai Date: Wed, 26 Oct 2022 15:23:36 +0800 Subject: [PATCH] [Fix] Avoid mapping unnecessary key to albu format --- .../datasets/transforms/common_transforms.py | 33 +++++++------------ 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/mmpose/datasets/transforms/common_transforms.py b/mmpose/datasets/transforms/common_transforms.py index a29e163df4..b923fc7041 100644 --- a/mmpose/datasets/transforms/common_transforms.py +++ b/mmpose/datasets/transforms/common_transforms.py @@ -618,7 +618,6 @@ def __init__(self, } else: self.keymap_to_albu = keymap - self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()} def albu_builder(self, cfg: dict) -> albumentations: """Import a module from albumentations. @@ -654,23 +653,6 @@ def albu_builder(self, cfg: dict) -> albumentations: return obj_cls(**args) - @staticmethod - def mapper(d: dict, keymap: dict) -> dict: - """Dictionary mapper. - - Renames keys according to keymap provided. - - Args: - d (dict): old dict - keymap (dict): key mapping like {'old_key': 'new_key'}. - - Returns: - dict: new dict. - """ - - updated_dict = {keymap.get(k, k): v for k, v in d.items()} - return updated_dict - def transform(self, results: dict) -> dict: """The transform function of :class:`Albumentation` to apply albumentations transforms. @@ -684,11 +666,18 @@ def transform(self, results: dict) -> dict: dict: updated result dict. """ # map result dict to albumentations format - results = self.mapper(results, self.keymap_to_albu) + results_albu = {} + for k, v in self.keymap_to_albu.items(): + assert k in results, \ + f'The `{k}` is required to perform albumentations transforms' + results_albu[v] = results[k] + # Apply albumentations transforms - results = self.aug(**results) - # map result dict back to the original format - results = self.mapper(results, self.keymap_back) + results_albu = self.aug(**results_albu) + + # map the albu results back to the original format + for k, v in self.keymap_to_albu.items(): + results[k] = results_albu[v] return results