From 4ff136163ee4e6436e7f5c44f73c4a4932aa7657 Mon Sep 17 00:00:00 2001 From: chriscarving <129262003+chriscarving@users.noreply.github.com> Date: Wed, 12 Apr 2023 16:20:58 +0800 Subject: [PATCH] [Fix] Update pre-commit-config-zh-cn.yaml and add typehints for PointNet2SAMSG (#2396) --- .pre-commit-config-zh-cn.yaml | 14 ++--- mmdet3d/models/backbones/pointnet2_sa_msg.py | 54 ++++++++++++------- .../models/layers/pointnet_modules/builder.py | 4 +- 3 files changed, 43 insertions(+), 29 deletions(-) diff --git a/.pre-commit-config-zh-cn.yaml b/.pre-commit-config-zh-cn.yaml index 5b7a03e46..1c78ad1d4 100644 --- a/.pre-commit-config-zh-cn.yaml +++ b/.pre-commit-config-zh-cn.yaml @@ -1,4 +1,3 @@ -exclude: ^tests/data/ repos: - repo: https://gitee.com/openmmlab/mirrors-flake8 rev: 5.0.4 @@ -25,6 +24,10 @@ repos: args: ["--remove"] - id: mixed-line-ending args: ["--fix=lf"] + - repo: https://gitee.com/openmmlab/mirrors-codespell + rev: v2.2.1 + hooks: + - id: codespell - repo: https://gitee.com/openmmlab/mirrors-mdformat rev: 0.7.9 hooks: @@ -34,20 +37,11 @@ repos: - mdformat-openmmlab - mdformat_frontmatter - linkify-it-py - - repo: https://gitee.com/openmmlab/mirrors-codespell - rev: v2.2.1 - hooks: - - id: codespell - repo: https://gitee.com/openmmlab/mirrors-docformatter rev: v1.3.1 hooks: - id: docformatter args: ["--in-place", "--wrap-descriptions", "79"] - - repo: https://gitee.com/openmmlab/mirrors-pyupgrade - rev: v3.0.0 - hooks: - - id: pyupgrade - args: ["--py36-plus"] - repo: https://gitee.com/openmmlab/pre-commit-hooks rev: v0.2.0 hooks: diff --git a/mmdet3d/models/backbones/pointnet2_sa_msg.py b/mmdet3d/models/backbones/pointnet2_sa_msg.py index 18bfae769..6675f2848 100644 --- a/mmdet3d/models/backbones/pointnet2_sa_msg.py +++ b/mmdet3d/models/backbones/pointnet2_sa_msg.py @@ -1,12 +1,19 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + import torch from mmcv.cnn import ConvModule from torch import nn as nn from mmdet3d.models.layers.pointnet_modules import build_sa_module from mmdet3d.registry import MODELS +from mmdet3d.utils import OptConfigType from .base_pointnet import BasePointNet +ThreeTupleIntType = Tuple[Tuple[Tuple[int, int, int]]] +TwoTupleIntType = Tuple[Tuple[int, int, int]] +TwoTupleStrType = Tuple[Tuple[str]] + @MODELS.register_module() class PointNet2SAMSG(BasePointNet): @@ -22,7 +29,7 @@ class PointNet2SAMSG(BasePointNet): sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module. aggregation_channels (tuple[int]): Out channels of aggregation multi-scale grouping features. - fps_mods (tuple[int]): Mod of FPS for each SA module. + fps_mods Sequence[Tuple[str]]: Mod of FPS for each SA module. fps_sample_range_lists (tuple[tuple[int]]): The number of sampling points which each SA module samples. dilated_group (tuple[bool]): Whether to use dilated ball query for @@ -38,26 +45,37 @@ class PointNet2SAMSG(BasePointNet): """ def __init__(self, - in_channels, - num_points=(2048, 1024, 512, 256), - radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)), - num_samples=((32, 32, 64), (32, 32, 64), (32, 32, 32)), - sa_channels=(((16, 16, 32), (16, 16, 32), (32, 32, 64)), - ((64, 64, 128), (64, 64, 128), (64, 96, 128)), - ((128, 128, 256), (128, 192, 256), (128, 256, - 256))), - aggregation_channels=(64, 128, 256), - fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')), - fps_sample_range_lists=((-1), (-1), (512, -1)), - dilated_group=(True, True, True), - out_indices=(2, ), - norm_cfg=dict(type='BN2d'), - sa_cfg=dict( + in_channels: int, + num_points: Tuple[int] = (2048, 1024, 512, 256), + radii: Tuple[Tuple[float, float, float]] = ( + (0.2, 0.4, 0.8), + (0.4, 0.8, 1.6), + (1.6, 3.2, 4.8), + ), + num_samples: TwoTupleIntType = ((32, 32, 64), (32, 32, 64), + (32, 32, 32)), + sa_channels: ThreeTupleIntType = (((16, 16, 32), (16, 16, 32), + (32, 32, 64)), + ((64, 64, 128), + (64, 64, 128), (64, 96, + 128)), + ((128, 128, 256), + (128, 192, 256), (128, 256, + 256))), + aggregation_channels: Tuple[int] = (64, 128, 256), + fps_mods: TwoTupleStrType = (('D-FPS'), ('FS'), ('F-FPS', + 'D-FPS')), + fps_sample_range_lists: TwoTupleIntType = ((-1), (-1), (512, + -1)), + dilated_group: Tuple[bool] = (True, True, True), + out_indices: Tuple[int] = (2, ), + norm_cfg: dict = dict(type='BN2d'), + sa_cfg: dict = dict( type='PointSAModuleMSG', pool_mod='max', use_xyz=True, normalize_xyz=False), - init_cfg=None): + init_cfg: OptConfigType = None): super().__init__(init_cfg=init_cfg) self.num_sa = len(sa_channels) self.out_indices = out_indices @@ -123,7 +141,7 @@ def __init__(self, bias=True)) sa_in_channel = cur_aggregation_channel - def forward(self, points): + def forward(self, points: torch.Tensor): """Forward pass. Args: diff --git a/mmdet3d/models/layers/pointnet_modules/builder.py b/mmdet3d/models/layers/pointnet_modules/builder.py index 617a8942e..2274f9c6c 100644 --- a/mmdet3d/models/layers/pointnet_modules/builder.py +++ b/mmdet3d/models/layers/pointnet_modules/builder.py @@ -4,7 +4,9 @@ from mmengine.registry import Registry from torch import nn as nn -SA_MODULES = Registry('point_sa_module') +SA_MODULES = Registry( + name='point_sa_module', + locations=['mmdet3d.models.layers.pointnet_modules']) def build_sa_module(cfg: Union[dict, None], *args, **kwargs) -> nn.Module: