Skip to content

Commit

Permalink
[modeling utils] revamp `from_pretrained(..., low_cpu_mem_usage=True)…
Browse files Browse the repository at this point in the history
…` + tests (huggingface#16657)

* add low_cpu_mem_usage tests

* wip: revamping

* wip

* install /usr/bin/time

* wip

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* fix assert

* put the wrapper back

* cleanup; switch to bert-base-cased

* Trigger CI

* Trigger CI
  • Loading branch information
stas00 authored and elusenji committed Jun 12, 2022
1 parent 588c53b commit f5eea6a
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 83 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ jobs:
keys:
- v0.4-torch-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng time
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
Expand Down
236 changes: 154 additions & 82 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,95 @@ def load(module: nn.Module, prefix=""):
return error_msgs


def find_submodule_and_param_name(model, long_key, start_prefix):
"""
A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed
from the start of the key
"""

if len(start_prefix) > 0 and long_key.startswith(start_prefix):
long_key = ".".join(long_key.split(".")[1:])

split_key = long_key.split(".")
submodule = model
while len(split_key) > 1:
if hasattr(submodule, split_key[0]):
submodule = getattr(submodule, split_key[0])
del split_key[0]
else:
submodule = None
break
if submodule == model:
submodule = None
return submodule, split_key[0]


def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
"""
Moves `loaded_state_dict_keys` in model to meta device which frees up the memory taken by those params.
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
`bert.pooler.dense.weight`
"""

# meta device was added in pt=1.9
require_version_core("torch>=1.9")

# dematerialize param storage for keys that are going to be replaced by state_dict, by
# putting those on the meta device
for k in loaded_state_dict_keys:
submodule, param_name = find_submodule_and_param_name(model, k, start_prefix)
if submodule is not None:
# selectively switch to the meta device only those params/buffers that will
# be next replaced from state_dict. This a complex way to do p.to_("meta")
# since we have no in-place to_ for tensors.
new_val = getattr(submodule, param_name)
if isinstance(new_val, torch.nn.Parameter):
# isinstance returns False for Params on meta device, so switch after the check
new_val = torch.nn.Parameter(new_val.to("meta"))
else:
new_val = new_val.to("meta")
setattr(submodule, param_name, new_val)


def _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
params back to the normal device, but only for `loaded_state_dict_keys`.
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
`bert.pooler.dense.weight`
"""

# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
# - deepspeed zero 3 support
# - need to copy metadata if any - see _load_state_dict_into_model
# - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
# they won't get loaded.

if is_deepspeed_zero3_enabled():
raise ValueError("low_cpu_mem_usage arg cannot currently be used with DeepSpeed ZeRO-3")

error_msgs = []

# materialize state_dict entries one by one on CPU
for k in loaded_state_dict_keys:
if k in state_dict:
submodule, param_name = find_submodule_and_param_name(model, k, start_prefix)
if submodule is not None:
param_dtype = getattr(submodule, param_name).dtype
new_val = state_dict[k].to(param_dtype)
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
new_val = torch.nn.Parameter(new_val)
setattr(submodule, param_name, new_val)

return error_msgs


class ModuleUtilsMixin:
"""
A few utilities for `torch.nn.Modules`, to be used as a mixin.
Expand Down Expand Up @@ -1529,7 +1618,24 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
>>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)
>>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
>>> model = BertModel.from_pretrained("bert-base-uncased", from_flax=True)
```"""
```
* `low_cpu_mem_usage` algorithm:
This is an experimental function that loads the model using ~1x model size CPU memory
Here is how it works:
1. save which state_dict keys we have
2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory
3. after the model has been instantiated switch to the meta device all params/buffers that
are going to be replaced from the loaded state_dict
4. load state_dict 2nd time
5. replace the params/buffers from the state_dict
Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors
"""
config = kwargs.pop("config", None)
state_dict = kwargs.pop("state_dict", None)
cache_dir = kwargs.pop("cache_dir", None)
Expand Down Expand Up @@ -1778,6 +1884,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if not is_sharded and state_dict is None:
# Time to load the checkpoint
state_dict = load_state_dict(resolved_archive_file)

# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
Expand All @@ -1801,13 +1908,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)

if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
loaded_state_dict_keys = [k for k in state_dict.keys()]
if low_cpu_mem_usage:
# save the keys
if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
loaded_state_dict_keys = [k for k in state_dict.keys()]
del state_dict # free CPU memory - will reload again later
state_dict = None

config.name_or_path = pretrained_model_name_or_path

Expand All @@ -1825,11 +1931,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
with no_init_weights(_enable=_fast_init):
model = cls(config, *model_args, **model_kwargs)

if from_pt:
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

if from_tf:
if resolved_archive_file.endswith(".index"):
# Load from a TensorFlow 1.X checkpoint - provided by original authors
Expand Down Expand Up @@ -1859,18 +1960,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
raise
elif from_pt:

if low_cpu_mem_usage:
cls._load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file)
else:
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
)
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=low_cpu_mem_usage,
)

# make sure token embedding weights are still tied if needed
model.tie_weights()
Expand All @@ -1894,16 +1998,17 @@ def _load_pretrained_model(
cls,
model,
state_dict,
loaded_keys,
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=False,
sharded_metadata=None,
_fast_init=True,
low_cpu_mem_usage=False,
):
# Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
loaded_keys = list(state_dict.keys()) if state_dict is not None else sharded_metadata["all_checkpoint_keys"]
prefix = model.base_model_prefix

def _fix_key(key):
Expand Down Expand Up @@ -1994,9 +2099,12 @@ def _find_mismatched_keys(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]

return mismatched_keys

if low_cpu_mem_usage:
model_state_dict = None # free references to model's params to allow memory freeing
_move_model_to_meta(model, loaded_keys, start_prefix)

if state_dict is not None:
# Whole checkpoint
mismatched_keys = _find_mismatched_keys(
Expand All @@ -2009,7 +2117,8 @@ def _find_mismatched_keys(
)
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
else:
# Sharded checkpoint
# Sharded checkpoint or whole but low_cpu_mem_usage==True

# This should always be a list but, just to be sure.
if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]
Expand All @@ -2018,6 +2127,10 @@ def _find_mismatched_keys(
mismatched_keys = []
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)

if low_cpu_mem_usage:
model_state_dict = model.state_dict()

# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys += _find_mismatched_keys(
Expand All @@ -2028,7 +2141,13 @@ def _find_mismatched_keys(
remove_prefix_from_model,
ignore_mismatched_sizes,
)
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)

if low_cpu_mem_usage:
error_msgs += _load_state_dict_into_meta_model(
model_to_load, state_dict, loaded_keys, start_prefix
)
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)

if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
Expand Down Expand Up @@ -2093,13 +2212,13 @@ def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=Fal
return retrieved_modules

@staticmethod
def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file):
def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file, start_prefix=""):
"""
This is an experimental function that loads the model using ~1.x model size CPU memory
Before it gets called we do:
Before you call it do:
1. save which state_dict keys we have
1. save which state_dict keys are available
2. drop state_dict before model is created, since the latter takes 1x model size memory
Here then we continue:
Expand All @@ -2110,58 +2229,11 @@ def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archi
Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.
"""
require_version_core("torch>=1.9")
if is_deepspeed_zero3_enabled():
raise ValueError("low_cpu_mem_usage arg cannot be used with DeepSpeed ZeRO-3")

# a helper util to find the last sub-module and the param/buffer name
def find_submodule_and_param_name(model, long_key):
split_key = long_key.split(".")
submodule = model
while len(split_key) > 1:
if hasattr(submodule, split_key[0]):
submodule = getattr(submodule, split_key[0])
del split_key[0]
else:
submodule = None
break
return submodule, split_key[0]

# dematerialize param storage for keys that are going to be replaced by state_dict, by
# putting those on the meta device
for k in loaded_state_dict_keys:
submodule, param_name = find_submodule_and_param_name(model, k)
if submodule is not None:
# selectively switch to the meta device only those params/buffers that will
# be next replaced from state_dict. This a complex way to do p.to_("meta")
# since we have no in-place to_ for tensors.
new_val = getattr(submodule, param_name)
if isinstance(new_val, torch.nn.Parameter):
# isinstance returns False for Params on meta device, so switch after the check
new_val = torch.nn.Parameter(new_val.to("meta"))
else:
new_val = new_val.to("meta")
setattr(submodule, param_name, new_val)

# only now can load state_dict(s)
if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]

for archive_file in resolved_archive_file:
state_dict = torch.load(archive_file, map_location="cpu")

# materialize state_dict entries one by one on CPU
for k in loaded_state_dict_keys:
if k in state_dict:
submodule, param_name = find_submodule_and_param_name(model, k)
if submodule is not None:
param_dtype = getattr(submodule, param_name).dtype
new_val = state_dict[k].to(param_dtype)
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
new_val = torch.nn.Parameter(new_val)
setattr(submodule, param_name, new_val)

del state_dict
_move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
state_dict = load_state_dict(resolved_archive_file)
error_msgs = _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix)
return error_msgs

@classmethod
def register_for_auto_class(cls, auto_class="AutoModel"):
Expand Down
Loading

0 comments on commit f5eea6a

Please # to comment.