Skip to content

Commit

Permalink
add prob argument
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiangxu-0103 committed Feb 22, 2023
1 parent 194a0be commit 6a80afd
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions mmdet3d/datasets/transforms/transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2385,19 +2385,22 @@ class PolarMix(BaseTransform):
rotate_paste_ratio (float): Rotate paste ratio. Defaults to 1.0.
pre_transform (Sequence[dict], optional): Sequence of transform object
or config dict to be composed. Defaults to None.
prob (float): The transformation probability. Defaults to 1.0.
"""

def __init__(self,
instance_classes: List[int],
swap_ratio: float = 0.5,
rotate_paste_ratio: float = 1.0,
pre_transform: Optional[Sequence[dict]] = None) -> None:
pre_transform: Optional[Sequence[dict]] = None,
prob: float = 1.0) -> None:
assert is_list_of(instance_classes, int), \
'instance_classes should be a list of int'
self.instance_classes = instance_classes
self.swap_ratio = swap_ratio
self.rotate_paste_ratio = rotate_paste_ratio

self.prob = prob
if pre_transform is None:
self.pre_transform = None
else:
Expand Down Expand Up @@ -2485,6 +2488,8 @@ def transform(self, input_dict: dict) -> dict:
Returns:
dict: output dict after transformation.
"""
if np.random.rand() > self.prob:
return input_dict

assert 'dataset' in input_dict, \
'`dataset` is needed to pass through PolarMix, while not found.'
Expand Down Expand Up @@ -2513,5 +2518,6 @@ def __repr__(self) -> str:
repr_str += f'(instance_classes={self.instance_classes}, '
repr_str += f'swap_ratio={self.swap_ratio}, '
repr_str += f'rotate_paste_ratio={self.rotate_paste_ratio}, '
repr_str += f'pre_transform={self.pre_transform})'
repr_str += f'pre_transform={self.pre_transform}, '
repr_str += f'prob={self.prob})'
return repr_str

0 comments on commit 6a80afd

Please # to comment.