Skip to content

Commit

Permalink
Fix empty bbox error for YOLOv9 (#4024)
Browse files Browse the repository at this point in the history
* Fix empty bbox error

* Add unit test

* precommit
  • Loading branch information
sungchul2 authored Oct 15, 2024
1 parent 08fe9d6 commit fa272d5
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/otx/algo/detection/losses/yolov9_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,20 @@ def __call__(self, target: Tensor, predict: tuple[Tensor, Tensor]) -> tuple[Tens
predict (tuple[Tensor, Tensor]): The predicted class and bounding box.
Returns:
tuple[Tensor, Tensor]: The aligned target tensor with (batch, targets, (class + 4)).
tuple[Tensor, Tensor]: The aligned target tensors with (batch, targets, (class + 4)) and (batch, targets).
"""
predict_cls, predict_bbox = predict

# return if target has no gt information.
n_targets = target.shape[1]
if n_targets == 0:
device = predict_bbox.device
align_cls = torch.zeros_like(predict_cls, device=device)
align_bbox = torch.zeros_like(predict_bbox, device=device)
valid_mask = torch.zeros(predict_cls.shape[:2], dtype=bool, device=device)
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
return anchor_matched_targets, valid_mask

target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
target_cls = target_cls.long().clamp(0)

Expand Down Expand Up @@ -341,8 +352,8 @@ def __call__(self, target: Tensor, predict: tuple[Tensor, Tensor]) -> tuple[Tens
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
align_cls = align_cls * normalize_term * valid_mask[:, :, None]

return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
return anchor_matched_targets, valid_mask


class YOLOv9Criterion(nn.Module):
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/algo/detection/losses/test_yolov9_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,21 @@ def test_call(self, box_matcher: BoxMatcher) -> None:
assert align_targets.shape == torch.Size([1, 3, 14])
assert valid_masks.shape == torch.Size([1, 3])

def test_call_with_empty_bbox(self, box_matcher: BoxMatcher) -> None:
target = torch.zeros((1, 0, 5))

predict_cls = torch.rand((1, 8400, 10))
predict_bbox = torch.rand((1, 8400, 4))
predict = (predict_cls, predict_bbox)

align_targets, valid_masks = box_matcher(target, predict)

assert align_targets.shape == (1, 8400, 14)
assert torch.all(align_targets == 0)

assert valid_masks.shape == (1, 8400)
assert torch.all(~valid_masks)


class TestYOLOv9Criterion:
@pytest.fixture()
Expand Down

0 comments on commit fa272d5

Please # to comment.