Skip to content

Commit

Permalink
[Feature]: support Multi-Scale-DeformAttention in deformable-detr (#878)
Browse files Browse the repository at this point in the history
* add c++ ms_deform_atten

* fix cpp lint

* fix cpp lint

* clang format

* remove cmakefile

* google style

* clang-format precommit

* use clang-format-lint-action

* add transformer base class

* add merge

* add docstr

* add pyargs

* fix according to commments

* resiger module

* change to use basemodule

* add _ between build function

* split the name

* fix according to comments

* fix lint and fix unitest

* fix cpp lint

* fix bug of deformdetr_atten

* fix drop out

* fix residual

* use CUDA_1D_KERNEL_LOOP
  • Loading branch information
jshilong authored Apr 23, 2021
1 parent 0dd0c49 commit 54a7ebb
Show file tree
Hide file tree
Showing 7 changed files with 1,732 additions and 4 deletions.
201 changes: 200 additions & 1 deletion mmcv/cnn/bricks/transformer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import copy
import math
import warnings

import torch
import torch.nn as nn

from mmcv import ConfigDict
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer
from mmcv.cnn import (Linear, build_activation_layer, build_norm_layer,
constant_init, xavier_init)
from mmcv.ops.multi_scale_deform_attn import (
MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch)
from mmcv.runner.base_module import BaseModule
from mmcv.utils import build_from_cfg
from .registry import (ATTENTION, POSITIONAL_ENCODING, TRANSFORMER_LAYER,
Expand Down Expand Up @@ -135,6 +139,201 @@ def forward(self,
return residual + self.dropout(out)


@ATTENTION.register_module()
class MultiScaleDeformableAttention(BaseModule):
"""An attention module used in Deformable-Detr. `Deformable DETR:
Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_residual`.
Default: 0..
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""

def __init__(self,
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
im2col_step=64,
dropout=0.1,
norm_cfg=None,
init_cfg=None):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
f'but got {embed_dims} and {num_heads}')
dim_per_head = embed_dims // num_heads
self.norm_cfg = norm_cfg
self.init_cfg = init_cfg
self.dropout = nn.Dropout(dropout)

# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format(
n, type(n)))
return (n & (n - 1) == 0) and n != 0

if not _is_power_of_2(dim_per_head):
warnings.warn(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.')

self.im2col_step = im2col_step
self.embed_dims = embed_dims
self.num_levels = num_levels
self.num_heads = num_heads
self.num_points = num_points
self.sampling_offsets = nn.Linear(
embed_dims, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims,
num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weight()

def init_weight(self):
"""Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.)
thetas = torch.arange(
self.num_heads,
dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view(
self.num_heads, 1, 1,
2).repeat(1, self.num_levels, self.num_points, 1)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1

self.sampling_offsets.bias.data = grid_init.view(-1)
constant_init(self.attention_weights, val=0., bias=0.)
xavier_init(self.value_proj, distribution='uniform', bias=0.)
xavier_init(self.output_proj, distribution='uniform', bias=0.)

def forward(self,
query,
key,
value,
residual=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
residual (Tensor): The tensor used for addition, with the
same shape as `x`. Default None. If None, `x` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different level. With shape (num_levels, 2),
last dimension represent (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""

if key is None:
key = query
if value is None:
value = key

if residual is None:
inp_residual = query
if query_pos is not None:
query = query + query_pos

# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)

bs, num_query, _ = query.shape
bs, num_key, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_key

value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
value = value.view(bs, num_key, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points)
attention_weights = attention_weights.softmax(-1)

attention_weights = attention_weights.view(bs, num_query,
self.num_heads,
self.num_levels,
self.num_points)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \
/ offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points \
* reference_points[:, :, None, :, None, 2:] \
* 0.5
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available():
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
output = self.output_proj(output).permute(1, 0, 2)
# (num_query, bs ,embed_dims)
return self.dropout(output) + inp_residual


class FFN(BaseModule):
"""Implements feed-forward networks (FFNs) with residual connection.
Expand Down
Loading

0 comments on commit 54a7ebb

Please # to comment.