-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] support mscoco dataset (#1520)
* support reading mscoco dataset with caption * fix lint * add unit test * fix isort * fix lint * fix lint
- Loading branch information
Showing
7 changed files
with
143 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import os | ||
import random | ||
from typing import Optional, Sequence, Union | ||
|
||
import mmengine | ||
from mmengine import FileClient | ||
|
||
from mmedit.registry import DATASETS | ||
from .basic_conditional_dataset import BasicConditionalDataset | ||
|
||
|
||
@DATASETS.register_module() | ||
@DATASETS.register_module('MSCOCO') | ||
class MSCoCoDataset(BasicConditionalDataset): | ||
"""MSCoCo 2014 dataset. | ||
Args: | ||
ann_file (str): Annotation file path. Defaults to ''. | ||
metainfo (dict, optional): Meta information for dataset, such as class | ||
information. Defaults to None. | ||
data_root (str): The root directory for ``data_prefix`` and | ||
``ann_file``. Defaults to ''. | ||
drop_caption_rate (float, optional): Rate of dropping caption, | ||
used for training. Defaults to 0.0. | ||
phase (str, optional): Subdataset used for certain phase, can be set | ||
to `train`, `test` and `val`. Defaults to 'train'. | ||
year (int, optional): Version of CoCo dataset, can be set to 2014 | ||
and 2017. Defaults to 2014. | ||
data_prefix (str | dict): Prefix for the data. Defaults to ''. | ||
extensions (Sequence[str]): A sequence of allowed extensions. Defaults | ||
to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'). | ||
lazy_init (bool): Whether to load annotation during instantiation. | ||
In some cases, such as visualization, only the meta information of | ||
the dataset is needed, which is not necessary to load annotation | ||
file. ``Basedataset`` can skip load annotations to save time by set | ||
``lazy_init=False``. Defaults to False. | ||
**kwargs: Other keyword arguments in :class:`BaseDataset`. | ||
""" | ||
METAINFO = dict(dataset_type='text_image_dataset', task_name='editing') | ||
|
||
def __init__(self, | ||
ann_file: str = '', | ||
metainfo: Optional[dict] = None, | ||
data_root: str = '', | ||
drop_caption_rate=0.0, | ||
phase='train', | ||
year=2014, | ||
data_prefix: Union[str, dict] = '', | ||
extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm', | ||
'.bmp', '.pgm', '.tif'), | ||
lazy_init: bool = False, | ||
classes: Union[str, Sequence[str], None] = None, | ||
**kwargs): | ||
ann_file = os.path.join('annotations', 'captions_' + phase + | ||
f'{year}.json') if ann_file == '' else ann_file | ||
self.image_prename = 'COCO_' + phase + f'{year}_' | ||
self.phase = phase | ||
self.drop_rate = drop_caption_rate | ||
self.year = year | ||
assert self.year == 2014, 'We only support CoCo2014 now.' | ||
|
||
super().__init__( | ||
ann_file=ann_file, | ||
metainfo=metainfo, | ||
data_root=data_root, | ||
data_prefix=data_prefix, | ||
extensions=extensions, | ||
lazy_init=lazy_init, | ||
classes=classes, | ||
**kwargs) | ||
|
||
def load_data_list(self): | ||
"""Load image paths and gt_labels.""" | ||
if self.img_prefix: | ||
file_client = FileClient.infer_client(uri=self.img_prefix) | ||
json_file = mmengine.fileio.io.load(self.ann_file) | ||
|
||
def add_prefix(filename, prefix=''): | ||
if not prefix: | ||
return filename | ||
else: | ||
return file_client.join_path(prefix, filename) | ||
|
||
data_list = [] | ||
for item in json_file['annotations']: | ||
image_name = self.image_prename + str( | ||
item['image_id']).zfill(12) + '.jpg' | ||
img_path = add_prefix( | ||
os.path.join(self.phase + str(self.year), image_name), | ||
self.img_prefix) | ||
caption = item['caption'].lower() | ||
info = { | ||
'img_path': | ||
img_path, | ||
'gt_label': | ||
caption if (self.phase != 'train' or self.drop_rate < 1e-6 | ||
or random.random() >= self.drop_rate) else '' | ||
} | ||
data_list.append(info) | ||
return data_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
{ | ||
"annotations": [{"image_id": 9, "caption": "a good meal"}] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
{ | ||
"annotations": [{"image_id": 42, "caption": "a pair of slippers"}] | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import os | ||
from pathlib import Path | ||
|
||
from mmedit.datasets import MSCoCoDataset | ||
|
||
|
||
class TestMSCoCoDatasets: | ||
|
||
@classmethod | ||
def setup_class(cls): | ||
cls.data_root = Path(__file__).parent.parent / 'data' / 'coco' | ||
|
||
def test_mscoco(self): | ||
|
||
# test basic usage | ||
dataset = MSCoCoDataset(data_root=self.data_root, pipeline=[]) | ||
assert dataset[0] == dict( | ||
gt_label='a good meal', | ||
img_path=os.path.join(self.data_root, 'train2014', | ||
'COCO_train2014_000000000009.jpg'), | ||
sample_idx=0) | ||
|
||
# test with different phase | ||
dataset = MSCoCoDataset( | ||
data_root=self.data_root, phase='val', pipeline=[]) | ||
assert dataset[0] == dict( | ||
gt_label='a pair of slippers', | ||
img_path=os.path.join(self.data_root, 'val2014', | ||
'COCO_val2014_000000000042.jpg'), | ||
sample_idx=0) |