Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Core] add suport for using sharded models in the pipeline context #8428

Closed
wants to merge 5 commits into from

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Jun 7, 2024

What does this PR do?

After adding support for model sharding through #6396 and #7830, it's now time for us to use a shared model within a pipeline.

This PR enables that.

TODOs

  • Add docs
  • Add tests

Currently, the following works:

from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline 
import torch

unet = UNet2DConditionModel.from_pretrained(
    "sayakpaul/sdxl-unet-sharded", device_map="auto"
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", unet=unet, torch_dtype=torch.float16
).to("cuda")

image = pipeline("dog", num_inference_steps=20).images[0]

However, it prints the following:

You shouldn't move a model that is dispatched using accelerate hooks.

Is this relevant/unsafe in our context?

Once I get initial approval from @SunMarc, I will proceed with the rest of the TODOs.

@sayakpaul sayakpaul requested a review from SunMarc June 7, 2024 09:48
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@SunMarc
Copy link
Member

SunMarc commented Jun 7, 2024

You get this error : You shouldn't move a model that is dispatched using accelerate hooks because you moved the pipeline to "cuda". However, the unet was probably dispatched on multi-gpu or with force_hook=True when you used device_map="auto".

pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", unet=unet, torch_dtype=torch.float16
).to("cuda")

I'm not use how to handle the case where we can pass a model that is loaded with device_map='auto' and a pipeline that can also be loaded with device_map=auto + take a model as input. It might be confusing to the users and we should try to make clear what is the best way.

@sayakpaul
Copy link
Member Author

I'm not use how to handle the case where we can pass a model that is loaded with device_map='auto' and a pipeline that can also be loaded with device_map=auto + take a model as input. It might be confusing to the users and we should try to make clear what is the best way.

But we are not doing auto device_map in the pipeline. What restrictions can we impose internally within the pipeline implementation? Perhaps not force hooks when _no_split_modules is not None?

@SunMarc
Copy link
Member

SunMarc commented Jun 7, 2024

But we are not doing auto device_map in the pipeline. What restrictions can we impose internally within the pipeline implementation? Perhaps not force hooks when _no_split_modules is not None?

Yes, but I was thinking that the user might do that. Not sure yet. I will test a few combination to see what error pops out. But I don't think we should create a model with device_map and use pipeline. In a multi-gpu setup, I think that the above script will fail.

@sayakpaul
Copy link
Member Author

A model within diffusers will most likely always be used with pipelines. So, we need to consider it.

For a multi-GPU setup I think we should restrict it, for sure.

@sayakpaul
Copy link
Member Author

@SunMarc I will let you run a couple tests and let me know if the above plan is good.

@SunMarc
Copy link
Member

SunMarc commented Jun 10, 2024

I tested a bit the this feature and I think that we should try to make the following work instead.

from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline 
import torch

unet = UNet2DConditionModel.from_pretrained(
    "sayakpaul/sdxl-unet-sharded", torch_dtype=torch.float16
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", unet=unet, torch_dtype=torch.float16
).to("cuda")

image = pipeline("dog", num_inference_steps=20).images[0]

Right now, it returns an error because since device_map is None we go through the following code when loading the sharded checkpoint (in load_checkpoint_in_model). This fails because in diffusers, strict is hardcoded to True to capture a specific error. However, to load a sharded checkpoint with model.load_state_dict, we would have to set strict = False. Hence, this is not an option to go this way.

for checkpoint_file in checkpoint_files:
        loaded_checkpoint = load_state_dict(checkpoint_file, device_map=device_map)
        if device_map is None:
            model.load_state_dict(loaded_checkpoint, strict=strict)
            unexpected_keys.update(set(loaded_checkpoint.keys()) - model_keys)

One way to do that would be to the following modification. This would simply load the sharded model on cpu without the hooks, so you can safely move the model. And we would still be able to capture the AttributeError

                else:  # else let accelerate handle loading and dispatching.
                    # Load weights and dispatch according to the device_map
                    # by default the device_map is None and the weights are loaded on the CPU
                    force_hook = True
                    device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
                    if device_map is None and is_sharded:
                        # we load the parameters on the cpu
                        device_map = {"":"cpu"}
                        force_hook = False
                    try:
                        accelerate.load_checkpoint_and_dispatch(
                            model,
                            model_file if not is_sharded else sharded_ckpt_cached_folder,
                            device_map,
                            max_memory=max_memory,
                            offload_folder=offload_folder,
                            offload_state_dict=offload_state_dict,
                            dtype=torch_dtype,
                            force_hooks=force_hook,
                            strict=True
                        )

