Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Revert "adapt boxes_overlap_bev & box_iou_rotated (#3162)" #3232

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mmcv/ops/box_iou_rotated.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def box_iou_rotated(bboxes1: torch.Tensor,
flip_mat[-1] = -1
bboxes1 = bboxes1 * flip_mat
bboxes2 = bboxes2 * flip_mat
if bboxes1.device.type == 'npu':
scale_mat = bboxes1.new_ones(bboxes1.shape[-1])
scale_mat[-1] = 1.0 / 0.01745329252
bboxes1 = bboxes1 * scale_mat
bboxes2 = bboxes2 * scale_mat
bboxes1 = bboxes1.contiguous()
bboxes2 = bboxes2.contiguous()
ext_module.box_iou_rotated(
Expand Down
37 changes: 33 additions & 4 deletions mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,40 @@ void box_iou_rotated_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,

void box_iou_rotated_npu(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned) {
TORCH_CHECK(boxes1.size(1) == 5, "boxes1 must be 2D tensor (N, 5)");
TORCH_CHECK(boxes1.size(1) == 5, "boxes1 must be 2D tensor (N, 5)");
at::Tensor boxes = at::ones_like(boxes1);
at::Tensor query_boxes = at::ones_like(boxes2);
boxes = boxes1.transpose(0, 1).unsqueeze(0);
query_boxes = boxes2.transpose(0, 1).unsqueeze(0);

EXEC_NPU_CMD(aclnnBoxIou, boxes1, boxes2, mode_flag, aligned, ious);
return;
bool is_trans = false;
string modeStr = "iou";
if (mode_flag == 1) {
modeStr = "iof";
}
bool is_cross = true;
if (aligned) {
is_cross = false;
}
float v_threshold = 0;
float e_threshold = 0;

OpCommand cmd;
cmd.Name("RotatedIou")
.Input(boxes)
.Input(query_boxes)
.Output(ious)
.Attr("trans", is_trans)
.Attr("mode", modeStr)
.Attr("is_cross", is_cross)
.Attr("v_threshold", v_threshold)
.Attr("e_threshold", e_threshold)
.Run();

if (is_cross) {
ious = ious.view({boxes1.size(0), boxes2.size(0)});
} else {
ious = ious.view({boxes1.size(0), 1});
}
}

REGISTER_NPU_IMPL(box_iou_rotated_impl, box_iou_rotated_npu);
6 changes: 1 addition & 5 deletions tests/test_ops/test_iou3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
])
def test_boxes_overlap_bev(device):
np_boxes1 = np.asarray([[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 0.0],
Expand Down
Loading