diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7f1b12386202fa..668c8bfca9ed9d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1984,8 +1984,14 @@ def _fix_key(key): "properly saved?" ) - if state_dict is not None: - # Whole checkpoint + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ): mismatched_keys = [] if ignore_mismatched_sizes: for checkpoint_key in loaded_keys: @@ -2006,6 +2012,18 @@ def _fix_key(key): ) del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ) error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix) else: # Sharded checkpoint @@ -2014,30 +2032,19 @@ def _fix_key(key): resolved_archive_file = [resolved_archive_file] error_msgs = [] + mismatched_keys = [] for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file) # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. - mismatched_keys = [] - if ignore_mismatched_sizes: - for checkpoint_key in loaded_keys: - model_key = checkpoint_key - if remove_prefix_from_model: - # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. - model_key = f"{prefix}.{checkpoint_key}" - elif add_prefix_to_model: - # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. - model_key = ".".join(checkpoint_key.split(".")[1:]) - - if ( - model_key in model_state_dict - and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape - ): - mismatched_keys.append( - (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) - ) - del state_dict[checkpoint_key] - + mismatched_keys += _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ) error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix) if len(error_msgs) > 0: