Skip to content

Commit

Permalink
[Enhance] Benchmark PAConvCUDA on S3DIS (#847)
Browse files Browse the repository at this point in the history
* add paconv_cuda readme

* add unit test

* fix link

* fix link
  • Loading branch information
Wuziyi616 authored Aug 9, 2021
1 parent 17f3336 commit 057ad6e
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 0 deletions.
2 changes: 2 additions & 0 deletions configs/paconv/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ We implement PAConv and provide the result and checkpoints on S3DIS dataset.
| Method | Split | Lr schd | Mem (GB) | Inf time (fps) | mIoU (Val set) | Download |
| :-------------------------------------------------------------------------: | :----: | :---------: | :------: | :------------: | :------------: | :----------------------: |
| [PAConv (SSG)](./paconv_ssg_8x8_cosine_150e_s3dis_seg-3d-13class.py) | Area_5 | cosine 150e | 5.8 | | 66.65 | [model](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/paconv/paconv_ssg_8x8_cosine_150e_s3dis_seg-3d-13class/paconv_ssg_8x8_cosine_150e_s3dis_seg-3d-13class_20210729_200615-2147b2d1.pth) | [log](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/paconv/paconv_ssg_8x8_cosine_150e_s3dis_seg-3d-13class/paconv_ssg_8x8_cosine_150e_s3dis_seg-3d-13class_20210729_200615.log.json) |
| [PAConv\* (SSG)](./paconv_cuda_ssg_8x8_cosine_200e_s3dis_seg-3d-13class.py) | Area_5 | cosine 200e | 3.8 | | 65.33 | [model](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/paconv/paconv_cuda_ssg_8x8_cosine_200e_s3dis_seg-3d-13class/paconv_cuda_ssg_8x8_cosine_200e_s3dis_seg-3d-13class_20210802_171802-e5ea9bb9.pth) | [log](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/paconv/paconv_cuda_ssg_8x8_cosine_200e_s3dis_seg-3d-13class/paconv_cuda_ssg_8x8_cosine_200e_s3dis_seg-3d-13class_20210802_171802.log.json) |

**Notes:**

- We use XYZ+Color+Normalized_XYZ as input in all the experiments on S3DIS datasets.
- `Area_5` Split means training the model on Area_1, 2, 3, 4, 6 and testing on Area_5.
- PAConv\* stands for the CUDA implementation of PAConv operations. See the [paper](https://arxiv.org/pdf/2103.14635.pdf) appendix section D for more details. In our experiments, the training of PAConv\* is found to be very unstable. We achieved slightly lower mIoU than the result in the paper, but is consistent with the result obtained by running their [official code](https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg). Besides, although the GPU memory consumption of PAConv\* is significantly lower than PAConv, its training and inference speed are actually slower (by ~10%).

## Indeterminism

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
_base_ = [
'../_base_/datasets/s3dis_seg-3d-13class.py',
'../_base_/models/paconv_cuda_ssg.py',
'../_base_/schedules/seg_cosine_150e.py', '../_base_/default_runtime.py'
]

# data settings
class_names = ('ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door',
'table', 'chair', 'sofa', 'bookcase', 'board', 'clutter')
num_points = 4096
train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True),
dict(
type='PointSegClassMapping',
valid_cat_ids=tuple(range(len(class_names))),
max_cat_id=13),
dict(
type='IndoorPatchPointSample',
num_points=num_points,
block_size=1.0,
use_normalized_coord=True,
num_try=10000,
enlarge_size=None,
min_unique_num=num_points // 4,
eps=0.0),
dict(type='NormalizePointsColor', color_mean=None),
dict(
type='GlobalRotScaleTrans',
rot_range=[0.0, 6.283185307179586], # [0, 2 * pi]
scale_ratio_range=[0.8, 1.2],
translation_std=[0, 0, 0]),
dict(
type='RandomJitterPoints',
jitter_std=[0.01, 0.01, 0.01],
clip_range=[-0.05, 0.05]),
dict(type='RandomDropPointsColor', drop_ratio=0.2),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'pts_semantic_mask'])
]

