Skip to content

Commit

Permalink
Update cache directory for loading pretrained weights (Semantic Segme…
Browse files Browse the repository at this point in the history
…ntation) (#3828)

* updated cache dir

* update unit tests
  • Loading branch information
kprokofi authored Aug 12, 2024
1 parent 1a8e10e commit d66eaf7
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 8 deletions.
3 changes: 2 additions & 1 deletion src/otx/algo/segmentation/backbones/dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def load_pretrained_weights(self, pretrained: str | None = None, prefix: str = "
checkpoint = torch.load(pretrained, "cpu")
print(f"init weight - {pretrained}")
elif pretrained is not None:
checkpoint = load_from_http(pretrained, "cpu")
cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints"
checkpoint = load_from_http(filename=pretrained, map_location="cpu", model_dir=cache_dir)
print(f"init weight - {pretrained}")
if checkpoint is not None:
load_checkpoint_to_model(self, checkpoint, prefix=prefix)
Expand Down
3 changes: 2 additions & 1 deletion src/otx/algo/segmentation/backbones/litehrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,7 +1525,8 @@ def load_pretrained_weights(self, pretrained: str | None = None, prefix: str = "
checkpoint = torch.load(pretrained, "cpu")
print(f"init weight - {pretrained}")
elif pretrained is not None:
checkpoint = load_from_http(pretrained, "cpu")
cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints"
checkpoint = load_from_http(filename=pretrained, map_location="cpu", model_dir=cache_dir)
print(f"init weight - {pretrained}")
if checkpoint is not None:
load_checkpoint_to_model(self, checkpoint, prefix=prefix)
3 changes: 2 additions & 1 deletion src/otx/algo/segmentation/backbones/mscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,8 @@ def load_pretrained_weights(self, pretrained: str | None = None, prefix: str = "
checkpoint = torch.load(pretrained, "cpu")
print(f"init weight - {pretrained}")
elif pretrained is not None:
checkpoint = load_from_http(pretrained, "cpu")
cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints"
checkpoint = load_from_http(filename=pretrained, map_location="cpu", model_dir=cache_dir)
print(f"init weight - {pretrained}")
if checkpoint is not None:
load_checkpoint_to_model(self, checkpoint, prefix=prefix)
2 changes: 1 addition & 1 deletion src/otx/algo/utils/mmengine_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def load_checkpoint(
def load_from_http(
filename: str,
map_location: str | None = None,
model_dir: str | None = None,
model_dir: Path | str | None = None,
progress: bool = os.isatty(0),
) -> dict[str, Any]:
"""Loads a checkpoint from an HTTP URL.
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/algo/segmentation/backbones/test_dinov2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from pathlib import Path
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -74,13 +75,13 @@ def mock_torch_load(self, mocker) -> MagicMock:

def test_load_pretrained_weights(self, dino_vit, pretrained_weight, mock_torch_load, mock_load_checkpoint_to_model):
dino_vit.load_pretrained_weights(pretrained=pretrained_weight)

mock_torch_load.assert_called_once_with(pretrained_weight, "cpu")
mock_load_checkpoint_to_model.assert_called_once()

def test_load_pretrained_weights_from_url(self, dino_vit, mock_load_from_http, mock_load_checkpoint_to_model):
pretrained_weight = "www.fake.com/fake.pth"
dino_vit.load_pretrained_weights(pretrained=pretrained_weight)

mock_load_from_http.assert_called_once_with(pretrained_weight, "cpu")
cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints"
mock_load_from_http.assert_called_once_with(filename=pretrained_weight, map_location="cpu", model_dir=cache_dir)
mock_load_checkpoint_to_model.assert_called_once()
4 changes: 3 additions & 1 deletion tests/unit/algo/segmentation/backbones/test_litehrnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from copy import deepcopy
from pathlib import Path
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -167,5 +168,6 @@ def test_load_pretrained_weights_from_url(self, extra_cfg, mock_load_from_http,
model = LiteHRNet(extra=extra_cfg)
model.load_pretrained_weights(pretrained=pretrained_weight)

mock_load_from_http.assert_called_once_with(pretrained_weight, "cpu")
cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints"
mock_load_from_http.assert_called_once_with(filename=pretrained_weight, map_location="cpu", model_dir=cache_dir)
mock_load_checkpoint_to_model.assert_called_once()
4 changes: 3 additions & 1 deletion tests/unit/algo/segmentation/backbones/test_mscan.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -101,5 +102,6 @@ def test_load_pretrained_weights_from_url(self, mock_load_from_http, mock_load_c
pretrained_weight = "www.fake.com/fake.pth"
MSCAN(pretrained_weights=pretrained_weight)

mock_load_from_http.assert_called_once_with(pretrained_weight, "cpu")
cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints"
mock_load_from_http.assert_called_once_with(filename=pretrained_weight, map_location="cpu", model_dir=cache_dir)
mock_load_checkpoint_to_model.assert_called_once()

0 comments on commit d66eaf7

Please # to comment.