Skip to content

Commit 41e4779

Browse files
sayakpaulBenjaminBossanSunMarcstevhliu
authoredOct 31, 2024··
[LoRA] fix: lora loading when using with a device_mapped model. (#9449)
* fix: lora loading when using with a device_mapped model. * better attibutung * empty Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * minors * better error messages. * fix-copies * add: tests, docs. * add hardware note. * quality * Update docs/source/en/training/distributed_inference.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * fixes * skip properly. * fixes --------- Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
1 parent ff182ad commit 41e4779

22 files changed

+546
-8
lines changed
 

‎docs/source/en/training/distributed_inference.md

+2
Original file line numberDiff line numberDiff line change
@@ -237,3 +237,5 @@ with torch.no_grad():
237237
```
238238

239239
By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.
240+
241+
This workflow is also compatible with LoRAs via [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. However, only LoRAs without text encoder components are currently supported in this workflow.

‎src/diffusers/loaders/lora_base.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
delete_adapter_layers,
3232
deprecate,
3333
is_accelerate_available,
34+
is_accelerate_version,
3435
is_peft_available,
3536
is_transformers_available,
3637
logging,
@@ -214,9 +215,18 @@ def _optionally_disable_offloading(cls, _pipeline):
214215
is_model_cpu_offload = False
215216
is_sequential_cpu_offload = False
216217

218+
def model_has_device_map(model):
219+
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
220+
return False
221+
return getattr(model, "hf_device_map", None) is not None
222+
217223
if _pipeline is not None and _pipeline.hf_device_map is None:
218224
for _, component in _pipeline.components.items():
219-
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
225+
if (
226+
isinstance(component, nn.Module)
227+
and hasattr(component, "_hf_hook")
228+
and not model_has_device_map(component)
229+
):
220230
if not is_model_cpu_offload:
221231
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
222232
if not is_sequential_cpu_offload:

‎src/diffusers/loaders/unet.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
get_adapter_name,
4040
get_peft_kwargs,
4141
is_accelerate_available,
42+
is_accelerate_version,
4243
is_peft_version,
4344
is_torch_version,
4445
logging,
@@ -398,9 +399,18 @@ def _optionally_disable_offloading(cls, _pipeline):
398399
is_model_cpu_offload = False
399400
is_sequential_cpu_offload = False
400401

402+
def model_has_device_map(model):
403+
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
404+
return False
405+
return getattr(model, "hf_device_map", None) is not None
406+
401407
if _pipeline is not None and _pipeline.hf_device_map is None:
402408
for _, component in _pipeline.components.items():
403-
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
409+
if (
410+
isinstance(component, nn.Module)
411+
and hasattr(component, "_hf_hook")
412+
and not model_has_device_map(component)
413+
):
404414
if not is_model_cpu_offload:
405415
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
406416
if not is_sequential_cpu_offload:

‎src/diffusers/pipelines/pipeline_loading_utils.py

+7
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
deprecate,
3737
get_class_from_dynamic_module,
3838
is_accelerate_available,
39+
is_accelerate_version,
3940
is_peft_available,
4041
is_transformers_available,
4142
logging,
@@ -947,3 +948,9 @@ def _get_ignore_patterns(
947948
)
948949

949950
return ignore_patterns
951+
952+
953+
def model_has_device_map(model):
954+
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
955+
return False
956+
return getattr(model, "hf_device_map", None) is not None

‎src/diffusers/pipelines/pipeline_utils.py

+31
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
_update_init_kwargs_with_connected_pipeline,
8686
load_sub_model,
8787
maybe_raise_or_warn,
88+
model_has_device_map,
8889
variant_compatible_siblings,
8990
warn_deprecated_model_variant,
9091
)
@@ -406,6 +407,16 @@ def module_is_offloaded(module):
406407

407408
return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
408409

410+
# device-mapped modules should not go through any device placements.
411+
device_mapped_components = [
412+
key for key, component in self.components.items() if model_has_device_map(component)
413+
]
414+
if device_mapped_components:
415+
raise ValueError(
416+
"The following pipeline components have been found to use a device map: "
417+
f"{device_mapped_components}. This is incompatible with explicitly setting the device using `to()`."
418+
)
419+
409420
# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
410421
pipeline_is_sequentially_offloaded = any(
411422
module_is_sequentially_offloaded(module) for _, module in self.components.items()
@@ -1002,6 +1013,16 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
10021013
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
10031014
default to "cuda".
10041015
"""
1016+
# device-mapped modules should not go through any device placements.
1017+
device_mapped_components = [
1018+
key for key, component in self.components.items() if model_has_device_map(component)
1019+
]
1020+
if device_mapped_components:
1021+
raise ValueError(
1022+
"The following pipeline components have been found to use a device map: "
1023+
f"{device_mapped_components}. This is incompatible with `enable_model_cpu_offload()`."
1024+
)
1025+
10051026
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
10061027
if is_pipeline_device_mapped:
10071028
raise ValueError(
@@ -1104,6 +1125,16 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
11041125
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
11051126
default to "cuda".
11061127
"""
1128+
# device-mapped modules should not go through any device placements.
1129+
device_mapped_components = [
1130+
key for key, component in self.components.items() if model_has_device_map(component)
1131+
]
1132+
if device_mapped_components:
1133+
raise ValueError(
1134+
"The following pipeline components have been found to use a device map: "
1135+
f"{device_mapped_components}. This is incompatible with `enable_sequential_cpu_offload()`."
1136+
)
1137+
11071138
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
11081139
from accelerate import cpu_offload
11091140
else:

‎tests/pipelines/audioldm2/test_audioldm2.py

+5
Original file line numberDiff line numberDiff line change
@@ -506,9 +506,14 @@ def test_to_dtype(self):
506506
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
507507
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
508508

509+
@unittest.skip("Test currently not supported.")
509510
def test_sequential_cpu_offload_forward_pass(self):
510511
pass
511512

513+
@unittest.skip("Test currently not supported.")
514+
def test_calling_mco_raises_error_device_mapped_components(self):
515+
pass
516+
512517

513518
@nightly
514519
class AudioLDM2PipelineSlowTests(unittest.TestCase):

‎tests/pipelines/controlnet/test_controlnet.py

+24
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,18 @@ def test_inference_multiple_prompt_input(self):
514514

515515
assert image.shape == (4, 64, 64, 3)
516516

517+
@unittest.skip("Test not supported.")
518+
def test_calling_mco_raises_error_device_mapped_components(self):
519+
pass
520+
521+
@unittest.skip("Test not supported.")
522+
def test_calling_to_raises_error_device_mapped_components(self):
523+
pass
524+
525+
@unittest.skip("Test not supported.")
526+
def test_calling_sco_raises_error_device_mapped_components(self):
527+
pass
528+
517529

518530
class StableDiffusionMultiControlNetOneModelPipelineFastTests(
519531
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
@@ -697,6 +709,18 @@ def test_save_pretrained_raise_not_implemented_exception(self):
697709
except NotImplementedError:
698710
pass
699711

712+
@unittest.skip("Test not supported.")
713+
def test_calling_mco_raises_error_device_mapped_components(self):
714+
pass
715+
716+
@unittest.skip("Test not supported.")
717+
def test_calling_to_raises_error_device_mapped_components(self):
718+
pass
719+
720+
@unittest.skip("Test not supported.")
721+
def test_calling_sco_raises_error_device_mapped_components(self):
722+
pass
723+
700724

701725
@slow
702726
@require_torch_gpu

‎tests/pipelines/controlnet/test_controlnet_img2img.py

+12
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,18 @@ def test_save_pretrained_raise_not_implemented_exception(self):
389389
except NotImplementedError:
390390
pass
391391

392+
@unittest.skip("Test not supported.")
393+
def test_calling_mco_raises_error_device_mapped_components(self):
394+
pass
395+
396+
@unittest.skip("Test not supported.")
397+
def test_calling_to_raises_error_device_mapped_components(self):
398+
pass
399+
400+
@unittest.skip("Test not supported.")
401+
def test_calling_sco_raises_error_device_mapped_components(self):
402+
pass
403+
392404

393405
@slow
394406
@require_torch_gpu

‎tests/pipelines/controlnet/test_controlnet_inpaint.py

+12
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,18 @@ def test_save_pretrained_raise_not_implemented_exception(self):
441441
except NotImplementedError:
442442
pass
443443

444+
@unittest.skip("Test not supported.")
445+
def test_calling_mco_raises_error_device_mapped_components(self):
446+
pass
447+
448+
@unittest.skip("Test not supported.")
449+
def test_calling_to_raises_error_device_mapped_components(self):
450+
pass
451+
452+
@unittest.skip("Test not supported.")
453+
def test_calling_sco_raises_error_device_mapped_components(self):
454+
pass
455+
444456

445457
@slow
446458
@require_torch_gpu

‎tests/pipelines/controlnet/test_controlnet_sdxl.py

+24
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,18 @@ def test_inference_batch_single_identical(self):
683683
def test_save_load_optional_components(self):
684684
return self._test_save_load_optional_components()
685685

686+
@unittest.skip("Test not supported.")
687+
def test_calling_mco_raises_error_device_mapped_components(self):
688+
pass
689+
690+
@unittest.skip("Test not supported.")
691+
def test_calling_to_raises_error_device_mapped_components(self):
692+
pass
693+
694+
@unittest.skip("Test not supported.")
695+
def test_calling_sco_raises_error_device_mapped_components(self):
696+
pass
697+
686698

687699
class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
688700
PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
@@ -887,6 +899,18 @@ def test_negative_conditions(self):
887899

888900
self.assertTrue(np.abs(image_slice_without_neg_cond - image_slice_with_neg_cond).max() > 1e-2)
889901

902+
@unittest.skip("Test not supported.")
903+
def test_calling_mco_raises_error_device_mapped_components(self):
904+
pass
905+
906+
@unittest.skip("Test not supported.")
907+
def test_calling_to_raises_error_device_mapped_components(self):
908+
pass
909+
910+
@unittest.skip("Test not supported.")
911+
def test_calling_sco_raises_error_device_mapped_components(self):
912+
pass
913+
890914

891915
@slow
892916
@require_torch_gpu

0 commit comments

Comments
 (0)
Please sign in to comment.