Skip to content

Commit

Permalink
Merge 859d8d3 into dbfb84c
Browse files Browse the repository at this point in the history
  • Loading branch information
InvincibleWyq authored Jun 8, 2023
2 parents dbfb84c + 859d8d3 commit 0e0c42f
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 3 deletions.
81 changes: 81 additions & 0 deletions configs/_base_/datasets/ocrvqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# data settings

data_preprocessor = dict(
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
interpolation='bicubic',
backend='pillow'),
dict(type='CleanCaption', keys=['question', 'gt_answer']),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=[],
),
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(480, 480),
interpolation='bicubic',
backend='pillow'),
dict(type='CleanCaption', keys=['question', 'gt_answer']),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=[],
),
]

train_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='OCRVQA',
data_root='data/ocrvqa',
data_prefix='images',
ann_file='annotations/dataset.json',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
drop_last=True,
)

val_dataloader = dict(
batch_size=64,
num_workers=8,
dataset=dict(
type='OCRVQA',
data_root='data/ocrvqa',
data_prefix='images',
ann_file='annotations/dataset.json',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='VQAAcc')

test_dataloader = dict(
batch_size=64,
num_workers=8,
dataset=dict(
type='OCRVQA',
data_root='data/ocrvqa',
data_prefix='images',
ann_file='annotations/dataset.json',
split='test',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
test_evaluator = dict(type='VQAAcc')
6 changes: 6 additions & 0 deletions configs/blip/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ python tools/test.py configs/blip/blip-base_8xb32_caption.py https://download.op
| :------------------------- | :--------: | :------: | :----------------------------------: | :-------------------------------------------------------------------------------------------------------------------: |
| `blip-base_3rdparty_vqa`\* | 361.48 | 40.59# | [config](./blip-base_8xb32_okvqa.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty-capflit_vqa_20230505-81488941.pth) |

### Visual Question Answering on OCR-VQA

| Model | Params (M) | Accuracy | Config | Download |
| :------------------------- | :--------: | :------: | :-----------------------------------: | :-------------------------------------------------------------------------------------------------------------------: |
| `blip-base_3rdparty_vqa`\* | 361.48 | 28.30# | [config](./blip-base_8xb32_ocrvqa.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty-capflit_vqa_20230505-81488941.pth) |

### Image-To-Text Retrieval on COCO

| Model | Params (M) | Recall@1 | Recall@5 | Config | Download |
Expand Down
75 changes: 75 additions & 0 deletions configs/blip/blip-base_8xb32_ocrvqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
_base_ = [
'../_base_/datasets/ocrvqa.py',
'../_base_/default_runtime.py',
]

# model settings
model = dict(
type='BlipVQA',
tokenizer=dict(type='BlipTokenizer', name_or_path='bert-base-uncased'),
vision_backbone=dict(
type='VisionTransformer',
arch='b',
img_size=480,
patch_size=16,
out_type='raw'),
multimodal_backbone=dict(
type='XBertEncoder',
med_config=dict(
architectures=['BertModel'],
attention_probs_dropout_prob=0.1,
hidden_act='gelu',
hidden_dropout_prob=0.1,
hidden_size=768,
initializer_range=0.02,
intermediate_size=3072,
layer_norm_eps=1e-12,
max_position_embeddings=512,
model_type='bert',
num_attention_heads=12,
num_hidden_layers=12,
pad_token_id=0,
add_type_embeddings=False,
vocab_size=30524,
encoder_width=768,
add_cross_attention=True),
),
head=dict(
type='VQAGenerationHead',
decoder=dict(
type='XBertLMHeadDecoder',
med_config=dict(
architectures=['BertModel'],
attention_probs_dropout_prob=0.1,
hidden_act='gelu',
hidden_dropout_prob=0.1,
hidden_size=768,
initializer_range=0.02,
intermediate_size=3072,
layer_norm_eps=1e-12,
max_position_embeddings=512,
model_type='bert',
num_attention_heads=12,
num_hidden_layers=12,
pad_token_id=0,
add_type_embeddings=False,
vocab_size=30524,
encoder_width=768,
add_cross_attention=True),
),
inference_method='generate',
),
)

# schedule settings
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=0.05)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)

param_scheduler = [dict(type='CosineAnnealingLR', by_epoch=True)]

train_cfg = dict(max_epochs=10, by_epoch=True)
val_cfg = dict()
test_cfg = dict()

# runtime settings
randomness = dict(seed=42)
6 changes: 3 additions & 3 deletions mmpretrain/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@
from .flamingo import FlamingoEvalCOCOCaption, FlamingoEvalCOCOVQA
from .gqa_dataset import GQA
from .nocaps import NoCaps
from .ocr_vqa import OCRVQA
from .refcoco import RefCOCO
from .scienceqa import ScienceQA
from .textvqa import TextVQA
from .visual_genome import VisualGenomeQA

__all__.extend([
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
'FlamingoEvalCOCOVQA', 'RefCOCO', 'VisualGenomeQA', 'ScienceQA',
'NoCaps'
'GQA', 'TextVQA'
'FlamingoEvalCOCOVQA', 'OCRVQA', 'RefCOCO', 'VisualGenomeQA',
'ScienceQA', 'NoCaps', 'GQA', 'TextVQA'
])
91 changes: 91 additions & 0 deletions mmpretrain/datasets/ocr_vqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List

import mmengine
from mmengine.dataset import BaseDataset

from mmpretrain.registry import DATASETS


@DATASETS.register_module()
class OCRVQA(BaseDataset):
"""OCR-VQA dataset.
Args:
data_root (str): The root directory for ``data_prefix``, ``ann_file``
and ``question_file``.
data_prefix (str): The directory of images.
ann_file (str): Annotation file path for training and validation.
split (str): 'train', 'val' or 'test'.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""

def __init__(self, data_root: str, data_prefix: str, ann_file: str,
split: str, **kwarg):

assert split in ['train', 'val', 'test'], \
'`split` must be train, val or test'
self.split = split
super().__init__(
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwarg,
)

def load_data_list(self) -> List[dict]:
"""Load data list."""

split_dict = {1: 'train', 2: 'val', 3: 'test'}

annotations = mmengine.load(self.ann_file)

# ann example
# "761183272": {
# "imageURL": \
# "http://ecx.images-amazon.com/images/I/61Y5cOdHJbL.jpg",
# "questions": [
# "Who wrote this book?",
# "What is the title of this book?",
# "What is the genre of this book?",
# "Is this a games related book?",
# "What is the year printed on this calendar?"],
# "answers": [
# "Sandra Boynton",
# "Mom's Family Wall Calendar 2016",
# "Calendars",
# "No",
# "2016"],
# "title": "Mom's Family Wall Calendar 2016",
# "authorName": "Sandra Boynton",
# "genre": "Calendars",
# "split": 1
# },

data_list = []

for key, ann in annotations.items():
if self.split != split_dict[ann['split']]:
continue

extension = osp.splitext(ann['imageURL'])[1]
if extension not in ['.jpg', '.png']:
continue
img_path = mmengine.join_path(self.data_prefix['img_path'],
key + extension)
for question, answer in zip(ann['questions'], ann['answers']):
data_info = {}
data_info['img_path'] = img_path
data_info['question'] = question
data_info['gt_answer'] = answer
data_info['gt_answer_weight'] = [1.0]

data_info['imageURL'] = ann['imageURL']
data_info['title'] = ann['title']
data_info['authorName'] = ann['authorName']
data_info['genre'] = ann['genre']

data_list.append(data_info)

return data_list

0 comments on commit 0e0c42f

Please # to comment.