From ca3f674ed988e2f83bbc12520521af016ecab50d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lina?= Date: Fri, 13 Dec 2024 13:56:42 +0100 Subject: [PATCH] remove context manager when loading shards and handle mlx weights (#2709) --- src/huggingface_hub/serialization/_torch.py | 44 ++++++--------------- 1 file changed, 11 insertions(+), 33 deletions(-) diff --git a/src/huggingface_hub/serialization/_torch.py b/src/huggingface_hub/serialization/_torch.py index a792ae5a40..52e11932d8 100644 --- a/src/huggingface_hub/serialization/_torch.py +++ b/src/huggingface_hub/serialization/_torch.py @@ -18,10 +18,9 @@ import os import re from collections import defaultdict, namedtuple -from contextlib import contextmanager from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union from packaging import version @@ -538,13 +537,15 @@ def _load_sharded_checkpoint( for shard_file in shard_files: # Load shard into memory shard_path = os.path.join(save_directory, shard_file) - with _load_shard_into_memory( + state_dict = load_state_dict_from_file( shard_path, - load_fn=load_state_dict_from_file, - kwargs={"weights_only": weights_only}, - ) as state_dict: - # Update model with parameters from this shard - model.load_state_dict(state_dict, strict=strict) + map_location="cpu", + weights_only=weights_only, + ) + # Update model with parameters from this shard + model.load_state_dict(state_dict, strict=strict) + # Explicitly remove the state dict from memory + del state_dict # 4. Return compatibility info loaded_keys = set(index["weight_map"].keys()) @@ -630,7 +631,8 @@ def load_state_dict_from_file( # Check format of the archive with safe_open(checkpoint_file, framework="pt") as f: # type: ignore[attr-defined] metadata = f.metadata() - if metadata.get("format") != "pt": + # see comment: https://github.com/huggingface/transformers/blob/3d213b57fe74302e5902d68ed9478c3ad1aaa713/src/transformers/modeling_utils.py#L3966 + if metadata is not None and metadata.get("format") not in ["pt", "mlx"]: raise OSError( f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " "you save your model with the `save_torch_model` method." @@ -668,30 +670,6 @@ def load_state_dict_from_file( # HELPERS -@contextmanager -def _load_shard_into_memory( - shard_path: str, - load_fn: Callable, - kwargs: Optional[Dict[str, Any]] = None, -): - """ - Context manager to handle loading and cleanup of model shards. - - Args: - shard_path: Path to the shard file - load_fn: Function to load the shard (either torch.load or safetensors.load) - - Yields: - The loaded state dict for this shard - """ - try: - state_dict = load_fn(shard_path, **kwargs) # type: ignore[arg-type] - yield state_dict - finally: - # Explicitly remove the state dict from memory - del state_dict - - def _validate_keys_for_strict_loading( model: "torch.nn.Module", loaded_keys: Iterable[str],