From 6a80afde2fc5e847acd021e31087a698a91407da Mon Sep 17 00:00:00 2001 From: Xiangxu-0103 Date: Wed, 22 Feb 2023 07:54:41 +0000 Subject: [PATCH] add prob argument --- mmdet3d/datasets/transforms/transforms_3d.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mmdet3d/datasets/transforms/transforms_3d.py b/mmdet3d/datasets/transforms/transforms_3d.py index 87cc95046..dbdbf2a45 100644 --- a/mmdet3d/datasets/transforms/transforms_3d.py +++ b/mmdet3d/datasets/transforms/transforms_3d.py @@ -2385,19 +2385,22 @@ class PolarMix(BaseTransform): rotate_paste_ratio (float): Rotate paste ratio. Defaults to 1.0. pre_transform (Sequence[dict], optional): Sequence of transform object or config dict to be composed. Defaults to None. + prob (float): The transformation probability. Defaults to 1.0. """ def __init__(self, instance_classes: List[int], swap_ratio: float = 0.5, rotate_paste_ratio: float = 1.0, - pre_transform: Optional[Sequence[dict]] = None) -> None: + pre_transform: Optional[Sequence[dict]] = None, + prob: float = 1.0) -> None: assert is_list_of(instance_classes, int), \ 'instance_classes should be a list of int' self.instance_classes = instance_classes self.swap_ratio = swap_ratio self.rotate_paste_ratio = rotate_paste_ratio + self.prob = prob if pre_transform is None: self.pre_transform = None else: @@ -2485,6 +2488,8 @@ def transform(self, input_dict: dict) -> dict: Returns: dict: output dict after transformation. """ + if np.random.rand() > self.prob: + return input_dict assert 'dataset' in input_dict, \ '`dataset` is needed to pass through PolarMix, while not found.' @@ -2513,5 +2518,6 @@ def __repr__(self) -> str: repr_str += f'(instance_classes={self.instance_classes}, ' repr_str += f'swap_ratio={self.swap_ratio}, ' repr_str += f'rotate_paste_ratio={self.rotate_paste_ratio}, ' - repr_str += f'pre_transform={self.pre_transform})' + repr_str += f'pre_transform={self.pre_transform}, ' + repr_str += f'prob={self.prob})' return repr_str