Skip to content

fix "Expected all tensors to be on the same device, but found at least two devices" error #11690

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

yao-matrix
Copy link
Contributor

@yao-matrix yao-matrix commented Jun 11, 2025

  1. when run pytest -rA tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_device_map_from_hub_local on 8 devices(CUDA, XPU), there will be a RuntimeError "RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:4 and cuda:1!", fix it by moving to same device
  2. enable one gpu-only case on accelerator(XPU test passed)

@sayakpaul , pls help review, thx.

@sayakpaul sayakpaul requested a review from SunMarc June 11, 2025 02:25
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yao-matrix
Copy link
Contributor Author

@SunMarc , could you pls help review? Thx very much.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! Left a comment

Comment on lines +2560 to +2561
if hidden_states.device != res_hidden_states.device:
res_hidden_states = res_hidden_states.to(hidden_states.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't need that since both hidden_states and res_hidden_states should be on the same device no ? The pre-forward hook added by accelerate should be move all the inputs to the same device.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc , i suppose this is a corner case? torch.cat is a weight-less function, so seems cannot covered by the pre-forward hook set by accelerate...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean since hidden_states and res_hidden_states_tuple are in the forward definition, they should be moved to the same device by the pre-forward hook added by accelerate

Copy link
Contributor Author

@yao-matrix yao-matrix Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc We run into a corner case here. Since we have 8 cards here, so the determined device_map(by https://github.com/huggingface/diffusers/blob/1bc6f3dc0f21779480db70a4928d14282c0198ed/src/diffusers/models/model_loading_utils.py#L64C5-L64C26) is

device_map: OrderedDict([('conv_in', 0), ('time_proj', 0), ('time_embedding', 0), ('down_blocks.0', 0), ('down_blocks.1.resnets.0', 1), ('up_blocks.0.resnets.0', 1), ('up_blocks.0.resnets.1', 2), ('up_blocks.0.upsamplers', 2), ('up_blocks.1', 3), ('mid_block.attentions', 3), ('conv_norm_out', 4), ('conv_act', 4), ('conv_out', 4), ('mid_block.resnets', 4)])

We can see UpBlock is not the atomic module, its submodules are assigned to different devices(up_blocks.0.resnets.0, up_blocks.0.resnets.1), so pre-hook for UpBlock will not help in this case. And since torch.cat is not pre-hooked(and cannot since it's a function rather than a module?), so the issue happens.

If there is no a torch.cat btw the sub-blocks in UpBlock, things will be all fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc, need your inputs in how to proceed for this corner case, thx.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants