Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Change categories mapping logic #3946

Merged
merged 2 commits into from
Sep 13, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/otx/core/data/pre_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,15 @@
raise ValueError(msg)
if len(used_labels) == len(original_categories):
return dataset
if data_format == "arrow" and sorted(used_labels)[-1] != len(original_categories) - 1:
kprokofi marked this conversation as resolved.
Show resolved Hide resolved
# we assume that empty label is always the last one. If it is not explicitly added to the dataset,
# (not in the used labels) it will be filtered out.
mapping = {cat: cat for cat in original_categories[:-1]}

Check warning on line 94 in src/otx/core/data/pre_filtering.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/pre_filtering.py#L94

Added line #L94 was not covered by tests
elif data_format == "arrow":
# this mean that some other class wasn't annotated, we don't need to filter the object classes
return dataset

Check warning on line 97 in src/otx/core/data/pre_filtering.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/pre_filtering.py#L97

Added line #L97 was not covered by tests
else:
mapping = {original_categories[idx]: original_categories[idx] for idx in used_labels}
msg = "There are unused labels in dataset, they will be filtered out before training."
warnings.warn(msg, stacklevel=2)
mapping = {original_categories[idx]: original_categories[idx] for idx in used_labels}
return dataset.transform("remap_labels", mapping=mapping, default="delete")
Loading