Skip to content

Commit

Permalink
Fix wrong indices setting in HLabelInfo (#4044)
Browse files Browse the repository at this point in the history
* Fix wrong indices setting in label_info

* Add unit-test & update for releases
  • Loading branch information
harimkang authored Oct 17, 2024
1 parent d6458e6 commit b817d1b
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/4016>)
- Fix multilabel_accuracy of MixedHLabelAccuracy
(<https://github.com/openvinotoolkit/training_extensions/pull/4042>)
- Fix wrong indices setting in HLabelInfo
(<https://github.com/openvinotoolkit/training_extensions/pull/4044>)

## \[v2.1.0\]

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ In addition to the examples above, please refer to the documentation for tutoria
- Fix num_trials calculation on dataset length less than num_class
- Fix out_features in HierarchicalCBAMClsHead
- Fix multilabel_accuracy of MixedHLabelAccuracy
- Fix wrong indices setting in HLabelInfo

### Known issues

Expand Down
1 change: 1 addition & 0 deletions docs/source/guide/release_notes/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Bug fixes
- Fix num_trials calculation on dataset length less than num_class
- Fix out_features in HierarchicalCBAMClsHead
- Fix multilabel_accuracy of MixedHLabelAccuracy
- Fix wrong indices setting in HLabelInfo

v2.1.0 (2024.07)
----------------
Expand Down
2 changes: 1 addition & 1 deletion src/otx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

__version__ = "2.2.0rc8"
__version__ = "2.2.0rc9"

import os
from pathlib import Path
Expand Down
4 changes: 3 additions & 1 deletion src/otx/core/types/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ def convert_labels_if_needed(
single_label_group_info["class_to_idx"],
)

label_to_idx = {lbl: i for i, lbl in enumerate(merged_class_to_idx.keys())}

return HLabelInfo(
label_names=label_names,
label_groups=all_groups,
Expand All @@ -273,7 +275,7 @@ def convert_labels_if_needed(
num_single_label_classes=exclusive_group_info["num_single_label_classes"],
class_to_group_idx=merged_class_to_idx,
all_groups=all_groups,
label_to_idx=dm_label_categories._indices, # noqa: SLF001
label_to_idx=label_to_idx,
label_tree_edges=get_label_tree_edges(dm_label_categories.items),
empty_multiclass_head_indices=[], # consider the label removing case
)
Expand Down
36 changes: 35 additions & 1 deletion tests/unit/core/types/test_label.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from otx.core.types.label import NullLabelInfo, SegLabelInfo
from datumaro import LabelCategories
from datumaro.components.annotation import GroupType
from otx.core.types.label import HLabelInfo, NullLabelInfo, SegLabelInfo


def test_as_json(fxt_label_info):
Expand All @@ -18,3 +21,34 @@ def test_seg_label_info():
)
assert SegLabelInfo.from_num_classes(1) == SegLabelInfo(["background", "label_0"], [["background", "label_0"]])
assert SegLabelInfo.from_num_classes(0) == NullLabelInfo()


# Unit test
def test_hlabel_info():
labels = [
LabelCategories.Category(name="car", parent="vehicle"),
LabelCategories.Category(name="truck", parent="vehicle"),
LabelCategories.Category(name="plush toy", parent="plush toy"),
LabelCategories.Category(name="No class"),
]
label_groups = [
LabelCategories.LabelGroup(
name="Detection labels___vehicle",
labels=["car", "truck"],
group_type=GroupType.EXCLUSIVE,
),
LabelCategories.LabelGroup(
name="Detection labels___plush toy",
labels=["plush toy"],
group_type=GroupType.EXCLUSIVE,
),
LabelCategories.LabelGroup(name="No class", labels=["No class"], group_type=GroupType.RESTRICTED),
]
dm_label_categories = LabelCategories(items=labels, label_groups=label_groups)

hlabel_info = HLabelInfo.from_dm_label_groups(dm_label_categories)

# Check if class_to_group_idx and label_to_idx have the same keys
assert list(hlabel_info.class_to_group_idx.keys()) == list(
hlabel_info.label_to_idx.keys(),
), "class_to_group_idx and label_to_idx keys do not match"

0 comments on commit b817d1b

Please # to comment.