diff --git a/mmcv/cnn/bricks/transformer.py b/mmcv/cnn/bricks/transformer.py index ff8b2c2d1b..e5f16d5664 100644 --- a/mmcv/cnn/bricks/transformer.py +++ b/mmcv/cnn/bricks/transformer.py @@ -1,11 +1,15 @@ import copy +import math import warnings import torch import torch.nn as nn from mmcv import ConfigDict -from mmcv.cnn import Linear, build_activation_layer, build_norm_layer +from mmcv.cnn import (Linear, build_activation_layer, build_norm_layer, + constant_init, xavier_init) +from mmcv.ops.multi_scale_deform_attn import ( + MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch) from mmcv.runner.base_module import BaseModule from mmcv.utils import build_from_cfg from .registry import (ATTENTION, POSITIONAL_ENCODING, TRANSFORMER_LAYER, @@ -135,6 +139,201 @@ def forward(self, return residual + self.dropout(out) +@ATTENTION.register_module() +class MultiScaleDeformableAttention(BaseModule): + """An attention module used in Deformable-Detr. `Deformable DETR: + Deformable Transformers for End-to-End Object Detection. + + `_. + + Args: + embed_dims (int): The embedding dimension of Attention. + Default: 256. + num_heads (int): Parallel attention heads. Default: 64. + num_levels (int): The number of feature map used in + Attention. Default: 4. + num_points (int): The number of sampling points for + each query in each head. Default: 4. + im2col_step (int): The step used in image_to_column. + Default: 64. + dropout (float): A Dropout layer on `inp_residual`. + Default: 0.. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims=256, + num_heads=8, + num_levels=4, + num_points=4, + im2col_step=64, + dropout=0.1, + norm_cfg=None, + init_cfg=None): + super().__init__(init_cfg) + if embed_dims % num_heads != 0: + raise ValueError(f'embed_dims must be divisible by num_heads, ' + f'but got {embed_dims} and {num_heads}') + dim_per_head = embed_dims // num_heads + self.norm_cfg = norm_cfg + self.init_cfg = init_cfg + self.dropout = nn.Dropout(dropout) + + # you'd better set dim_per_head to a power of 2 + # which is more efficient in the CUDA implementation + def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError( + 'invalid input for _is_power_of_2: {} (type: {})'.format( + n, type(n))) + return (n & (n - 1) == 0) and n != 0 + + if not _is_power_of_2(dim_per_head): + warnings.warn( + "You'd better set embed_dims in " + 'MultiScaleDeformAttention to make ' + 'the dimension of each attention head a power of 2 ' + 'which is more efficient in our CUDA implementation.') + + self.im2col_step = im2col_step + self.embed_dims = embed_dims + self.num_levels = num_levels + self.num_heads = num_heads + self.num_points = num_points + self.sampling_offsets = nn.Linear( + embed_dims, num_heads * num_levels * num_points * 2) + self.attention_weights = nn.Linear(embed_dims, + num_heads * num_levels * num_points) + self.value_proj = nn.Linear(embed_dims, embed_dims) + self.output_proj = nn.Linear(embed_dims, embed_dims) + self.init_weight() + + def init_weight(self): + """Default initialization for Parameters of Module.""" + constant_init(self.sampling_offsets, 0.) + thetas = torch.arange( + self.num_heads, + dtype=torch.float32) * (2.0 * math.pi / self.num_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = (grid_init / + grid_init.abs().max(-1, keepdim=True)[0]).view( + self.num_heads, 1, 1, + 2).repeat(1, self.num_levels, self.num_points, 1) + for i in range(self.num_points): + grid_init[:, :, i, :] *= i + 1 + + self.sampling_offsets.bias.data = grid_init.view(-1) + constant_init(self.attention_weights, val=0., bias=0.) + xavier_init(self.value_proj, distribution='uniform', bias=0.) + xavier_init(self.output_proj, distribution='uniform', bias=0.) + + def forward(self, + query, + key, + value, + residual=None, + query_pos=None, + key_padding_mask=None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + **kwargs): + """Forward Function of MultiScaleDeformAttention. + + Args: + query (Tensor): Query of Transformer with shape + (num_query, bs, embed_dims). + key (Tensor): The key tensor with shape + `(num_key, bs, embed_dims)`. + value (Tensor): The value tensor with shape + `(num_key, bs, embed_dims)`. + residual (Tensor): The tensor used for addition, with the + same shape as `x`. Default None. If None, `x` will be used. + query_pos (Tensor): The positional encoding for `query`. + Default: None. + key_pos (Tensor): The positional encoding for `key`. Default + None. + reference_points (Tensor): The normalized reference + points with shape (bs, num_query, num_levels, 2), + all elements is range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area. + or (N, Length_{query}, num_levels, 4), add + additional two dimensions is (w, h) to + form reference boxes. + key_padding_mask (Tensor): ByteTensor for `query`, with + shape [bs, num_key]. + spatial_shapes (Tensor): Spatial shape of features in + different level. With shape (num_levels, 2), + last dimension represent (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + + Returns: + Tensor: forwarded results with shape [num_query, bs, embed_dims]. + """ + + if key is None: + key = query + if value is None: + value = key + + if residual is None: + inp_residual = query + if query_pos is not None: + query = query + query_pos + + # change to (bs, num_query ,embed_dims) + query = query.permute(1, 0, 2) + value = value.permute(1, 0, 2) + + bs, num_query, _ = query.shape + bs, num_key, _ = value.shape + assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_key + + value = self.value_proj(value) + if key_padding_mask is not None: + value = value.masked_fill(key_padding_mask[..., None], 0.0) + value = value.view(bs, num_key, self.num_heads, -1) + sampling_offsets = self.sampling_offsets(query).view( + bs, num_query, self.num_heads, self.num_levels, self.num_points, 2) + attention_weights = self.attention_weights(query).view( + bs, num_query, self.num_heads, self.num_levels * self.num_points) + attention_weights = attention_weights.softmax(-1) + + attention_weights = attention_weights.view(bs, num_query, + self.num_heads, + self.num_levels, + self.num_points) + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack( + [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = reference_points[:, :, None, :, None, :] \ + + sampling_offsets \ + / offset_normalizer[None, None, None, :, None, :] + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / self.num_points \ + * reference_points[:, :, None, :, None, 2:] \ + * 0.5 + else: + raise ValueError( + f'Last dim of reference_points must be' + f' 2 or 4, but get {reference_points.shape[-1]} instead.') + if torch.cuda.is_available(): + output = MultiScaleDeformableAttnFunction.apply( + value, spatial_shapes, level_start_index, sampling_locations, + attention_weights, self.im2col_step) + else: + output = multi_scale_deformable_attn_pytorch( + value, spatial_shapes, level_start_index, sampling_locations, + attention_weights, self.im2col_step) + output = self.output_proj(output).permute(1, 0, 2) + # (num_query, bs ,embed_dims) + return self.dropout(output) + inp_residual + + class FFN(BaseModule): """Implements feed-forward networks (FFNs) with residual connection. diff --git a/mmcv/ops/csrc/ms_deform_attn_cuda_kernel.cuh b/mmcv/ops/csrc/ms_deform_attn_cuda_kernel.cuh new file mode 100644 index 0000000000..c888fabb3e --- /dev/null +++ b/mmcv/ops/csrc/ms_deform_attn_cuda_kernel.cuh @@ -0,0 +1,807 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ +#ifndef DEFORM_ATTN_CUDA_KERNEL +#define DEFORM_ATTN_CUDA_KERNEL + +#include "common_cuda_helper.hpp" +#include "pytorch_cuda_helper.hpp" + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) { + return (N + num_threads - 1) / num_threads; +} + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear( + const scalar_t *&bottom_data, const int &height, const int &width, + const int &nheads, const int &channels, const scalar_t &h, + const scalar_t &w, const int &m, const int &c) { + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ void ms_deform_attn_col2im_bilinear( + const scalar_t *&bottom_data, const int &height, const int &width, + const int &nheads, const int &channels, const scalar_t &h, + const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad, + const scalar_t &attn_weight, scalar_t *&grad_value, + scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value + ptr1, w1 * top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value + ptr2, w2 * top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value + ptr3, w3 * top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value + ptr4, w4 * top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = width * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; +} + +template +__device__ void ms_deform_attn_col2im_bilinear_gm( + const scalar_t *&bottom_data, const int &height, const int &width, + const int &nheads, const int &channels, const scalar_t &h, + const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad, + const scalar_t &attn_weight, scalar_t *&grad_value, + scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value + ptr1, w1 * top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value + ptr2, w2 * top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value + ptr3, w3 * top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value + ptr4, w4 * top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + +template +__global__ void ms_deformable_im2col_gpu_kernel( + const int n, const scalar_t *data_value, const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, const int batch_size, + const int spatial_size, const int num_heads, const int channels, + const int num_levels, const int num_query, const int num_point, + scalar_t *data_col) { + CUDA_1D_KERNEL_LOOP(index, n) { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = + data_value + + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, + spatial_w, num_heads, channels, + h_im, w_im, m_col, c_col) * + weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1( + const int n, const scalar_t *grad_col, const scalar_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, + const int batch_size, const int spatial_size, const int num_heads, + const int channels, const int num_levels, const int num_query, + const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) { + CUDA_1D_KERNEL_LOOP(index, n) { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, + w_im, m_col, c_col, top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + if (tid == 0) { + scalar_t _grad_w = cache_grad_sampling_loc[0], + _grad_h = cache_grad_sampling_loc[1], + _grad_a = cache_grad_attn_weight[0]; + int sid = 2; + for (unsigned int tid = 1; tid < blockSize; ++tid) { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2( + const int n, const scalar_t *grad_col, const scalar_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, + const int batch_size, const int spatial_size, const int num_heads, + const int channels, const int num_levels, const int num_query, + const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) { + CUDA_1D_KERNEL_LOOP(index, n) { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, + w_im, m_col, c_col, top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s = blockSize / 2; s > 0; s >>= 1) { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1( + const int n, const scalar_t *grad_col, const scalar_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, + const int batch_size, const int spatial_size, const int num_heads, + const int channels, const int num_levels, const int num_query, + const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) { + CUDA_1D_KERNEL_LOOP(index, n) { + extern __shared__ int _s[]; + scalar_t *cache_grad_sampling_loc = reinterpret_cast(_s); + scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, + w_im, m_col, c_col, top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + if (tid == 0) { + scalar_t _grad_w = cache_grad_sampling_loc[0], + _grad_h = cache_grad_sampling_loc[1], + _grad_a = cache_grad_attn_weight[0]; + int sid = 2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2( + const int n, const scalar_t *grad_col, const scalar_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, + const int batch_size, const int spatial_size, const int num_heads, + const int channels, const int num_levels, const int num_query, + const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) { + CUDA_1D_KERNEL_LOOP(index, n) { + extern __shared__ int _s[]; + scalar_t *cache_grad_sampling_loc = reinterpret_cast(_s); + scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, + w_im, m_col, c_col, top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0; + s >>= 1, spre >>= 1) { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) { + cache_grad_attn_weight[tid] += + cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += + cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks( + const int n, const scalar_t *grad_col, const scalar_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, + const int batch_size, const int spatial_size, const int num_heads, + const int channels, const int num_levels, const int num_query, + const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) { + CUDA_1D_KERNEL_LOOP(index, n) { + extern __shared__ int _s[]; + scalar_t *cache_grad_sampling_loc = reinterpret_cast(_s); + scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight + threadIdx.x) = 0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, + w_im, m_col, c_col, top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc + (threadIdx.x << 1), + cache_grad_attn_weight + threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0; + s >>= 1, spre >>= 1) { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) { + cache_grad_attn_weight[tid] += + cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += + cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += + cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm( + const int n, const scalar_t *grad_col, const scalar_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, + const int batch_size, const int spatial_size, const int num_heads, + const int channels, const int num_levels, const int num_query, + const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) { + CUDA_1D_KERNEL_LOOP(index, n) { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col = 0; l_col < num_levels; ++l_col) { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = + data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col = 0; p_col < num_point; ++p_col) { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, + w_im, m_col, c_col, top_grad, weight, grad_value_ptr, + grad_sampling_loc, grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} +#endif // DEFORM_ATTN_CUDA_KERNEL diff --git a/mmcv/ops/csrc/pytorch/ms_deform_attn.cpp b/mmcv/ops/csrc/pytorch/ms_deform_attn.cpp new file mode 100644 index 0000000000..9bcee5c243 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/ms_deform_attn.cpp @@ -0,0 +1,74 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +Tensor ms_deform_attn_cuda_forward(const Tensor &value, + const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const int im2col_step); + +std::vector ms_deform_attn_cuda_backward( + const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, const Tensor &sampling_loc, + const Tensor &attn_weight, const Tensor &grad_output, + const int im2col_step); + +#endif + +Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const int im2col_step) { + if (value.type().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(value) + CHECK_CUDA_INPUT(spatial_shapes) + CHECK_CUDA_INPUT(level_start_index) + CHECK_CUDA_INPUT(sampling_loc) + CHECK_CUDA_INPUT(attn_weight) + return ms_deform_attn_cuda_forward(value, spatial_shapes, level_start_index, + sampling_loc, attn_weight, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::vector ms_deform_attn_backward(const Tensor &value, + const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const Tensor &grad_output, + const int im2col_step) { + if (value.type().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(value) + CHECK_CUDA_INPUT(spatial_shapes) + CHECK_CUDA_INPUT(level_start_index) + CHECK_CUDA_INPUT(sampling_loc) + CHECK_CUDA_INPUT(attn_weight) + CHECK_CUDA_INPUT(grad_output) + return ms_deform_attn_cuda_backward(value, spatial_shapes, + level_start_index, sampling_loc, + attn_weight, grad_output, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} diff --git a/mmcv/ops/csrc/pytorch/ms_deform_attn_cuda.cu b/mmcv/ops/csrc/pytorch/ms_deform_attn_cuda.cu new file mode 100644 index 0000000000..1cd67403f0 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/ms_deform_attn_cuda.cu @@ -0,0 +1,365 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include +#include +#include +#include + +#include +#include +#include + +template +void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, const int spatial_size, + const int num_heads, const int channels, + const int num_levels, const int num_query, + const int num_point, scalar_t *data_col) { + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, + data_sampling_loc, data_attn_weight, batch_size, spatial_size, + num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +template +void ms_deformable_col2im_cuda( + cudaStream_t stream, const scalar_t *grad_col, const scalar_t *data_value, + const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, + const int batch_size, const int spatial_size, const int num_heads, + const int channels, const int num_levels, const int num_query, + const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) { + const int num_threads = + (channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) { + if ((channels & 1023) == 0) { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, + num_query, num_point, grad_value, grad_sampling_loc, + grad_attn_weight); + } else { + ms_deformable_col2im_gpu_kernel_gm + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + } + } else { + switch (channels) { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, + data_attn_weight, batch_size, spatial_size, num_heads, + channels, num_levels, num_query, num_point, grad_value, + grad_sampling_loc, grad_attn_weight); + break; + default: + if (channels < 64) { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, + num_query, num_point, grad_value, grad_sampling_loc, + grad_attn_weight); + } else { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, grad_col, data_value, data_spatial_shapes, + data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, + num_query, num_point, grad_value, grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) { + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), + "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), + "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), + "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), + "attn_weight tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), + "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), + "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), + "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", + batch, im2col_step_); + + auto output = + at::zeros({batch, num_query, num_heads, channels}, value.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view( + {batch / im2col_step_, batch_n, num_query, num_heads, channels}); + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + for (int n = 0; n < batch / im2col_step_; ++n) { + auto columns = output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES( + value.type(), "ms_deform_attn_forward_cuda", ([&] { + ms_deformable_im2col_cuda( + at::cuda::getCurrentCUDAStream(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), level_start_index.data(), + sampling_loc.data() + + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, + num_point, columns.data()); + })); + } + + output = output.view({batch, num_query, num_heads * channels}); + + return output; +} + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, const at::Tensor &grad_output, + const int im2col_step) { + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), + "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), + "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), + "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), + "attn_weight tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), + "grad_output tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), + "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), + "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), + "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", + batch, im2col_step_); + + auto grad_value = at::zeros_like(value); + auto grad_sampling_loc = at::zeros_like(sampling_loc); + auto grad_attn_weight = at::zeros_like(attn_weight); + + const int batch_n = im2col_step_; + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + auto grad_output_n = grad_output.view( + {batch / im2col_step_, batch_n, num_query, num_heads, channels}); + + for (int n = 0; n < batch / im2col_step_; ++n) { + auto grad_output_g = grad_output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES( + value.type(), "ms_deform_attn_backward_cuda", ([&] { + ms_deformable_col2im_cuda( + at::cuda::getCurrentCUDAStream(), grad_output_g.data(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), level_start_index.data(), + sampling_loc.data() + + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, + num_point, + grad_value.data() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data() + + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data() + + n * im2col_step_ * per_attn_weight_size); + })); + } + + return {grad_value, grad_sampling_loc, grad_attn_weight}; +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 309baf9761..3948fc26e2 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -92,6 +92,19 @@ void modulated_deform_conv_backward( int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, const bool with_bias); +Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, const int im2col_step); + +std::vector ms_deform_attn_backward(const Tensor &value, + const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, + const Tensor &grad_output, + const int im2col_step); + Tensor nms(Tensor boxes, Tensor scores, float iou_threshold, int offset); Tensor softnms(Tensor boxes, Tensor scores, Tensor dets, float iou_threshold, @@ -182,12 +195,12 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order, const Tensor dets_sorted, const float iou_threshold, const int multi_label); -Tensor upfirdn2d(const Tensor& input, const Tensor& kernel, int up_x, int up_y, +Tensor upfirdn2d(const Tensor &input, const Tensor &kernel, int up_x, int up_y, int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, int pad_y1); -Tensor fused_bias_leakyrelu(const Tensor& input, const Tensor& bias, - const Tensor& refer, int act, int grad, float alpha, +Tensor fused_bias_leakyrelu(const Tensor &input, const Tensor &bias, + const Tensor &refer, int act, int grad, float alpha, float scale); void roi_align_rotated_forward(Tensor input, Tensor rois, Tensor output, @@ -401,4 +414,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("grad_input"), py::arg("pooled_height"), py::arg("pooled_width"), py::arg("spatial_scale"), py::arg("sample_num"), py::arg("aligned"), py::arg("clockwise")); + m.def("ms_deform_attn_forward", &ms_deform_attn_forward, + "forward function of multi-scale deformable attention", + py::arg("value"), py::arg("value_spatial_shapes"), + py::arg("value_level_start_index"), py::arg("sampling_locations"), + py::arg("attention_weights"), py::arg("im2col_step")); + m.def("ms_deform_attn_backward", &ms_deform_attn_backward, + "backward function of multi-scale deformable attention", + py::arg("value"), py::arg("value_spatial_shapes"), + py::arg("value_level_start_index"), py::arg("sampling_locations"), + py::arg("attention_weights"), py::arg("grad_output"), + py::arg("im2col_step")); } diff --git a/mmcv/ops/multi_scale_deform_attn.py b/mmcv/ops/multi_scale_deform_attn.py new file mode 100644 index 0000000000..77919e47ec --- /dev/null +++ b/mmcv/ops/multi_scale_deform_attn.py @@ -0,0 +1,134 @@ +import torch +import torch.nn.functional as F +from torch.autograd.function import Function, once_differentiable + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward']) + + +class MultiScaleDeformableAttnFunction(Function): + + @staticmethod + def forward(ctx, value, value_spatial_shapes, value_level_start_index, + sampling_locations, attention_weights, im2col_step): + """GPU version of multi-scale deformable attention. + + Args: + value (Tensor): The value has shape + (bs, num_keys, mum_heads, embed_dims//num_heads) + value_spatial_shapes (Tensor): Spatial shape of + each feature map, has shape (num_levels, 2), + last dimension 2 represent (h, w) + sampling_locations (Tensor): The location of sampling points, + has shape + (bs ,num_queries, num_heads, num_levels, num_points, 2), + the last dimension 2 represent (x, y). + attention_weights (Tensor): The weight of sampling points used + when calculate the attention, has shape + (bs ,num_queries, num_heads, num_levels, num_points), + im2col_step (Tensor): The step used in image to column. + + Returns: + Tensor: has shape (bs, num_queries, embed_dims) + """ + + ctx.im2col_step = im2col_step + output = ext_module.ms_deform_attn_forward(value, value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ctx.im2col_step) + ctx.save_for_backward(value, value_spatial_shapes, + value_level_start_index, sampling_locations, + attention_weights) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + """GPU version of backward function. + + Args: + grad_output (Tensor): Gradient + of output tensor of forward. + + Returns: + Tuple[Tensor]: Gradient + of input tensors in forward. + """ + value, value_spatial_shapes, value_level_start_index,\ + sampling_locations, attention_weights = ctx.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = \ + ext_module.ms_deform_attn_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + grad_output, + ctx.im2col_step) + + return grad_value, None, None, \ + grad_sampling_loc, grad_attn_weight, None + + +def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes, + sampling_locations, attention_weights): + """CPU version of multi-scale deformable attention. + + Args: + value (Tensor): The value has shape + (bs, num_keys, mum_heads, embed_dims//num_heads) + value_spatial_shapes (Tensor): Spatial shape of + each feature map, has shape (num_levels, 2), + last dimension 2 represent (h, w) + sampling_locations (Tensor): The location of sampling points, + has shape + (bs ,num_queries, num_heads, num_levels, num_points, 2), + the last dimension 2 represent (x, y). + attention_weights (Tensor): The weight of sampling points used + when calculate the attention, has shape + (bs ,num_queries, num_heads, num_levels, num_points), + + Returns: + Tensor: has shape (bs, num_queries, embed_dims) + """ + + bs, _, num_heads, embed_dims = value.shape + _, num_queries, num_heads, num_levels, num_points, _ =\ + sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], + dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level, (H_, W_) in enumerate(value_spatial_shapes): + # bs, H_*W_, num_heads, embed_dims -> + # bs, H_*W_, num_heads*embed_dims -> + # bs, num_heads*embed_dims, H_*W_ -> + # bs*num_heads, embed_dims, H_, W_ + value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape( + bs * num_heads, embed_dims, H_, W_) + # bs, num_queries, num_heads, num_points, 2 -> + # bs, num_heads, num_queries, num_points, 2 -> + # bs*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, + level].transpose(1, 2).flatten(0, 1) + # bs*num_heads, embed_dims, num_queries, num_points + sampling_value_l_ = F.grid_sample( + value_l_, + sampling_grid_l_, + mode='bilinear', + padding_mode='zeros', + align_corners=False) + sampling_value_list.append(sampling_value_l_) + # (bs, num_queries, num_heads, num_levels, num_points) -> + # (bs, num_heads, num_queries, num_levels, num_points) -> + # (bs, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + bs * num_heads, 1, num_queries, num_levels * num_points) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * + attention_weights).sum(-1).view(bs, num_heads * embed_dims, + num_queries) + return output.transpose(1, 2).contiguous() diff --git a/tests/test_ops/test_ms_deformable_attn.py b/tests/test_ops/test_ms_deformable_attn.py new file mode 100644 index 0000000000..39d371fcb3 --- /dev/null +++ b/tests/test_ops/test_ms_deformable_attn.py @@ -0,0 +1,125 @@ +import pytest +import torch +from torch.autograd import gradcheck + +from mmcv.ops.multi_scale_deform_attn import ( + MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch) + + +def test_forward_multi_scale_deformable_attn_pytorch(): + N, M, D = 1, 2, 2 + Lq, L, P = 2, 2, 2 + shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long) + S = sum([(H * W).item() for H, W in shapes]) + + torch.manual_seed(3) + value = torch.rand(N, S, M, D) * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2) + attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5 + attention_weights /= attention_weights.sum( + -1, keepdim=True).sum( + -2, keepdim=True) + + multi_scale_deformable_attn_pytorch(value.double(), shapes, + sampling_locations.double(), + attention_weights.double()).detach() + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_forward_equal_with_pytorch_double(): + N, M, D = 1, 2, 2 + Lq, L, P = 2, 2, 2 + shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() + level_start_index = torch.cat((shapes.new_zeros( + (1, )), shapes.prod(1).cumsum(0)[:-1])) + S = sum([(H * W).item() for H, W in shapes]) + + torch.manual_seed(3) + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum( + -1, keepdim=True).sum( + -2, keepdim=True) + im2col_step = 2 + output_pytorch = multi_scale_deformable_attn_pytorch( + value.double(), shapes, sampling_locations.double(), + attention_weights.double()).detach().cpu() + + output_cuda = MultiScaleDeformableAttnFunction.apply( + value.double(), shapes, level_start_index, sampling_locations.double(), + attention_weights.double(), im2col_step).detach().cpu() + assert torch.allclose(output_cuda, output_pytorch) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / + output_pytorch.abs()).max() + assert max_abs_err < 1e-18 + assert max_rel_err < 1e-15 + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_forward_equal_with_pytorch_float(): + N, M, D = 1, 2, 2 + Lq, L, P = 2, 2, 2 + shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() + level_start_index = torch.cat((shapes.new_zeros( + (1, )), shapes.prod(1).cumsum(0)[:-1])) + S = sum([(H * W).item() for H, W in shapes]) + + torch.manual_seed(3) + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum( + -1, keepdim=True).sum( + -2, keepdim=True) + im2col_step = 2 + output_pytorch = multi_scale_deformable_attn_pytorch( + value, shapes, sampling_locations, attention_weights).detach().cpu() + + output_cuda = MultiScaleDeformableAttnFunction.apply( + value, shapes, level_start_index, sampling_locations, + attention_weights, im2col_step).detach().cpu() + assert torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / + output_pytorch.abs()).max() + assert max_abs_err < 1e-9 + assert max_rel_err < 1e-6 + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +@pytest.mark.parametrize('channels', [4, 30, 32, 64, 71, 1025, 2048, 3096]) +def test_gradient_numerical(channels, + grad_value=True, + grad_sampling_loc=True, + grad_attn_weight=True): + + N, M, _ = 1, 2, 2 + Lq, L, P = 2, 2, 2 + shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() + level_start_index = torch.cat((shapes.new_zeros( + (1, )), shapes.prod(1).cumsum(0)[:-1])) + S = sum([(H * W).item() for H, W in shapes]) + + value = torch.rand(N, S, M, channels).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum( + -1, keepdim=True).sum( + -2, keepdim=True) + im2col_step = 2 + + func = MultiScaleDeformableAttnFunction.apply + + value.requires_grad = grad_value + sampling_locations.requires_grad = grad_sampling_loc + attention_weights.requires_grad = grad_attn_weight + + assert gradcheck( + func, + (value.double(), shapes, level_start_index, + sampling_locations.double(), attention_weights.double(), im2col_step))