diff --git a/src/operator/contrib/psroi_pooling.cc b/src/operator/contrib/psroi_pooling.cc index d3a3871ed004..c3b66a15852b 100644 --- a/src/operator/contrib/psroi_pooling.cc +++ b/src/operator/contrib/psroi_pooling.cc @@ -38,25 +38,193 @@ using std::floor; using std::ceil; namespace mshadow { + +template + inline void PSROIPoolForwardCPU( + const int count, + const DType* bottom_data, + const DType spatial_scale, + const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const DType* bottom_rois, + const int output_dim, + const int group_size, + DType* top_data) { + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); +#pragma omp parallel for num_threads(omp_threads) + for (int index = 0; index < count; index++) { + // The output is in order (n, ctop, ph, pw) + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int ctop = (index / pooled_width / pooled_height) % output_dim; + int n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const DType* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + DType roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale; + DType roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale; + DType roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale; + DType roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale; + + // Force too small ROIs to be 1x1 + DType roi_width = max(roi_end_w - roi_start_w, static_cast(0.1)); // avoid 0 + DType roi_height = max(roi_end_h - roi_start_h, static_cast(0.1)); + + // Compute w and h at bottom + DType bin_size_h = roi_height / static_cast(pooled_height); + DType bin_size_w = roi_width / static_cast(pooled_width); + + int hstart = floor(static_cast(ph) * bin_size_h + + roi_start_h); + int wstart = floor(static_cast(pw)* bin_size_w + + roi_start_w); + int hend = ceil(static_cast(ph + 1) * bin_size_h + + roi_start_h); + int wend = ceil(static_cast(pw + 1) * bin_size_w + + roi_start_w); + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart, 0), height); + hend = min(max(hend, 0), height); + wstart = min(max(wstart, 0), width); + wend = min(max(wend, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + int gw = floor(static_cast(pw)* group_size / pooled_width); + int gh = floor(static_cast(ph)* group_size / pooled_height); + gw = min(max(gw, 0), group_size - 1); + gh = min(max(gh, 0), group_size - 1); + int c = (ctop*group_size + gh)*group_size + gw; + + const DType* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width; + DType out_sum = 0; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int bottom_index = h*width + w; + out_sum += offset_bottom_data[bottom_index]; + } + } + + DType bin_area = (hend - hstart)*(wend - wstart); + top_data[index] = is_empty? (DType)0. : out_sum/bin_area; + } +} + template inline void PSROIPoolForward(const Tensor &out, const Tensor &data, const Tensor &bbox, - const float spatial_scale_, + const float spatial_scale, const int output_dim_, const int group_size_) { - // NOT_IMPLEMENTED; + const DType *bottom_data = data.dptr_; + const DType *bottom_rois = bbox.dptr_; + DType *top_data = out.dptr_; + const int count = out.shape_.Size(); + const int channels = data.size(1); + const int height = data.size(2); + const int width = data.size(3); + const int pooled_height = out.size(2); + const int pooled_width = out.size(3); + PSROIPoolForwardCPU ( + count, bottom_data, spatial_scale, channels, height, width, + pooled_height, pooled_width, bottom_rois, output_dim_, group_size_, top_data); + return; } +template + inline void PSROIPoolBackwardAccCPU( + const int count, + const DType* top_diff, + const int num_rois, + const DType spatial_scale, + const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const int group_size, + const int output_dim, + DType* bottom_diff, + const DType* bottom_rois) { + for (int index = 0; index < count; index++) { + // The output is in order (n, ctop, ph, pw) + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int ctop = (index / pooled_width / pooled_height) % output_dim; + int n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const DType* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + DType roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale; + DType roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale; + DType roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale; + DType roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale; + + // Force too small ROIs to be 1x1 + DType roi_width = max(roi_end_w - roi_start_w, static_cast(0.1)); // avoid 0 + DType roi_height = max(roi_end_h - roi_start_h, static_cast(0.1)); + + // Compute w and h at bottom + DType bin_size_h = roi_height / static_cast(pooled_height); + DType bin_size_w = roi_width / static_cast(pooled_width); + + int hstart = floor(static_cast(ph)* bin_size_h + + roi_start_h); + int wstart = floor(static_cast(pw)* bin_size_w + + roi_start_w); + int hend = ceil(static_cast(ph + 1) * bin_size_h + + roi_start_h); + int wend = ceil(static_cast(pw + 1) * bin_size_w + + roi_start_w); + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart, 0), height); + hend = min(max(hend, 0), height); + wstart = min(max(wstart, 0), width); + wend = min(max(wend, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + // Compute c at bottom + int gw = floor(static_cast(pw)* group_size / pooled_width); + int gh = floor(static_cast(ph)* group_size / pooled_height); + gw = min(max(gw, 0), group_size - 1); + gh = min(max(gh, 0), group_size - 1); + int c = (ctop*group_size + gh)*group_size + gw; + DType* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width; + DType bin_area = (hend - hstart)*(wend - wstart); + DType diff_val = is_empty ? (DType)0. : top_diff[index] / bin_area; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int bottom_index = h*width + w; + *(offset_bottom_diff + bottom_index) = *(offset_bottom_diff + bottom_index) + diff_val; + } + } + } +} + + template inline void PSROIPoolBackwardAcc(const Tensor &in_grad, const Tensor &out_grad, const Tensor &bbox, - const float spatial_scale_, + const float spatial_scale, const int output_dim_, const int group_size_) { - // NOT_IMPLEMENTED; + // LOG(INFO) << "PSROIPoolBackward"; + const DType *top_diff = out_grad.dptr_; + const DType *bottom_rois = bbox.dptr_; + DType *bottom_diff = in_grad.dptr_; + const int count = out_grad.shape_.Size(); + const int num_rois = bbox.size(0); + const int channels = in_grad.size(1); + const int height = in_grad.size(2); + const int width = in_grad.size(3); + const int pooled_height = out_grad.size(2); + const int pooled_width = out_grad.size(3); + PSROIPoolBackwardAccCPU ( + count, top_diff, num_rois, spatial_scale, channels, height, width, + pooled_height, pooled_width, group_size_, output_dim_, bottom_diff, bottom_rois); + return; } } // namespace mshadow diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 44ab326d76d4..19cf70927810 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5037,10 +5037,33 @@ def test_psroipooling(): group_size=num_group, pooled_size=num_group, output_dim=num_classes, name='test_op') rtol, atol = 1e-2, 1e-3 - # By now we only have gpu implementation - if default_context().device_type == 'gpu': - check_numeric_gradient(op, [im_data, rois_data], rtol=rtol, atol=atol, - grad_nodes=grad_nodes, ctx=mx.gpu(0)) + check_numeric_gradient(op, [im_data, rois_data], rtol=rtol, atol=atol, + grad_nodes=grad_nodes) + + +@with_seed() +def test_psroipooling_with_type(): + arg_params = { + 'psroipool_rois': np.array([[0, 10, 22, 161, 173], [0, 20, 15, 154, 160]])} + + # plain psroipooling + sym = mx.sym.contrib.PSROIPooling(spatial_scale=0.0625, output_dim=2, pooled_size=3, name='psroipool') + ctx_list = [{'ctx': mx.cpu(0), + 'psroipool_data': (1, 18, 14, 14), + 'psroipool_rois': (2, 5), + 'type_dict': {'psroipool_data': np.float64, 'psroipool_rois': np.float64}}, + {'ctx': mx.cpu(0), + 'psroipool_data': (1, 18, 14, 14), + 'psroipool_rois': (2, 5), + 'type_dict': {'psroipool_data': np.float32, 'psroipool_rois': np.float32}}, + {'ctx': mx.cpu(0), + 'psroipool_data': (1, 18, 14, 14), + 'psroipool_rois': (2, 5), + 'type_dict': {'psroipool_data': np.float16, 'psroipool_rois': np.float16}}, + ] + + check_consistency(sym, ctx_list, grad_req={'psroipool_data': 'write', + 'psroipool_rois': 'null'}, arg_params=arg_params) @with_seed()