Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

BC improvement #4154

Merged
merged 8 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/otx/core/data/dataset/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def __init__(self, dataset: OTXDataset, tile_config: TileConfig) -> None:
dataset.image_color_channel,
dataset.stack_images,
dataset.to_tv_image,
data_format=dataset.data_format,
)
self.tile_config = tile_config
self._dataset = dataset
Expand Down
5 changes: 5 additions & 0 deletions src/otx/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,11 @@ def load_state_dict_incrementally(self, ckpt: dict[str, Any], *args, **kwargs) -
msg = "Checkpoint should have `label_info`."
raise ValueError(msg, ckpt_label_info)

if not hasattr(ckpt_label_info, "label_ids"):
msg = "Loading checkpoint from OTX < 2.2.1, label_ids are assigned automatically"
logger.info(msg)
ckpt_label_info.label_ids = [str(i) for i, _ in enumerate(ckpt_label_info.label_names)]

if ckpt_label_info != self.label_info:
msg = (
"Load model state dictionary incrementally: "
Expand Down
2 changes: 2 additions & 0 deletions src/otx/core/types/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dataclasses import dataclass, fields
from enum import Enum

import otx
from otx.core.config.data import TileConfig
from otx.core.types.label import HLabelInfo, LabelInfo

Expand Down Expand Up @@ -122,6 +123,7 @@ def to_metadata(self) -> dict[tuple[str, str], str]:
("model_info", "labels"): all_labels.strip(),
("model_info", "label_ids"): all_label_ids.strip(),
("model_info", "optimization_config"): json.dumps(self.optimization_config),
("model_info", "otx_version"): otx.__version__,
}

if isinstance(self.label_info, HLabelInfo):
Expand Down
1 change: 1 addition & 0 deletions tests/unit/core/types/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@ def test_wrap(fxt_label_info, task_type):
assert ("model_info", "tile_size") in metadata
assert ("model_info", "tiles_overlap") in metadata
assert ("model_info", "max_pred_number") in metadata
assert ("model_info", "otx_version") in metadata
Loading