Skip to content

Commit

Permalink
Refactor empty label workaround in iseg (#4013)
Browse files Browse the repository at this point in the history
* Refactor empty label workaround in iseg

* Del unnecessary code

* Add fallback to the base implementation
  • Loading branch information
sovrasov authored Oct 11, 2024
1 parent 7fa81d3 commit 1d43f7c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
22 changes: 21 additions & 1 deletion src/otx/core/model/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import copy
import logging as log
import types
from contextlib import contextmanager
Expand All @@ -30,7 +31,7 @@
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.export import TaskLevelExportParameters
from otx.core.types.label import LabelInfoTypes
from otx.core.types.label import LabelInfo, LabelInfoTypes
from otx.core.utils.mask_util import encode_rle, polygon_to_rle
from otx.core.utils.tile_merge import InstanceSegTileMerge

Expand Down Expand Up @@ -274,12 +275,17 @@ def forward_for_tracing(self, inputs: Tensor) -> tuple[Tensor, ...]:
@property
def _export_parameters(self) -> TaskLevelExportParameters:
"""Defines parameters required to export a particular model implementation."""
modified_label_info = copy.deepcopy(self.label_info)
# Instance segmentation needs to add empty label to satisfy MAPI wrapper requirements
modified_label_info.label_names.insert(0, "otx_empty_lbl")

return super()._export_parameters.wrap(
model_type="MaskRCNN",
task_type="instance_segmentation",
confidence_threshold=self.hparams.get("best_confidence_threshold", 0.05),
iou_threshold=0.5,
tile_config=self.tile_config if self.tile_config.enable_tiler else None,
label_info=modified_label_info,
)

def on_load_checkpoint(self, ckpt: dict[str, Any]) -> None:
Expand Down Expand Up @@ -739,3 +745,17 @@ def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwa
best_confidence_threshold = self.hparams.get("best_confidence_threshold", None)
compute_kwargs = {"best_confidence_threshold": best_confidence_threshold}
return super()._log_metrics(meter, key, **compute_kwargs)

def _create_label_info_from_ov_ir(self) -> LabelInfo:
ov_model = self.model.get_model()

if ov_model.has_rt_info(["model_info", "label_info"]):
serialized = ov_model.get_rt_info(["model_info", "label_info"]).value
ir_label_info = LabelInfo.from_json(serialized)
# workaround to hide extra otx_empty_lbl
if ir_label_info.label_names[0] == "otx_empty_lbl":
ir_label_info.label_names.pop(0)
ir_label_info.label_groups[0].pop(0)
return ir_label_info

return super()._create_label_info_from_ov_ir()
18 changes: 5 additions & 13 deletions src/otx/core/types/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,11 @@ def to_metadata(self) -> dict[tuple[str, str], str]:
dict[tuple[str, str], str]: It will be directly delivered to
OpenVINO IR's `rt_info` or ONNX metadata slot.
"""
if self.task_type == "instance_segmentation":
# Instance segmentation needs to add empty label
all_labels = "otx_empty_lbl "
all_label_ids = "None "
for lbl in self.label_info.label_names:
all_labels += lbl.replace(" ", "_") + " "
all_label_ids += lbl.replace(" ", "_") + " "
else:
all_labels = ""
all_label_ids = ""
for lbl in self.label_info.label_names:
all_labels += lbl.replace(" ", "_") + " "
all_label_ids += lbl.replace(" ", "_") + " "
all_labels = ""
all_label_ids = ""
for lbl in self.label_info.label_names:
all_labels += lbl.replace(" ", "_") + " "
all_label_ids += lbl.replace(" ", "_") + " "

metadata = {
# Common
Expand Down

0 comments on commit 1d43f7c

Please # to comment.