Skip to content

Commit

Permalink
Fix tile CLI (#3204)
Browse files Browse the repository at this point in the history
* update cli

* fix tiling cli

* update import location

* reinitiate datamodule for OV if tile is enabled

* bring tile-cli-2.0.0 to develop branch

* update style

* fix unit test
  • Loading branch information
eugene123tw authored Apr 4, 2024
1 parent acf03b9 commit 2c43de6
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/otx/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def instantiate_model(self, model_config: Namespace) -> OTXModel:
Returns:
tuple: The model and optimizer and scheduler.
"""
from otx.core.model.base import OTXModel
from otx.core.model.base import OTXModel, OVModel

skip = set()

Expand Down Expand Up @@ -393,7 +393,7 @@ def instantiate_model(self, model_config: Namespace) -> OTXModel:
self.config_init[self.subcommand]["model"] = model

# Update tile config due to adaptive tiling
if self.datamodule.config.tile_config.enable_tiler:
if not isinstance(model, OVModel) and self.datamodule.config.tile_config.enable_tiler:
if not hasattr(model, "tile_config"):
msg = "The model does not have a tile_config attribute. Please check if the model supports tiling."
raise AttributeError(msg)
Expand Down
10 changes: 8 additions & 2 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,16 @@ def test(

is_ir_ckpt = Path(str(checkpoint)).suffix in [".xml", ".onnx"]
if is_ir_ckpt and not isinstance(model, OVModel):
datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test")
model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info)
if self.device.accelerator != "cpu":
msg = "IR model supports inference only on CPU device. The device is changed automatic."
warn(msg, stacklevel=1)
self.device = DeviceType.cpu # type: ignore[assignment]

# NOTE: Re-initiate datamodule without tiling as model API supports its own tiling mechanism
if isinstance(model, OVModel):
datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test")

# NOTE, trainer.test takes only lightning based checkpoint.
# So, it can't take the OTX1.x checkpoint.
if checkpoint is not None and not is_ir_ckpt:
Expand Down Expand Up @@ -440,9 +443,12 @@ def predict(

is_ir_ckpt = checkpoint is not None and Path(checkpoint).suffix in [".xml", ".onnx"]
if is_ir_ckpt and not isinstance(model, OVModel):
datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test")
model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info)

# NOTE: Re-initiate datamodule for OVModel as model API supports its own data pipeline.
if isinstance(model, OVModel):
datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test")

if checkpoint is not None and not is_ir_ckpt:
loaded_checkpoint = torch.load(checkpoint)
model.load_state_dict(loaded_checkpoint)
Expand Down
7 changes: 5 additions & 2 deletions src/otx/engine/utils/auto_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,15 @@ def update_ov_subset_pipeline(self, datamodule: OTXDataModule, subset: str = "te
data_configuration = datamodule.config
ov_test_config = self._load_default_config(model_name="openvino_model")["data"]["config"][f"{subset}_subset"]
subset_config = getattr(data_configuration, f"{subset}_subset")
subset_config.batch_size = ov_test_config["batch_size"]
subset_config.transform_lib_type = ov_test_config["transform_lib_type"]
subset_config.transforms = ov_test_config["transforms"]
data_configuration.tile_config.enable_tiler = False
msg = (
f"For OpenVINO IR models, Update the following {subset} transforms: {subset_config.transforms}"
f"and transform_lib_type: {subset_config.transform_lib_type}"
f"For OpenVINO IR models, Update the following {subset} \n"
f"\t transforms: {subset_config.transforms} \n"
f"\t transform_lib_type: {subset_config.transform_lib_type} \n"
f"\t batch_size: {subset_config.batch_size} \n"
"And the tiler is disabled."
)
warn(msg, stacklevel=1)
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
from unittest.mock import create_autospec

import pytest
from otx.algo.classification.efficientnet_b0 import EfficientNetB0ForMulticlassCls
from otx.algo.classification.torchvision_model import OTXTVModel
from otx.core.config.device import DeviceConfig
from otx.core.model.base import OVModel
from otx.core.types.export import OTXExportFormatType
from otx.core.types.precision import OTXPrecisionType
from otx.engine import Engine
Expand Down Expand Up @@ -123,6 +125,9 @@ def test_testing_with_ov_model(self, fxt_engine, mocker) -> None:
mock_test.assert_called_once()
mock_torch_load.assert_not_called()

fxt_engine.model = create_autospec(OVModel)
fxt_engine.test(checkpoint="path/to/model.xml")

def test_prediction_after_training(self, fxt_engine, mocker) -> None:
mocker.patch("otx.engine.engine.OTXModel.load_state_dict")
mock_predict = mocker.patch("otx.engine.engine.Trainer.predict")
Expand All @@ -137,6 +142,9 @@ def test_prediction_after_training(self, fxt_engine, mocker) -> None:
fxt_engine.predict(checkpoint="path/to/new/checkpoint")
mock_torch_load.assert_called_with("path/to/new/checkpoint")

fxt_engine.model = create_autospec(OVModel)
fxt_engine.predict(checkpoint="path/to/model.xml")

def test_prediction_with_ov_model(self, fxt_engine, mocker) -> None:
mock_predict = mocker.patch("otx.engine.engine.Trainer.predict")
mock_torch_load = mocker.patch("torch.load")
Expand Down

0 comments on commit 2c43de6

Please # to comment.