data = dict(samples_per_gpu=8, train=dict(pipeline=train_pipeline))
evaluation = dict(interval=1)

# model settings
model = dict(
decode_head=dict(
num_classes=13, ignore_index=13,
loss_decode=dict(class_weight=None)), # S3DIS doesn't use class_weight
test_cfg=dict(
num_points=4096,
block_size=1.0,
sample_rate=0.5,
use_normalized_coord=True,
batch_size=12))

# runtime settings
runner = dict(max_epochs=200)
72 changes: 72 additions & 0 deletions tests/test_models/test_segmentors.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,75 @@ def test_paconv_ssg():
results = self.forward(return_loss=False, **data_dict)
assert results[0]['semantic_mask'].shape == torch.Size([200])
assert results[1]['semantic_mask'].shape == torch.Size([100])


def test_paconv_cuda_ssg():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')

set_random_seed(0, True)
paconv_cuda_ssg_cfg = _get_segmentor_cfg(
'paconv/paconv_cuda_ssg_8x8_cosine_200e_s3dis_seg-3d-13class.py')
# for GPU memory consideration
paconv_cuda_ssg_cfg.backbone.num_points = (256, 64, 16, 4)
paconv_cuda_ssg_cfg.test_cfg.num_points = 32
self = build_segmentor(paconv_cuda_ssg_cfg).cuda()
points = [torch.rand(1024, 9).float().cuda() for _ in range(2)]
img_metas = [dict(), dict()]
gt_masks = [torch.randint(0, 13, (1024, )).long().cuda() for _ in range(2)]

# test forward_train
losses = self.forward_train(points, img_metas, gt_masks)
assert losses['decode.loss_sem_seg'].item() >= 0
assert losses['regularize.loss_regularize'].item() >= 0

# test forward function
set_random_seed(0, True)
data_dict = dict(
points=points, img_metas=img_metas, pts_semantic_mask=gt_masks)
forward_losses = self.forward(return_loss=True, **data_dict)
assert np.allclose(losses['decode.loss_sem_seg'].item(),
forward_losses['decode.loss_sem_seg'].item())
assert np.allclose(losses['regularize.loss_regularize'].item(),
forward_losses['regularize.loss_regularize'].item())

# test loss with ignore_index
ignore_masks = [torch.ones_like(gt_masks[0]) * 13 for _ in range(2)]
losses = self.forward_train(points, img_metas, ignore_masks)
assert losses['decode.loss_sem_seg'].item() == 0

# test simple_test
self.eval()
with torch.no_grad():
scene_points = [
torch.randn(200, 6).float().cuda() * 3.0,
torch.randn(100, 6).float().cuda() * 2.5
]
results = self.simple_test(scene_points, img_metas)
assert results[0]['semantic_mask'].shape == torch.Size([200])
assert results[1]['semantic_mask'].shape == torch.Size([100])

# test forward function calling simple_test
with torch.no_grad():
data_dict = dict(points=[scene_points], img_metas=[img_metas])
results = self.forward(return_loss=False, **data_dict)
assert results[0]['semantic_mask'].shape == torch.Size([200])
assert results[1]['semantic_mask'].shape == torch.Size([100])

# test aug_test
with torch.no_grad():
scene_points = [
torch.randn(2, 200, 6).float().cuda() * 3.0,
torch.randn(2, 100, 6).float().cuda() * 2.5
]
img_metas = [[dict(), dict()], [dict(), dict()]]
results = self.aug_test(scene_points, img_metas)
assert results[0]['semantic_mask'].shape == torch.Size([200])
assert results[1]['semantic_mask'].shape == torch.Size([100])

# test forward function calling aug_test
with torch.no_grad():
data_dict = dict(points=scene_points, img_metas=img_metas)
results = self.forward(return_loss=False, **data_dict)
assert results[0]['semantic_mask'].shape == torch.Size([200])
assert results[1]['semantic_mask'].shape == torch.Size([100])

0 comments on commit 057ad6e

Please # to comment.