From 163c8bbdc9c1b2b328b26c3e7acac2083625888a Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 28 Jan 2025 11:38:45 +0100 Subject: [PATCH] Fix: loading DBRX back from saved path (#35728) * fix dtype as dict for some models + add test * add comment in tests --- src/transformers/modeling_utils.py | 29 ++++++++-- .../models/dbrx/configuration_dbrx.py | 4 +- tests/test_modeling_common.py | 6 +++ tests/utils/test_modeling_utils.py | 54 +++++++++++++++++++ 4 files changed, 87 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8eb2d7439ef3..ec855c8347b6 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4020,10 +4020,31 @@ def from_pretrained( ) elif hasattr(torch, torch_dtype): torch_dtype = getattr(torch, torch_dtype) - else: - raise ValueError( - f'`torch_dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`, but received {torch_dtype}' - ) + for sub_config_key in config.sub_configs.keys(): + sub_config = getattr(config, sub_config_key) + sub_config.torch_dtype = torch_dtype + elif isinstance(torch_dtype, torch.dtype): + for sub_config_key in config.sub_configs.keys(): + sub_config = getattr(config, sub_config_key) + sub_config.torch_dtype = torch_dtype + elif isinstance(torch_dtype, dict): + for key, curr_dtype in torch_dtype.items(): + if hasattr(config, key): + value = getattr(config, key) + value.torch_dtype = curr_dtype + # main torch dtype for modules that aren't part of any sub-config + torch_dtype = torch_dtype.get("") + config.torch_dtype = torch_dtype + if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype): + torch_dtype = getattr(torch, torch_dtype) + elif torch_dtype is None: + torch_dtype = torch.float32 + else: + raise ValueError( + f"`torch_dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `torch_dtype` " + f"for each sub-config in composite configs, but received {torch_dtype}" + ) + dtype_orig = cls._set_default_torch_dtype(torch_dtype) # Check if `_keep_in_fp32_modules` is not None diff --git a/src/transformers/models/dbrx/configuration_dbrx.py b/src/transformers/models/dbrx/configuration_dbrx.py index 7935b1d1beb7..72df1fe335ba 100644 --- a/src/transformers/models/dbrx/configuration_dbrx.py +++ b/src/transformers/models/dbrx/configuration_dbrx.py @@ -57,7 +57,7 @@ def __init__( self.kv_n_heads = kv_n_heads self.rope_theta = rope_theta - for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash"]: + for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype"]: if k in kwargs: kwargs.pop(k) if len(kwargs) != 0: @@ -109,7 +109,7 @@ def __init__( self.moe_loss_weight = moe_loss_weight self.moe_normalize_expert_weights = moe_normalize_expert_weights - for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash"]: + for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype"]: if k in kwargs: kwargs.pop(k) if len(kwargs) != 0: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 922d5b827252..45a10d90b755 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -331,6 +331,12 @@ def check_save_load(out1, out2): with torch.no_grad(): second = model(**self._prepare_for_class(inputs_dict, model_class))[0] + # Save and load second time because `from_pretrained` adds a bunch of new config fields + # so we need to make sure those fields can be loaded back after saving + # Simply init as `model(config)` doesn't add those fields + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname) + if isinstance(first, tuple) and isinstance(second, tuple): for tensor1, tensor2 in zip(first, second): check_save_load(tensor1, tensor2) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 383f0cbe60e1..57c756d343fc 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -460,6 +460,60 @@ def test_model_from_config_torch_dtype_str(self): with self.assertRaises(ValueError): model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64") + def test_model_from_config_torch_dtype_composite(self): + """ + Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config + Tiny-Llava has saved auto dtype as `torch.float32` for all modules. + """ + # should be able to set torch_dtype as a simple string and the model loads it correctly + model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32") + self.assertEqual(model.language_model.dtype, torch.float32) + self.assertEqual(model.vision_tower.dtype, torch.float32) + + model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype=torch.float16) + self.assertEqual(model.language_model.dtype, torch.float16) + self.assertEqual(model.vision_tower.dtype, torch.float16) + + # should be able to set torch_dtype as a dict for each sub-config + model = LlavaForConditionalGeneration.from_pretrained( + TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "float16", "": "bfloat16"} + ) + self.assertEqual(model.language_model.dtype, torch.float32) + self.assertEqual(model.vision_tower.dtype, torch.float16) + self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16) + + # should be able to set the values as torch.dtype (not str) + model = LlavaForConditionalGeneration.from_pretrained( + TINY_LLAVA, torch_dtype={"text_config": torch.float32, "vision_config": torch.float16, "": torch.bfloat16} + ) + self.assertEqual(model.language_model.dtype, torch.float32) + self.assertEqual(model.vision_tower.dtype, torch.float16) + self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16) + + # should be able to set the values in configs directly and pass it to `from_pretrained` + config = copy.deepcopy(model.config) + config.text_config.torch_dtype = torch.float32 + config.vision_config.torch_dtype = torch.bfloat16 + config.torch_dtype = torch.float16 + model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto") + self.assertEqual(model.language_model.dtype, torch.float32) + self.assertEqual(model.vision_tower.dtype, torch.bfloat16) + self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16) + + # but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what + LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"] + model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto") + self.assertEqual(model.language_model.dtype, torch.float32) + self.assertEqual(model.vision_tower.dtype, torch.bfloat16) + self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32) + + # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type + with self.assertRaises(ValueError): + model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="int64") + model = LlavaForConditionalGeneration.from_pretrained( + TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "int64", "": "float16"} + ) + @require_torch def test_model_from_pretrained_meta_device(self): def is_on_meta(model_id, dtype):