-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
256 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# data settings | ||
|
||
data_preprocessor = dict( | ||
mean=[122.770938, 116.7460125, 104.09373615], | ||
std=[68.5005327, 66.6321579, 70.32316305], | ||
to_rgb=True, | ||
) | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='RandomResizedCrop', | ||
scale=384, | ||
interpolation='bicubic', | ||
backend='pillow'), | ||
dict(type='CleanCaption', keys=['question', 'gt_answer']), | ||
dict( | ||
type='PackInputs', | ||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], | ||
meta_keys=[], | ||
), | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='Resize', | ||
scale=(480, 480), | ||
interpolation='bicubic', | ||
backend='pillow'), | ||
dict(type='CleanCaption', keys=['question', 'gt_answer']), | ||
dict( | ||
type='PackInputs', | ||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], | ||
meta_keys=[], | ||
), | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=16, | ||
num_workers=8, | ||
dataset=dict( | ||
type='OCRVQA', | ||
data_root='data/ocrvqa', | ||
data_prefix='images', | ||
ann_file='annotations/dataset.json', | ||
split='train', | ||
pipeline=train_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
persistent_workers=True, | ||
drop_last=True, | ||
) | ||
|
||
val_dataloader = dict( | ||
batch_size=64, | ||
num_workers=8, | ||
dataset=dict( | ||
type='OCRVQA', | ||
data_root='data/ocrvqa', | ||
data_prefix='images', | ||
ann_file='annotations/dataset.json', | ||
split='val', | ||
pipeline=test_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
persistent_workers=True, | ||
) | ||
val_evaluator = dict(type='VQAAcc') | ||
|
||
test_dataloader = dict( | ||
batch_size=64, | ||
num_workers=8, | ||
dataset=dict( | ||
type='OCRVQA', | ||
data_root='data/ocrvqa', | ||
data_prefix='images', | ||
ann_file='annotations/dataset.json', | ||
split='test', | ||
pipeline=test_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
) | ||
test_evaluator = dict(type='VQAAcc') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
_base_ = [ | ||
'../_base_/datasets/ocrvqa.py', | ||
'../_base_/default_runtime.py', | ||
] | ||
|
||
# model settings | ||
model = dict( | ||
type='BlipVQA', | ||
tokenizer=dict(type='BlipTokenizer', name_or_path='bert-base-uncased'), | ||
vision_backbone=dict( | ||
type='VisionTransformer', | ||
arch='b', | ||
img_size=480, | ||
patch_size=16, | ||
out_type='raw'), | ||
multimodal_backbone=dict( | ||
type='XBertEncoder', | ||
med_config=dict( | ||
architectures=['BertModel'], | ||
attention_probs_dropout_prob=0.1, | ||
hidden_act='gelu', | ||
hidden_dropout_prob=0.1, | ||
hidden_size=768, | ||
initializer_range=0.02, | ||
intermediate_size=3072, | ||
layer_norm_eps=1e-12, | ||
max_position_embeddings=512, | ||
model_type='bert', | ||
num_attention_heads=12, | ||
num_hidden_layers=12, | ||
pad_token_id=0, | ||
add_type_embeddings=False, | ||
vocab_size=30524, | ||
encoder_width=768, | ||
add_cross_attention=True), | ||
), | ||
head=dict( | ||
type='VQAGenerationHead', | ||
decoder=dict( | ||
type='XBertLMHeadDecoder', | ||
med_config=dict( | ||
architectures=['BertModel'], | ||
attention_probs_dropout_prob=0.1, | ||
hidden_act='gelu', | ||
hidden_dropout_prob=0.1, | ||
hidden_size=768, | ||
initializer_range=0.02, | ||
intermediate_size=3072, | ||
layer_norm_eps=1e-12, | ||
max_position_embeddings=512, | ||
model_type='bert', | ||
num_attention_heads=12, | ||
num_hidden_layers=12, | ||
pad_token_id=0, | ||
add_type_embeddings=False, | ||
vocab_size=30524, | ||
encoder_width=768, | ||
add_cross_attention=True), | ||
), | ||
inference_method='generate', | ||
), | ||
) | ||
|
||
# schedule settings | ||
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=0.05) | ||
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer) | ||
|
||
param_scheduler = [dict(type='CosineAnnealingLR', by_epoch=True)] | ||
|
||
train_cfg = dict(max_epochs=10, by_epoch=True) | ||
val_cfg = dict() | ||
test_cfg = dict() | ||
|
||
# runtime settings | ||
randomness = dict(seed=42) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import os.path as osp | ||
from typing import List | ||
|
||
import mmengine | ||
from mmengine.dataset import BaseDataset | ||
|
||
from mmpretrain.registry import DATASETS | ||
|
||
|
||
@DATASETS.register_module() | ||
class OCRVQA(BaseDataset): | ||
"""OCR-VQA dataset. | ||
Args: | ||
data_root (str): The root directory for ``data_prefix``, ``ann_file`` | ||
and ``question_file``. | ||
data_prefix (str): The directory of images. | ||
ann_file (str): Annotation file path for training and validation. | ||
split (str): 'train', 'val' or 'test'. | ||
**kwargs: Other keyword arguments in :class:`BaseDataset`. | ||
""" | ||
|
||
def __init__(self, data_root: str, data_prefix: str, ann_file: str, | ||
split: str, **kwarg): | ||
|
||
assert split in ['train', 'val', 'test'], \ | ||
'`split` must be train, val or test' | ||
self.split = split | ||
super().__init__( | ||
data_root=data_root, | ||
data_prefix=dict(img_path=data_prefix), | ||
ann_file=ann_file, | ||
**kwarg, | ||
) | ||
|
||
def load_data_list(self) -> List[dict]: | ||
"""Load data list.""" | ||
|
||
split_dict = {1: 'train', 2: 'val', 3: 'test'} | ||
|
||
annotations = mmengine.load(self.ann_file) | ||
|
||
# ann example | ||
# "761183272": { | ||
# "imageURL": \ | ||
# "http://ecx.images-amazon.com/images/I/61Y5cOdHJbL.jpg", | ||
# "questions": [ | ||
# "Who wrote this book?", | ||
# "What is the title of this book?", | ||
# "What is the genre of this book?", | ||
# "Is this a games related book?", | ||
# "What is the year printed on this calendar?"], | ||
# "answers": [ | ||
# "Sandra Boynton", | ||
# "Mom's Family Wall Calendar 2016", | ||
# "Calendars", | ||
# "No", | ||
# "2016"], | ||
# "title": "Mom's Family Wall Calendar 2016", | ||
# "authorName": "Sandra Boynton", | ||
# "genre": "Calendars", | ||
# "split": 1 | ||
# }, | ||
|
||
data_list = [] | ||
|
||
for key, ann in annotations.items(): | ||
if self.split != split_dict[ann['split']]: | ||
continue | ||
|
||
extension = osp.splitext(ann['imageURL'])[1] | ||
if extension not in ['.jpg', '.png']: | ||
continue | ||
img_path = mmengine.join_path(self.data_prefix['img_path'], | ||
key + extension) | ||
for question, answer in zip(ann['questions'], ann['answers']): | ||
data_info = {} | ||
data_info['img_path'] = img_path | ||
data_info['question'] = question | ||
data_info['gt_answer'] = answer | ||
data_info['gt_answer_weight'] = [1.0] | ||
|
||
data_info['imageURL'] = ann['imageURL'] | ||
data_info['title'] = ann['title'] | ||
data_info['authorName'] = ann['authorName'] | ||
data_info['genre'] = ann['genre'] | ||
|
||
data_list.append(data_info) | ||
|
||
return data_list |