diff --git a/src/otx/core/data/pre_filtering.py b/src/otx/core/data/pre_filtering.py index 459ef7be6f5..f78d8fe1db2 100644 --- a/src/otx/core/data/pre_filtering.py +++ b/src/otx/core/data/pre_filtering.py @@ -88,7 +88,15 @@ def remove_unused_labels(dataset: DmDataset, data_format: str, ignore_index: int raise ValueError(msg) if len(used_labels) == len(original_categories): return dataset + if data_format == "arrow" and max(used_labels) != len(original_categories) - 1: + # we assume that empty label is always the last one. If it is not explicitly added to the dataset, + # (not in the used labels) it will be filtered out. + mapping = {cat: cat for cat in original_categories[:-1]} + elif data_format == "arrow": + # this mean that some other class wasn't annotated, we don't need to filter the object classes + return dataset + else: + mapping = {original_categories[idx]: original_categories[idx] for idx in used_labels} msg = "There are unused labels in dataset, they will be filtered out before training." warnings.warn(msg, stacklevel=2) - mapping = {original_categories[idx]: original_categories[idx] for idx in used_labels} return dataset.transform("remap_labels", mapping=mapping, default="delete")