From 1d43f7c6f7dbe46c37478cb9d53e65e7e79bf1cb Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Fri, 11 Oct 2024 09:36:32 +0200 Subject: [PATCH] Refactor empty label workaround in iseg (#4013) * Refactor empty label workaround in iseg * Del unnecessary code * Add fallback to the base implementation --- src/otx/core/model/instance_segmentation.py | 22 ++++++++++++++++++++- src/otx/core/types/export.py | 18 +++++------------ 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/otx/core/model/instance_segmentation.py b/src/otx/core/model/instance_segmentation.py index d7a3d9372a2..2384e39deed 100644 --- a/src/otx/core/model/instance_segmentation.py +++ b/src/otx/core/model/instance_segmentation.py @@ -4,6 +4,7 @@ from __future__ import annotations +import copy import logging as log import types from contextlib import contextmanager @@ -30,7 +31,7 @@ from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel from otx.core.schedulers import LRSchedulerListCallable from otx.core.types.export import TaskLevelExportParameters -from otx.core.types.label import LabelInfoTypes +from otx.core.types.label import LabelInfo, LabelInfoTypes from otx.core.utils.mask_util import encode_rle, polygon_to_rle from otx.core.utils.tile_merge import InstanceSegTileMerge @@ -274,12 +275,17 @@ def forward_for_tracing(self, inputs: Tensor) -> tuple[Tensor, ...]: @property def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" + modified_label_info = copy.deepcopy(self.label_info) + # Instance segmentation needs to add empty label to satisfy MAPI wrapper requirements + modified_label_info.label_names.insert(0, "otx_empty_lbl") + return super()._export_parameters.wrap( model_type="MaskRCNN", task_type="instance_segmentation", confidence_threshold=self.hparams.get("best_confidence_threshold", 0.05), iou_threshold=0.5, tile_config=self.tile_config if self.tile_config.enable_tiler else None, + label_info=modified_label_info, ) def on_load_checkpoint(self, ckpt: dict[str, Any]) -> None: @@ -739,3 +745,17 @@ def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwa best_confidence_threshold = self.hparams.get("best_confidence_threshold", None) compute_kwargs = {"best_confidence_threshold": best_confidence_threshold} return super()._log_metrics(meter, key, **compute_kwargs) + + def _create_label_info_from_ov_ir(self) -> LabelInfo: + ov_model = self.model.get_model() + + if ov_model.has_rt_info(["model_info", "label_info"]): + serialized = ov_model.get_rt_info(["model_info", "label_info"]).value + ir_label_info = LabelInfo.from_json(serialized) + # workaround to hide extra otx_empty_lbl + if ir_label_info.label_names[0] == "otx_empty_lbl": + ir_label_info.label_names.pop(0) + ir_label_info.label_groups[0].pop(0) + return ir_label_info + + return super()._create_label_info_from_ov_ir() diff --git a/src/otx/core/types/export.py b/src/otx/core/types/export.py index 7f64febe607..cc9c592f3b9 100644 --- a/src/otx/core/types/export.py +++ b/src/otx/core/types/export.py @@ -98,19 +98,11 @@ def to_metadata(self) -> dict[tuple[str, str], str]: dict[tuple[str, str], str]: It will be directly delivered to OpenVINO IR's `rt_info` or ONNX metadata slot. """ - if self.task_type == "instance_segmentation": - # Instance segmentation needs to add empty label - all_labels = "otx_empty_lbl " - all_label_ids = "None " - for lbl in self.label_info.label_names: - all_labels += lbl.replace(" ", "_") + " " - all_label_ids += lbl.replace(" ", "_") + " " - else: - all_labels = "" - all_label_ids = "" - for lbl in self.label_info.label_names: - all_labels += lbl.replace(" ", "_") + " " - all_label_ids += lbl.replace(" ", "_") + " " + all_labels = "" + all_label_ids = "" + for lbl in self.label_info.label_names: + all_labels += lbl.replace(" ", "_") + " " + all_label_ids += lbl.replace(" ", "_") + " " metadata = { # Common