Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix empty bbox error for YOLOv9 to 2.3.0 #4026

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/otx/algo/detection/losses/yolov9_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,20 @@ def __call__(self, target: Tensor, predict: tuple[Tensor, Tensor]) -> tuple[Tens
predict (tuple[Tensor, Tensor]): The predicted class and bounding box.

Returns:
tuple[Tensor, Tensor]: The aligned target tensor with (batch, targets, (class + 4)).
tuple[Tensor, Tensor]: The aligned target tensors with (batch, targets, (class + 4)) and (batch, targets).
"""
predict_cls, predict_bbox = predict

# return if target has no gt information.
n_targets = target.shape[1]
if n_targets == 0:
device = predict_bbox.device
align_cls = torch.zeros_like(predict_cls, device=device)
align_bbox = torch.zeros_like(predict_bbox, device=device)
valid_mask = torch.zeros(predict_cls.shape[:2], dtype=bool, device=device)
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
return anchor_matched_targets, valid_mask

target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
target_cls = target_cls.long().clamp(0)

Expand Down Expand Up @@ -341,8 +352,8 @@ def __call__(self, target: Tensor, predict: tuple[Tensor, Tensor]) -> tuple[Tens
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
align_cls = align_cls * normalize_term * valid_mask[:, :, None]

return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
return anchor_matched_targets, valid_mask


class YOLOv9Criterion(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_otx_e2e_cli(
assert (latest_dir / export_case.expected_output).exists()

if task == OTXTaskType.OBJECT_DETECTION_3D:
return # "3D Object Detection is not supported for OV IR inference.
return # "3D Object Detection is not supported for OV IR inference.

# 4) infer of the exported models
ov_output_dir = tmp_path_test / "outputs" / "OPENVINO"
Expand Down Expand Up @@ -319,7 +319,7 @@ def test_otx_explain_e2e_cli(
"rtdetr_50",
"rtdetr_101",
"maskrcnn_r50_tv",
"maskrcnn_r50_tv_tile"
"maskrcnn_r50_tv_tile",
]

if any(model in model_name for model in models_not_supported):
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/algo/detection/losses/test_yolov9_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,21 @@ def test_call(self, box_matcher: BoxMatcher) -> None:
assert align_targets.shape == torch.Size([1, 3, 14])
assert valid_masks.shape == torch.Size([1, 3])

def test_call_with_empty_bbox(self, box_matcher: BoxMatcher) -> None:
target = torch.zeros((1, 0, 5))

predict_cls = torch.rand((1, 8400, 10))
predict_bbox = torch.rand((1, 8400, 4))
predict = (predict_cls, predict_bbox)

align_targets, valid_masks = box_matcher(target, predict)

assert align_targets.shape == (1, 8400, 14)
assert torch.all(align_targets == 0)

assert valid_masks.shape == (1, 8400)
assert torch.all(~valid_masks)


class TestYOLOv9Criterion:
@pytest.fixture()
Expand Down
Loading