Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Check symlinks support per directory instead of globally #1077

Merged
merged 3 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,33 +172,51 @@ def get_jinja_version():
return _jinja_version


_are_symlinks_supported: Optional[bool] = None
_are_symlinks_supported_in_dir: Dict[str, bool] = {}


def are_symlinks_supported() -> bool:
# Check symlink compatibility only once at first time use
global _are_symlinks_supported
def are_symlinks_supported(cache_dir: Union[str, Path, None] = None) -> bool:
"""Return whether the symlinks are supported on the machine.

if _are_symlinks_supported is None:
_are_symlinks_supported = True
Since symlinks support can change depending on the mounted disk, we need to check
on the precise cache folder. By default, the default HF cache directory is checked.

with tempfile.TemporaryDirectory() as tmpdir:
Args:
cache_dir (`str`, `Path`, *optional*):
Path to the folder where cached files are stored.

Returns: [bool] Whether symlinks are supported in the directory.
"""
# Defaults to HF cache
if cache_dir is None:
cache_dir = HUGGINGFACE_HUB_CACHE
cache_dir = str(Path(cache_dir).expanduser().resolve()) # make it unique

# Check symlink compatibility only once (per cache directory) at first time use
if cache_dir not in _are_symlinks_supported_in_dir:
_are_symlinks_supported_in_dir[cache_dir] = True

os.makedirs(cache_dir, exist_ok=True)
with tempfile.TemporaryDirectory(dir=cache_dir) as tmpdir:
src_path = Path(tmpdir) / "dummy_file_src"
src_path.touch()
dst_path = Path(tmpdir) / "dummy_file_dst"

# Relative source path as in `_create_relative_symlink``
relative_src = os.path.relpath(src_path, start=os.path.dirname(dst_path))
try:
os.symlink(src_path, dst_path)
os.symlink(relative_src, dst_path)
except OSError:
# Likely running on Windows
_are_symlinks_supported = False
_are_symlinks_supported_in_dir[cache_dir] = False

if not os.environ.get("DISABLE_SYMLINKS_WARNING"):
message = (
"`huggingface_hub` cache-system uses symlinks by default to"
" efficiently store duplicated files but your machine doesn't"
" support them. Caching files will still work but in a degraded"
" version that might require more space on your disk. This"
" warning can be disabled by setting the"
" efficiently store duplicated files but your machine does not"
f" support them in {cache_dir}. Caching files will still work"
" but in a degraded version that might require more space on"
" your disk. This warning can be disabled by setting the"
" `DISABLE_SYMLINKS_WARNING` environment variable. For more"
" details, see"
" https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations."
Expand All @@ -213,7 +231,7 @@ def are_symlinks_supported() -> bool:
)
warnings.warn(message)

return _are_symlinks_supported
return _are_symlinks_supported_in_dir[cache_dir]


# Return value when trying to load a file from cache but the file does not exist in the distant repo.
Expand Down Expand Up @@ -920,7 +938,8 @@ def _create_relative_symlink(src: str, dst: str, new_blob: bool = False) -> None
except OSError:
pass

if are_symlinks_supported():
cache_dir = os.path.dirname(os.path.commonpath([src, dst]))
if are_symlinks_supported(cache_dir=cache_dir):
os.symlink(relative_src, dst)
elif new_blob:
os.replace(src, dst)
Expand Down
33 changes: 22 additions & 11 deletions tests/test_cache_no_symlinks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from huggingface_hub import hf_hub_download, scan_cache_dir
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.constants import CONFIG_NAME, HUGGINGFACE_HUB_CACHE
from huggingface_hub.file_download import are_symlinks_supported

from .testing_constants import TOKEN
Expand All @@ -18,24 +18,35 @@
class TestCacheLayoutIfSymlinksNotSupported(unittest.TestCase):
cache_dir: Path

@patch("huggingface_hub.file_download._are_symlinks_supported", None)
def test_are_symlinks_supported_normal(self) -> None:
@patch(
"huggingface_hub.file_download._are_symlinks_supported_in_dir",
{HUGGINGFACE_HUB_CACHE: True},
)
def test_are_symlinks_supported_default(self) -> None:
self.assertTrue(are_symlinks_supported())

@patch("huggingface_hub.file_download.os.symlink") # Symlinks not supported
@patch("huggingface_hub.file_download._are_symlinks_supported", None) # first use
def test_are_symlinks_supported_windows(self, mock_symlink: Mock) -> None:
mock_symlink.side_effect = OSError()
@patch("huggingface_hub.file_download.os.symlink")
@patch("huggingface_hub.file_download._are_symlinks_supported_in_dir", {})
def test_are_symlinks_supported_windows_specific_dir(
self, mock_symlink: Mock
) -> None:
mock_symlink.side_effect = [OSError(), None] # First dir not supported then yes
this_dir = Path(__file__).parent

# First time: warning is raised
# First time in `this_dir`: warning is raised
with self.assertWarns(UserWarning):
self.assertFalse(are_symlinks_supported())
self.assertFalse(are_symlinks_supported(this_dir))

# Afterward: value is cached (no warning raised)
with warnings.catch_warnings():
# Assert no warnings raised
# Taken from https://stackoverflow.com/a/45671804
warnings.simplefilter("error")
self.assertFalse(are_symlinks_supported())

# Second time in `this_dir` but with absolute path: value is still cached
self.assertFalse(are_symlinks_supported(this_dir.absolute()))

# Try with another directory: symlinks are supported, no warnings
self.assertTrue(are_symlinks_supported()) # True

@patch("huggingface_hub.file_download.are_symlinks_supported")
def test_download_no_symlink_new_file(
Expand Down