From 7122362fb025f42dcc1bab40a4211ee626116ce4 Mon Sep 17 00:00:00 2001 From: Jih-neng Lin Date: Fri, 1 Dec 2023 08:10:17 +0800 Subject: [PATCH] Add safetensors support (#4659) Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Jeff Rasley Co-authored-by: Michael Wyatt --- .../v2/checkpoint/huggingface_engine.py | 44 +++++++++++++++---- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/deepspeed/inference/v2/checkpoint/huggingface_engine.py b/deepspeed/inference/v2/checkpoint/huggingface_engine.py index 029e3f7774c0..6b64ed3185a2 100644 --- a/deepspeed/inference/v2/checkpoint/huggingface_engine.py +++ b/deepspeed/inference/v2/checkpoint/huggingface_engine.py @@ -8,6 +8,7 @@ import torch from .base_engine import CheckpointEngineBase from typing import Iterable, Tuple +from functools import partial from ..logging import inference_logger @@ -28,6 +29,7 @@ def __init__(self, model_name_or_path: str, auth_token: str = None) -> None: else: self.model_config.max_seq_length = self.generation_config.max_length + self._local_checkpoint_dir = None self._all_ckpt_paths = self._fetch_checkpoint_files() def _fetch_checkpoint_files(self): @@ -41,17 +43,30 @@ def _fetch_checkpoint_files(self): # NOTE(jeff): allow_patterns here are explicitly not using safetensors or other # checkpoint files that may be present. Example of all files in the llama-2-7b # repo here: https://huggingface.co/meta-llama/Llama-2-7b-hf/tree/main - from huggingface_hub import snapshot_download + from huggingface_hub import snapshot_download, list_files_info + + def model_has_safetensors(model_name_or_path: str) -> bool: + if os.path.isdir(model_name_or_path): + file_list = os.listdir(model_name_or_path) + else: + file_list = [rf.rfilename for rf in list_files_info(model_name_or_path)] + for f in file_list: + if f.endswith(".safetensors"): + return True + return False if os.path.isdir(self.model_name_or_path): self._local_checkpoint_dir = self.model_name_or_path else: + # We need to download the checkpoint files from HF + if model_has_safetensors(self.model_name_or_path): + # Prioritize downloading safetensors if they are available + allow_patterns = ["*.safetensors", "*.json", "*.pt"] + else: + # Fallback to bin files when safetensors are not present + allow_patterns = ["*.bin", "*.json", "*.pt"] self._local_checkpoint_dir = snapshot_download(self.model_name_or_path, - allow_patterns=[ - "*.bin", - "*.json", - "*.pt", - ], + allow_patterns=allow_patterns, revision=None, token=self.auth_token) @@ -59,11 +74,22 @@ def _fetch_checkpoint_files(self): self._local_checkpoint_dir ), f"Checkpoint dir {self._local_checkpoint_dir} is not a directory, cannot load checkpoint." - model_param_json = os.path.join(self._local_checkpoint_dir, "pytorch_model.bin.index.json") + # Set the appropriate file names based on whether we have safetensors or not + if model_has_safetensors(self._local_checkpoint_dir): + from safetensors.torch import load_file + model_param_json_fname = "model.safetensors.index.json" + model_file_fname = "model.safetensors" + self._checkpoint_load_fn = load_file + else: + model_param_json_fname = "pytorch_model.bin.index.json" + model_file_fname = "pytorch_model.bin" + self._checkpoint_load_fn = partial(torch.load, map_location="cpu") + + model_param_json = os.path.join(self._local_checkpoint_dir, model_param_json_fname) if not os.path.isfile(model_param_json): # We don't need any json as all such HF models will have pytorch_model.bin - all_checkpoint_files = [os.path.join(self._local_checkpoint_dir, 'pytorch_model.bin')] + all_checkpoint_files = [os.path.join(self._local_checkpoint_dir, model_file_fname)] else: param_map = json.load(open(model_param_json, "r")) @@ -84,7 +110,7 @@ def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]: """ for checkpoint in self._all_ckpt_paths: inference_logger().info(f"Loading checkpoint: {checkpoint}") - checkpoint_sd = torch.load(checkpoint, map_location='cpu') + checkpoint_sd = self._checkpoint_load_fn(checkpoint) param_keys = list(checkpoint_sd.keys()) for param_name in param_keys: param = checkpoint_sd[param_name]