Another way would be to make the following logic compatible with loading sharded checkpoint but we would essentially rewrite what was done in load_checkpoint_and_dispatch (from_pretrained code):

                if device_map is None and not is_sharded:
                    param_device = "cpu"
                    state_dict = load_state_dict(model_file, variant=variant)
                    ...

LMK what you think @sayakpaul @yiyixuxu !

Lastly, about passing device_map="auto" in the model, then passing it to the pipeline, this should be restricted in multi-gpu setup. However, even in single device setup, because of force_hook that it set to True by default, we can't move the pipeline whenever we want as we might get device mismatch. So either modify a bit how we set force_hook, remove the hooks when passing a model to pipeline or raise an error when that happens.

from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline 
import torch

unet = UNet2DConditionModel.from_pretrained(
    "sayakpaul/sdxl-unet-sharded", torch_dtype=torch.float16, device_map=0
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", unet=unet, torch_dtype=torch.float16
).to(1)

image = pipeline("dog", num_inference_steps=20).images[0]

@sayakpaul
Copy link
Member Author

Thanks for the investigation! I think you already have a good idea, would you mind opening a PR?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jun 13, 2024

@SunMarc, thanks for the long explanation!
I see this huggingface/accelerate#2640 make it complicated to load shared checkpoints with device_map=None

however I think this error message we want to catch here is specifically for loading from a deprecated checkpoint

except AttributeError as e:

since :

  1. user will not be able to save the checkpoint in the deprecated format once they loaded it
  2. sharding in diffusers is a newly introduced feature

I think it is safe to make the strict flag False when it is a shared checkpoint?

cc @pcuenca here too if you have time!

@SunMarc
Copy link
Member

SunMarc commented Jun 13, 2024

Hi @yiyixuxu, I'm fine with setting strict to False if it is okay ! Note however that we would have to also set assign to True. model.load_state_dict(loaded_checkpoint, strict=strict, assign=True) in accelerate
Otherwise, we would get these warnings that you might have seen:

/home/marc/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py:2047: UserWarning: for mid_block.attentions.0.transformer_blocks.7.norm2.weight: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to pass assign=True to assign items in the state dictionary to their corresponding key in the module instead of copying them in place?)

This might have some complications as explained here with respect to training. I've opened a PR with the solution I proposed in the long comment. Let me know which one you prefer !

@sayakpaul
Copy link
Member Author

@SunMarc should I close this PR in light of #8531?

@SunMarc
Copy link
Member

SunMarc commented Jun 18, 2024

@SunMarc should I close this PR in light of #8531?

Yes and you can merge the other PR !

@sayakpaul sayakpaul closed this Jun 18, 2024
@sayakpaul
Copy link
Member Author

sayakpaul commented Jun 21, 2024

@SunMarc when I do along with the changes from your PR and this PR:

from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline 
import torch

unet = UNet2DConditionModel.from_pretrained(
    "sayakpaul/sdxl-unet-sharded", torch_dtype=torch.float16, device_map="auto"
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", unet=unet, torch_dtype=torch.float16
).to("cuda")

image = pipeline("a cute dog running on the grass", num_inference_steps=30).images[0]
image.save("dog.png")

I still face:

You shouldn't move a model that is dispatched using accelerate hooks.

@sayakpaul sayakpaul reopened this Jun 21, 2024
@SunMarc
Copy link
Member

SunMarc commented Jun 21, 2024

Hi @sayakpaul, remove the device_map arg from

unet = UNet2DConditionModel.from_pretrained(
    "sayakpaul/sdxl-unet-sharded", torch_dtype=torch.float16, device_map="auto"
)

and it should work as expected !

@sayakpaul
Copy link
Member Author

My bad.

@sayakpaul sayakpaul closed this Jun 21, 2024
@sayakpaul sayakpaul deleted the model-sharding-with-pipeline branch June 21, 2024 10:26
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants