From f5eea6ad009cf86d559203fd80dbca8fabd35d7c Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 14 Apr 2022 18:10:05 -0700 Subject: [PATCH] [modeling utils] revamp `from_pretrained(..., low_cpu_mem_usage=True)` + tests (#16657) * add low_cpu_mem_usage tests * wip: revamping * wip * install /usr/bin/time * wip * cleanup * cleanup * cleanup * cleanup * cleanup * fix assert * put the wrapper back * cleanup; switch to bert-base-cased * Trigger CI * Trigger CI --- .circleci/config.yml | 2 +- src/transformers/modeling_utils.py | 236 +++++++++++++++++++---------- src/transformers/testing_utils.py | 48 ++++++ tests/test_modeling_common.py | 51 +++++++ 4 files changed, 254 insertions(+), 83 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 310aecc7b83644..869406fc0b637a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -217,7 +217,7 @@ jobs: keys: - v0.4-torch-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} - - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng + - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng time - run: pip install --upgrade pip - run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm] - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a6b2aa1e78ba27..a721f550a9f4ea 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -400,6 +400,95 @@ def load(module: nn.Module, prefix=""): return error_msgs +def find_submodule_and_param_name(model, long_key, start_prefix): + """ + A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed + from the start of the key + """ + + if len(start_prefix) > 0 and long_key.startswith(start_prefix): + long_key = ".".join(long_key.split(".")[1:]) + + split_key = long_key.split(".") + submodule = model + while len(split_key) > 1: + if hasattr(submodule, split_key[0]): + submodule = getattr(submodule, split_key[0]) + del split_key[0] + else: + submodule = None + break + if submodule == model: + submodule = None + return submodule, split_key[0] + + +def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix): + """ + Moves `loaded_state_dict_keys` in model to meta device which frees up the memory taken by those params. + + `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in + `bert.pooler.dense.weight` + + """ + + # meta device was added in pt=1.9 + require_version_core("torch>=1.9") + + # dematerialize param storage for keys that are going to be replaced by state_dict, by + # putting those on the meta device + for k in loaded_state_dict_keys: + submodule, param_name = find_submodule_and_param_name(model, k, start_prefix) + if submodule is not None: + # selectively switch to the meta device only those params/buffers that will + # be next replaced from state_dict. This a complex way to do p.to_("meta") + # since we have no in-place to_ for tensors. + new_val = getattr(submodule, param_name) + if isinstance(new_val, torch.nn.Parameter): + # isinstance returns False for Params on meta device, so switch after the check + new_val = torch.nn.Parameter(new_val.to("meta")) + else: + new_val = new_val.to("meta") + setattr(submodule, param_name, new_val) + + +def _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix): + """ + This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its + params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the + params back to the normal device, but only for `loaded_state_dict_keys`. + + `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in + `bert.pooler.dense.weight` + + """ + + # XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model + # - deepspeed zero 3 support + # - need to copy metadata if any - see _load_state_dict_into_model + # - handling error_msgs - mimicking the error handling in module._load_from_state_dict() + # - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case + # they won't get loaded. + + if is_deepspeed_zero3_enabled(): + raise ValueError("low_cpu_mem_usage arg cannot currently be used with DeepSpeed ZeRO-3") + + error_msgs = [] + + # materialize state_dict entries one by one on CPU + for k in loaded_state_dict_keys: + if k in state_dict: + submodule, param_name = find_submodule_and_param_name(model, k, start_prefix) + if submodule is not None: + param_dtype = getattr(submodule, param_name).dtype + new_val = state_dict[k].to(param_dtype) + if isinstance(getattr(submodule, param_name), torch.nn.Parameter): + new_val = torch.nn.Parameter(new_val) + setattr(submodule, param_name, new_val) + + return error_msgs + + class ModuleUtilsMixin: """ A few utilities for `torch.nn.Modules`, to be used as a mixin. @@ -1529,7 +1618,24 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P >>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config) >>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower) >>> model = BertModel.from_pretrained("bert-base-uncased", from_flax=True) - ```""" + ``` + + * `low_cpu_mem_usage` algorithm: + + This is an experimental function that loads the model using ~1x model size CPU memory + + Here is how it works: + + 1. save which state_dict keys we have + 2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory + 3. after the model has been instantiated switch to the meta device all params/buffers that + are going to be replaced from the loaded state_dict + 4. load state_dict 2nd time + 5. replace the params/buffers from the state_dict + + Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors + + """ config = kwargs.pop("config", None) state_dict = kwargs.pop("state_dict", None) cache_dir = kwargs.pop("cache_dir", None) @@ -1778,6 +1884,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if not is_sharded and state_dict is None: # Time to load the checkpoint state_dict = load_state_dict(resolved_archive_file) + # set dtype to instantiate the model under: # 1. If torch_dtype is not None, we use that dtype # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first @@ -1801,13 +1908,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) dtype_orig = cls._set_default_torch_dtype(torch_dtype) + if is_sharded: + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] + else: + loaded_state_dict_keys = [k for k in state_dict.keys()] if low_cpu_mem_usage: - # save the keys - if is_sharded: - loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] - else: - loaded_state_dict_keys = [k for k in state_dict.keys()] - del state_dict # free CPU memory - will reload again later + state_dict = None config.name_or_path = pretrained_model_name_or_path @@ -1825,11 +1931,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P with no_init_weights(_enable=_fast_init): model = cls(config, *model_args, **model_kwargs) - if from_pt: - # restore default dtype - if dtype_orig is not None: - torch.set_default_dtype(dtype_orig) - if from_tf: if resolved_archive_file.endswith(".index"): # Load from a TensorFlow 1.X checkpoint - provided by original authors @@ -1859,18 +1960,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P raise elif from_pt: - if low_cpu_mem_usage: - cls._load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file) - else: - model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( - model, - state_dict, - resolved_archive_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=ignore_mismatched_sizes, - sharded_metadata=sharded_metadata, - _fast_init=_fast_init, - ) + # restore default dtype + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + loaded_state_dict_keys, # XXX: rename? + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + sharded_metadata=sharded_metadata, + _fast_init=_fast_init, + low_cpu_mem_usage=low_cpu_mem_usage, + ) # make sure token embedding weights are still tied if needed model.tie_weights() @@ -1894,16 +1998,17 @@ def _load_pretrained_model( cls, model, state_dict, + loaded_keys, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes=False, sharded_metadata=None, _fast_init=True, + low_cpu_mem_usage=False, ): # Retrieve missing & unexpected_keys model_state_dict = model.state_dict() expected_keys = list(model_state_dict.keys()) - loaded_keys = list(state_dict.keys()) if state_dict is not None else sharded_metadata["all_checkpoint_keys"] prefix = model.base_model_prefix def _fix_key(key): @@ -1994,9 +2099,12 @@ def _find_mismatched_keys( (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) ) del state_dict[checkpoint_key] - return mismatched_keys + if low_cpu_mem_usage: + model_state_dict = None # free references to model's params to allow memory freeing + _move_model_to_meta(model, loaded_keys, start_prefix) + if state_dict is not None: # Whole checkpoint mismatched_keys = _find_mismatched_keys( @@ -2009,7 +2117,8 @@ def _find_mismatched_keys( ) error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix) else: - # Sharded checkpoint + # Sharded checkpoint or whole but low_cpu_mem_usage==True + # This should always be a list but, just to be sure. if not isinstance(resolved_archive_file, list): resolved_archive_file = [resolved_archive_file] @@ -2018,6 +2127,10 @@ def _find_mismatched_keys( mismatched_keys = [] for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file) + + if low_cpu_mem_usage: + model_state_dict = model.state_dict() + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. mismatched_keys += _find_mismatched_keys( @@ -2028,7 +2141,13 @@ def _find_mismatched_keys( remove_prefix_from_model, ignore_mismatched_sizes, ) - error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix) + + if low_cpu_mem_usage: + error_msgs += _load_state_dict_into_meta_model( + model_to_load, state_dict, loaded_keys, start_prefix + ) + else: + error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix) if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) @@ -2093,13 +2212,13 @@ def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=Fal return retrieved_modules @staticmethod - def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file): + def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file, start_prefix=""): """ This is an experimental function that loads the model using ~1.x model size CPU memory - Before it gets called we do: + Before you call it do: - 1. save which state_dict keys we have + 1. save which state_dict keys are available 2. drop state_dict before model is created, since the latter takes 1x model size memory Here then we continue: @@ -2110,58 +2229,11 @@ def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archi Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed. """ - require_version_core("torch>=1.9") - if is_deepspeed_zero3_enabled(): - raise ValueError("low_cpu_mem_usage arg cannot be used with DeepSpeed ZeRO-3") - - # a helper util to find the last sub-module and the param/buffer name - def find_submodule_and_param_name(model, long_key): - split_key = long_key.split(".") - submodule = model - while len(split_key) > 1: - if hasattr(submodule, split_key[0]): - submodule = getattr(submodule, split_key[0]) - del split_key[0] - else: - submodule = None - break - return submodule, split_key[0] - - # dematerialize param storage for keys that are going to be replaced by state_dict, by - # putting those on the meta device - for k in loaded_state_dict_keys: - submodule, param_name = find_submodule_and_param_name(model, k) - if submodule is not None: - # selectively switch to the meta device only those params/buffers that will - # be next replaced from state_dict. This a complex way to do p.to_("meta") - # since we have no in-place to_ for tensors. - new_val = getattr(submodule, param_name) - if isinstance(new_val, torch.nn.Parameter): - # isinstance returns False for Params on meta device, so switch after the check - new_val = torch.nn.Parameter(new_val.to("meta")) - else: - new_val = new_val.to("meta") - setattr(submodule, param_name, new_val) - # only now can load state_dict(s) - if not isinstance(resolved_archive_file, list): - resolved_archive_file = [resolved_archive_file] - - for archive_file in resolved_archive_file: - state_dict = torch.load(archive_file, map_location="cpu") - - # materialize state_dict entries one by one on CPU - for k in loaded_state_dict_keys: - if k in state_dict: - submodule, param_name = find_submodule_and_param_name(model, k) - if submodule is not None: - param_dtype = getattr(submodule, param_name).dtype - new_val = state_dict[k].to(param_dtype) - if isinstance(getattr(submodule, param_name), torch.nn.Parameter): - new_val = torch.nn.Parameter(new_val) - setattr(submodule, param_name, new_val) - - del state_dict + _move_model_to_meta(model, loaded_state_dict_keys, start_prefix) + state_dict = load_state_dict(resolved_archive_file) + error_msgs = _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix) + return error_msgs @classmethod def register_for_auto_class(cls, auto_class="AutoModel"): diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index ec681597f0fb01..b60c7942097a14 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -17,6 +17,7 @@ import logging import os import re +import shlex import shutil import sys import tempfile @@ -667,6 +668,20 @@ def require_librosa(test_case): return test_case +def cmd_exists(cmd): + return shutil.which(cmd) is not None + + +def require_usr_bin_time(test_case): + """ + Decorator marking a test that requires `/usr/bin/time` + """ + if not cmd_exists("/usr/bin/time"): + return unittest.skip("test requires /usr/bin/time")(test_case) + else: + return test_case + + def get_gpu_count(): """ Return the number of available gpus (regardless of whether torch, tf or jax is used) @@ -1178,6 +1193,39 @@ def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None): return tmp_dir + def python_one_liner_max_rss(self, one_liner_str): + """ + Runs the passed python one liner (just the code) and returns how much max cpu memory was used to run the + program. + + Args: + one_liner_str (`string`): + a python one liner code that gets passed to `python -c` + + Returns: + max cpu memory bytes used to run the program. This value is likely to vary slightly from run to run. + + Requirements: + this helper needs `/usr/bin/time` to be installed (`apt install time`) + + Example: + + ``` + one_liner_str = 'from transformers import AutoModel; AutoModel.from_pretrained("t5-large")' + max_rss = self.python_one_liner_max_rss(one_liner_str) + ``` + """ + + if not cmd_exists("/usr/bin/time"): + raise ValueError("/usr/bin/time is required, install with `apt install time`") + + cmd = shlex.split(f"/usr/bin/time -f %M python -c '{one_liner_str}'") + with CaptureStd() as cs: + execute_subprocess_async(cmd, env=self.get_env()) + # returned data is in KB so convert to bytes + max_rss = int(cs.err.split("\n")[-2].replace("stderr: ", "")) * 1024 + return max_rss + def tearDown(self): # get_auto_remove_tmp_dir feature: remove registered temp dirs diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index db24ece11faf96..4a4a0eba044f76 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -52,6 +52,7 @@ is_staging_test, require_torch, require_torch_multi_gpu, + require_usr_bin_time, slow, torch_device, ) @@ -2489,6 +2490,56 @@ def test_checkpoint_sharding_from_hub(self): for p1, p2 in zip(model.parameters(), ref_model.parameters()): self.assertTrue(torch.allclose(p1, p2)) + def test_from_pretrained_low_cpu_mem_usage_functional(self): + # test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and + # sharded models + + mnames = [ + "hf-internal-testing/tiny-random-bert-sharded", + "hf-internal-testing/tiny-random-bert", + ] + for mname in mnames: + _ = BertModel.from_pretrained(mname, low_cpu_mem_usage=True) + + @require_usr_bin_time + def test_from_pretrained_low_cpu_mem_usage_measured(self): + # test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default + + mname = "bert-base-cased" + + preamble = "from transformers import AutoModel" + one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=False)' + max_rss_normal = self.python_one_liner_max_rss(one_liner_str) + # print(f"{max_rss_normal=}") + + one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=True)' + max_rss_low_mem = self.python_one_liner_max_rss(one_liner_str) + # print(f"{max_rss_low_mem=}") + + diff_bytes = max_rss_normal - max_rss_low_mem + diff_percent = diff_bytes / max_rss_low_mem + # print(f"{diff_bytes=}, {diff_percent=}") + # ideally we would compare that the diff is close to ~1x checkpoint size in bytes, but + # measuring cpu memory on linux is very tricky and inconsistent, so instead let's check that + # it's at least 15% less cpu memory consumed + + self.assertGreater( + diff_percent, + 0.15, + "should use less CPU memory for low_cpu_mem_usage=True, " + f"but got max_rss_normal={max_rss_normal} and max_rss_low_mem={max_rss_low_mem}", + ) + + # if you want to compare things manually, let's first look at the size of the model in bytes + # model = BertModel.from_pretrained(mname, low_cpu_mem_usage=False) + # total_numel = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()) + # total_bytes = total_numel * 4 # 420MB + # Now the diff_bytes should be very close to total_bytes, but the reports are inconsistent. + # The easiest way to test this is to switch the model and torch.load to do all the work on + # gpu - that way one can measure exactly the total and peak memory used. Perhaps once we add + # functionality to load models directly on gpu, this test can be rewritten to use torch's + # cuda memory tracking and then we should be able to do a much more precise test. + def test_cached_files_are_used_when_internet_is_down(self): # A mock response for an HTTP head request to emulate server down response_mock = mock.Mock()