From d66eaf78c6077cc65ea6760886269ef99dd1fc28 Mon Sep 17 00:00:00 2001 From: Prokofiev Kirill Date: Mon, 12 Aug 2024 10:01:19 +0200 Subject: [PATCH] Update cache directory for loading pretrained weights (Semantic Segmentation) (#3828) * updated cache dir * update unit tests --- src/otx/algo/segmentation/backbones/dinov2.py | 3 ++- src/otx/algo/segmentation/backbones/litehrnet.py | 3 ++- src/otx/algo/segmentation/backbones/mscan.py | 3 ++- src/otx/algo/utils/mmengine_utils.py | 2 +- tests/unit/algo/segmentation/backbones/test_dinov2.py | 5 +++-- tests/unit/algo/segmentation/backbones/test_litehrnet.py | 4 +++- tests/unit/algo/segmentation/backbones/test_mscan.py | 4 +++- 7 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/otx/algo/segmentation/backbones/dinov2.py b/src/otx/algo/segmentation/backbones/dinov2.py index 76ace69c5b2..6abf733165a 100644 --- a/src/otx/algo/segmentation/backbones/dinov2.py +++ b/src/otx/algo/segmentation/backbones/dinov2.py @@ -90,7 +90,8 @@ def load_pretrained_weights(self, pretrained: str | None = None, prefix: str = " checkpoint = torch.load(pretrained, "cpu") print(f"init weight - {pretrained}") elif pretrained is not None: - checkpoint = load_from_http(pretrained, "cpu") + cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints" + checkpoint = load_from_http(filename=pretrained, map_location="cpu", model_dir=cache_dir) print(f"init weight - {pretrained}") if checkpoint is not None: load_checkpoint_to_model(self, checkpoint, prefix=prefix) diff --git a/src/otx/algo/segmentation/backbones/litehrnet.py b/src/otx/algo/segmentation/backbones/litehrnet.py index a47a4571bf1..ba98a8b4650 100644 --- a/src/otx/algo/segmentation/backbones/litehrnet.py +++ b/src/otx/algo/segmentation/backbones/litehrnet.py @@ -1525,7 +1525,8 @@ def load_pretrained_weights(self, pretrained: str | None = None, prefix: str = " checkpoint = torch.load(pretrained, "cpu") print(f"init weight - {pretrained}") elif pretrained is not None: - checkpoint = load_from_http(pretrained, "cpu") + cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints" + checkpoint = load_from_http(filename=pretrained, map_location="cpu", model_dir=cache_dir) print(f"init weight - {pretrained}") if checkpoint is not None: load_checkpoint_to_model(self, checkpoint, prefix=prefix) diff --git a/src/otx/algo/segmentation/backbones/mscan.py b/src/otx/algo/segmentation/backbones/mscan.py index 415655bf8ca..cc1bb96db8b 100644 --- a/src/otx/algo/segmentation/backbones/mscan.py +++ b/src/otx/algo/segmentation/backbones/mscan.py @@ -445,7 +445,8 @@ def load_pretrained_weights(self, pretrained: str | None = None, prefix: str = " checkpoint = torch.load(pretrained, "cpu") print(f"init weight - {pretrained}") elif pretrained is not None: - checkpoint = load_from_http(pretrained, "cpu") + cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints" + checkpoint = load_from_http(filename=pretrained, map_location="cpu", model_dir=cache_dir) print(f"init weight - {pretrained}") if checkpoint is not None: load_checkpoint_to_model(self, checkpoint, prefix=prefix) diff --git a/src/otx/algo/utils/mmengine_utils.py b/src/otx/algo/utils/mmengine_utils.py index b7b90818bbc..8059d5aae5b 100644 --- a/src/otx/algo/utils/mmengine_utils.py +++ b/src/otx/algo/utils/mmengine_utils.py @@ -72,7 +72,7 @@ def load_checkpoint( def load_from_http( filename: str, map_location: str | None = None, - model_dir: str | None = None, + model_dir: Path | str | None = None, progress: bool = os.isatty(0), ) -> dict[str, Any]: """Loads a checkpoint from an HTTP URL. diff --git a/tests/unit/algo/segmentation/backbones/test_dinov2.py b/tests/unit/algo/segmentation/backbones/test_dinov2.py index 12f2c8ba2ee..8774767f61a 100644 --- a/tests/unit/algo/segmentation/backbones/test_dinov2.py +++ b/tests/unit/algo/segmentation/backbones/test_dinov2.py @@ -1,5 +1,6 @@ from __future__ import annotations +from pathlib import Path from unittest.mock import MagicMock import pytest @@ -74,7 +75,6 @@ def mock_torch_load(self, mocker) -> MagicMock: def test_load_pretrained_weights(self, dino_vit, pretrained_weight, mock_torch_load, mock_load_checkpoint_to_model): dino_vit.load_pretrained_weights(pretrained=pretrained_weight) - mock_torch_load.assert_called_once_with(pretrained_weight, "cpu") mock_load_checkpoint_to_model.assert_called_once() @@ -82,5 +82,6 @@ def test_load_pretrained_weights_from_url(self, dino_vit, mock_load_from_http, m pretrained_weight = "www.fake.com/fake.pth" dino_vit.load_pretrained_weights(pretrained=pretrained_weight) - mock_load_from_http.assert_called_once_with(pretrained_weight, "cpu") + cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints" + mock_load_from_http.assert_called_once_with(filename=pretrained_weight, map_location="cpu", model_dir=cache_dir) mock_load_checkpoint_to_model.assert_called_once() diff --git a/tests/unit/algo/segmentation/backbones/test_litehrnet.py b/tests/unit/algo/segmentation/backbones/test_litehrnet.py index 03c0f835d38..eddac529ed0 100644 --- a/tests/unit/algo/segmentation/backbones/test_litehrnet.py +++ b/tests/unit/algo/segmentation/backbones/test_litehrnet.py @@ -1,4 +1,5 @@ from copy import deepcopy +from pathlib import Path from unittest.mock import MagicMock import pytest @@ -167,5 +168,6 @@ def test_load_pretrained_weights_from_url(self, extra_cfg, mock_load_from_http, model = LiteHRNet(extra=extra_cfg) model.load_pretrained_weights(pretrained=pretrained_weight) - mock_load_from_http.assert_called_once_with(pretrained_weight, "cpu") + cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints" + mock_load_from_http.assert_called_once_with(filename=pretrained_weight, map_location="cpu", model_dir=cache_dir) mock_load_checkpoint_to_model.assert_called_once() diff --git a/tests/unit/algo/segmentation/backbones/test_mscan.py b/tests/unit/algo/segmentation/backbones/test_mscan.py index b6686276477..a991b9ba8c2 100644 --- a/tests/unit/algo/segmentation/backbones/test_mscan.py +++ b/tests/unit/algo/segmentation/backbones/test_mscan.py @@ -1,3 +1,4 @@ +from pathlib import Path from unittest.mock import MagicMock import pytest @@ -101,5 +102,6 @@ def test_load_pretrained_weights_from_url(self, mock_load_from_http, mock_load_c pretrained_weight = "www.fake.com/fake.pth" MSCAN(pretrained_weights=pretrained_weight) - mock_load_from_http.assert_called_once_with(pretrained_weight, "cpu") + cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints" + mock_load_from_http.assert_called_once_with(filename=pretrained_weight, map_location="cpu", model_dir=cache_dir) mock_load_checkpoint_to_model.assert_called_once()