Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix sharding when no device_map is passed #8531

Merged
merged 9 commits into from
Jun 18, 2024

Conversation

SunMarc
Copy link
Member

@SunMarc SunMarc commented Jun 13, 2024

What does this PR do?

This PR fixes the loading for sharded checkpoint when no device_map is passed. Currently, the following doesn't work:

from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline 
import torch

unet = UNet2DConditionModel.from_pretrained(
    "sayakpaul/sdxl-unet-sharded", torch_dtype = torch.float16
)

You can have more details here.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@SunMarc
Copy link
Member Author

SunMarc commented Jun 13, 2024

There is still a path where sharding is not handled. It happens when low_cpu_mem_usage=False. I see that by default, low_cpu_mem_usage is set to True, it is the case for most models ? cc @sayakpaul

@SunMarc SunMarc requested review from sayakpaul and yiyixuxu and removed request for yiyixuxu June 13, 2024 12:51
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!
very nice tests too:)

is it possible to explain device_map=None in the doc string for device_map too?

@SunMarc
Copy link
Member Author

SunMarc commented Jun 14, 2024

is it possible to explain device_map=None in the doc string for device_map too?

Done !

@@ -872,6 +872,39 @@ def test_model_parallelism(self):

@require_torch_gpu
def test_sharded_checkpoints(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is already here:

def test_sharded_checkpoints(self):

Is it different?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

he renamed this test to test_sharded_checkpoints_device_map because in that test it loads with device_map='auto' flag; this is a new test testing default value for device_map

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I renamed the tests since it makes more sense this way

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much, Marc. I think there's some confusion in the tests as they are existing in the main already. Am I missing out on something?

@sayakpaul
Copy link
Member

Alright then! Let’s merge this.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@yiyixuxu yiyixuxu merged commit 96399c3 into huggingface:main Jun 18, 2024
14 of 15 checks passed
yiyixuxu pushed a commit that referenced this pull request Jun 20, 2024
* Fix sharding when no device_map is passed

* style

* add tests

* align

* add docstring

* format

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* Fix sharding when no device_map is passed

* style

* add tests

* align

* add docstring

* format

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants