Skip to content

Commit e394e2a

Browse files
CodeCamp #1555[Feature] Support Mapillary Vistas Dataset (#2484)
## Support `Mapillary Vistas Dataset` ## Motivation Support **`Mapillary Vistas Dataset`** Dataset Paper link : https://ieeexplore.ieee.org/document/9878466/ Download and more information view https://www.mapillary.com/dataset/vistas ``` @InProceedings{Neuhold_2017_ICCV, author = {Neuhold, Gerhard and Ollmann, Tobias and Rota Bulo, Samuel and Kontschieder, Peter}, title = {The Mapillary Vistas Dataset for Semantic Understanding of Street Scenes}, booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)}, month = {Oct}, year = {2017} } ``` ## Modification Add `Mapillary_dataset` in `mmsegmentation/projects` Add `configs/_base_/mapillary_v1_2.py` and `configs/_base_/mapillary_v2_0.py` Add `configs/deeplabv3plus_r18-d8_4xb2-80k_mapillay-512x1024.py` to test training and testing on Mapillary datasets Add `docs/en/user_guides/2_dataset_prepare.md` , add Mapillary Vistas Dataset Preparing and Structure. Add `tools/dataset_converters/mapillary.py` to convert RGB labels to Mask labels. Co-authored-by: 谢昕辰 <xiexinch@outlook.com>
1 parent f678a5c commit e394e2a

File tree

8 files changed

+867
-0
lines changed

8 files changed

+867
-0
lines changed

Diff for: projects/mapillary_dataset/README.md

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Mapillary Vistas Dataset
2+
3+
Support **`Mapillary Vistas Dataset`**
4+
5+
## Description
6+
7+
Author: AI-Tianlong
8+
9+
This project implements **`Mapillary Vistas Dataset`**
10+
11+
### Dataset preparing
12+
13+
Preparing `Mapillary Vistas Dataset` dataset following [Mapillary Vistas Dataset Preparing Guide](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/projects/mapillary_dataset/docs/en/user_guides/2_dataset_prepare.md)
14+
15+
```none
16+
mmsegmentation
17+
├── mmseg
18+
├── tools
19+
├── configs
20+
├── data
21+
│ ├── mapillary
22+
│ │ ├── training
23+
│ │ │ ├── images
24+
│ │ │ ├── v1.2
25+
| │ │ │ ├── instances
26+
| │ │ │ ├── labels
27+
| │ │ │ ├── labels_mask
28+
| │   │   │ └── panoptic
29+
│ │ │ ├── v2.0
30+
| │ │ │ ├── instances
31+
| │ │ │ ├── labels
32+
| │ │ │ ├── labels_mask
33+
| │ │ │ ├── panoptic
34+
| │   │   │ └── polygons
35+
│ │ ├── validation
36+
│ │ │ ├── images
37+
| │ │ │ ├── instances
38+
| │ │ │ ├── labels
39+
| │ │ │ ├── labels_mask
40+
| │   │   │ └── panoptic
41+
│ │ │ ├── v2.0
42+
| │ │ │ ├── instances
43+
| │ │ │ ├── labels
44+
| │ │ │ ├── labels_mask
45+
| │ │ │ ├── panoptic
46+
| │   │   │ └── polygons
47+
```
48+
49+
### Training commands with `deeplabv3plus_r101-d8_4xb2-240k_mapillay-512x1024.py`
50+
51+
```bash
52+
# Dataset train commands
53+
# at `mmsegmentation` folder
54+
bash tools/dist_train.sh projects/mapillary_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_mapillay-512x1024.py 4
55+
```
56+
57+
## Checklist
58+
59+
- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
60+
61+
- [x] Finish the code
62+
63+
- [x] Basic docstrings & proper citation
64+
65+
- [ ] Test-time correctness
66+
67+
- [x] A full README
68+
69+
- [ ] Milestone 2: Indicates a successful model implementation.
70+
71+
- [ ] Training-time correctness
72+
73+
- [ ] Milestone 3: Good to be a part of our core package!
74+
75+
- [ ] Type hints and docstrings
76+
77+
- [ ] Unit tests
78+
79+
- [ ] Code polishing
80+
81+
- [ ] Metafile.yml
82+
83+
- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
84+
85+
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# dataset settings
2+
dataset_type = 'MapillaryDataset_v1_2'
3+
data_root = 'data/mapillary/'
4+
crop_size = (512, 1024)
5+
train_pipeline = [
6+
dict(type='LoadImageFromFile'),
7+
dict(type='LoadAnnotations'),
8+
dict(
9+
type='RandomResize',
10+
scale=(2048, 1024),
11+
ratio_range=(0.5, 2.0),
12+
keep_ratio=True),
13+
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
14+
dict(type='RandomFlip', prob=0.5),
15+
dict(type='PhotoMetricDistortion'),
16+
dict(type='PackSegInputs')
17+
]
18+
test_pipeline = [
19+
dict(type='LoadImageFromFile'),
20+
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
21+
# add loading annotation after ``Resize`` because ground truth
22+
# does not need to do resize data transform
23+
dict(type='LoadAnnotations'),
24+
dict(type='PackSegInputs')
25+
]
26+
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
27+
tta_pipeline = [
28+
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
29+
dict(
30+
type='TestTimeAug',
31+
transforms=[
32+
[
33+
dict(type='Resize', scale_factor=r, keep_ratio=True)
34+
for r in img_ratios
35+
],
36+
[
37+
dict(type='RandomFlip', prob=0., direction='horizontal'),
38+
dict(type='RandomFlip', prob=1., direction='horizontal')
39+
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
40+
])
41+
]
42+
train_dataloader = dict(
43+
batch_size=2,
44+
num_workers=4,
45+
persistent_workers=True,
46+
sampler=dict(type='InfiniteSampler', shuffle=True),
47+
dataset=dict(
48+
type=dataset_type,
49+
data_root=data_root,
50+
data_prefix=dict(
51+
img_path='training/images',
52+
seg_map_path='training/v1.2/labels_mask'),
53+
pipeline=train_pipeline))
54+
val_dataloader = dict(
55+
batch_size=1,
56+
num_workers=4,
57+
persistent_workers=True,
58+
sampler=dict(type='DefaultSampler', shuffle=False),
59+
dataset=dict(
60+
type=dataset_type,
61+
data_root=data_root,
62+
data_prefix=dict(
63+
img_path='validation/images',
64+
seg_map_path='validation/v1.2/labels_mask'),
65+
pipeline=test_pipeline))
66+
test_dataloader = val_dataloader
67+
68+
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
69+
test_evaluator = val_evaluator
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# dataset settings
2+
dataset_type = 'MapillaryDataset_v2_0'
3+
data_root = 'data/mapillary/'
4+
crop_size = (512, 1024)
5+
train_pipeline = [
6+
dict(type='LoadImageFromFile'),
7+
dict(type='LoadAnnotations'),
8+
dict(
9+
type='RandomResize',
10+
scale=(2048, 1024),
11+
ratio_range=(0.5, 2.0),
12+
keep_ratio=True),
13+
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
14+
dict(type='RandomFlip', prob=0.5),
15+
dict(type='PhotoMetricDistortion'),
16+
dict(type='PackSegInputs')
17+
]
18+
test_pipeline = [
19+
dict(type='LoadImageFromFile'),
20+
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
21+
# add loading annotation after ``Resize`` because ground truth
22+
# does not need to do resize data transform
23+
dict(type='LoadAnnotations'),
24+
dict(type='PackSegInputs')
25+
]
26+
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
27+
tta_pipeline = [
28+
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
29+
dict(
30+
type='TestTimeAug',
31+
transforms=[
32+
[
33+
dict(type='Resize', scale_factor=r, keep_ratio=True)
34+
for r in img_ratios
35+
],
36+
[
37+
dict(type='RandomFlip', prob=0., direction='horizontal'),
38+
dict(type='RandomFlip', prob=1., direction='horizontal')
39+
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
40+
])
41+
]
42+
train_dataloader = dict(
43+
batch_size=2,
44+
num_workers=4,
45+
persistent_workers=True,
46+
sampler=dict(type='InfiniteSampler', shuffle=True),
47+
dataset=dict(
48+
type=dataset_type,
49+
data_root=data_root,
50+
data_prefix=dict(
51+
img_path='training/images',
52+
seg_map_path='training/v2.0/labels_mask'),
53+
pipeline=train_pipeline))
54+
val_dataloader = dict(
55+
batch_size=1,
56+
num_workers=4,
57+
persistent_workers=True,
58+
sampler=dict(type='DefaultSampler', shuffle=False),
59+
dataset=dict(
60+
type=dataset_type,
61+
data_root=data_root,
62+
data_prefix=dict(
63+
img_path='validation/images',
64+
seg_map_path='validation/v2.0/labels_mask'),
65+
pipeline=test_pipeline))
66+
test_dataloader = val_dataloader
67+
68+
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
69+
test_evaluator = val_evaluator
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
_base_ = ['./_base_/datasets/mapillary_v1_2.py'] # v 1.2 labels
2+
# _base_ = ['./_base_/datasets/mapillary_v2_0.py'] # v2.0 labels
3+
custom_imports = dict(imports=[
4+
'projects.mapillary_dataset.mmseg.datasets.mapillary_v1_2',
5+
'projects.mapillary_dataset.mmseg.datasets.mapillary_v2_0',
6+
])
7+
8+
norm_cfg = dict(type='SyncBN', requires_grad=True)
9+
data_preprocessor = dict(
10+
type='SegDataPreProcessor',
11+
mean=[123.675, 116.28, 103.53],
12+
std=[58.395, 57.12, 57.375],
13+
bgr_to_rgb=True,
14+
pad_val=0,
15+
seg_pad_val=255,
16+
size=(512, 1024))
17+
18+
model = dict(
19+
type='EncoderDecoder',
20+
data_preprocessor=data_preprocessor,
21+
pretrained=None,
22+
backbone=dict(
23+
type='ResNet',
24+
depth=101,
25+
num_stages=4,
26+
out_indices=(0, 1, 2, 3),
27+
dilations=(1, 1, 2, 4),
28+
strides=(1, 2, 1, 1),
29+
norm_cfg=norm_cfg,
30+
norm_eval=False,
31+
style='pytorch',
32+
contract_dilation=True),
33+
decode_head=dict(
34+
type='DepthwiseSeparableASPPHead',
35+
in_channels=2048,
36+
in_index=3,
37+
channels=512,
38+
dilations=(1, 12, 24, 36),
39+
c1_in_channels=256,
40+
c1_channels=48,
41+
dropout_ratio=0.1,
42+
num_classes=66, # v1.2
43+
# num_classes=124, # v2.0
44+
norm_cfg=norm_cfg,
45+
align_corners=False,
46+
loss_decode=dict(
47+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
48+
auxiliary_head=dict(
49+
type='FCNHead',
50+
in_channels=1024,
51+
in_index=2,
52+
channels=256,
53+
num_convs=1,
54+
concat_input=False,
55+
dropout_ratio=0.1,
56+
num_classes=66, # v1.2
57+
# num_classes=124, # v2.0
58+
norm_cfg=norm_cfg,
59+
align_corners=False,
60+
loss_decode=dict(
61+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
62+
train_cfg=dict(),
63+
test_cfg=dict(mode='whole'))
64+
default_scope = 'mmseg'
65+
env_cfg = dict(
66+
cudnn_benchmark=True,
67+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
68+
dist_cfg=dict(backend='nccl'))
69+
vis_backends = [dict(type='LocalVisBackend')]
70+
visualizer = dict(
71+
type='SegLocalVisualizer',
72+
vis_backends=[dict(type='LocalVisBackend')],
73+
name='visualizer')
74+
log_processor = dict(by_epoch=False)
75+
log_level = 'INFO'
76+
load_from = None
77+
resume = False
78+
tta_model = dict(type='SegTTAModel')
79+
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
80+
optim_wrapper = dict(
81+
type='OptimWrapper',
82+
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001),
83+
clip_grad=None)
84+
param_scheduler = [
85+
dict(
86+
type='PolyLR',
87+
eta_min=0.0001,
88+
power=0.9,
89+
begin=0,
90+
end=240000,
91+
by_epoch=False)
92+
]
93+
train_cfg = dict(
94+
type='IterBasedTrainLoop', max_iters=240000, val_interval=24000)
95+
val_cfg = dict(type='ValLoop')
96+
test_cfg = dict(type='TestLoop')
97+
default_hooks = dict(
98+
timer=dict(type='IterTimerHook'),
99+
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
100+
param_scheduler=dict(type='ParamSchedulerHook'),
101+
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=24000),
102+
sampler_seed=dict(type='DistSamplerSeedHook'),
103+
visualization=dict(type='SegVisualizationHook'))

0 commit comments

Comments
 (0)