Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

cuda device mismatch in DataParallel when not using cuda:0 #60

Open
janfb opened this issue May 15, 2024 · 1 comment · May be fixed by #65
Open

cuda device mismatch in DataParallel when not using cuda:0 #60

janfb opened this issue May 15, 2024 · 1 comment · May be fixed by #65

Comments

@janfb
Copy link

janfb commented May 15, 2024

Hi there, thanks for this package, it's really helpful!

On a cluster with multiple GPUs, I have my model on device cuda:1.

When calculating FID with a passed gen function, new samples are generated during FID calculation. To that end, a model_fn(x) function is defined here:

if use_dataparallel:
model = torch.nn.DataParallel(model)
def model_fn(x): return model(x)

and if use_dataparallel=True, the model will be wrapped with model = torch.nn.DataParallel(model).

Problem: DataParallel has a kwarg device_ids=None which defaults to all the available devices and then selects the first device as the "source" device, i.e., cuda:0. Later it asserts that all parameters and buffers of the model are on that device.
Now, if device_ids is not passed, this will result in an error because my model device is different from cuda:0.
I am wondering why DataParallel just hard codes everything to the first of all available devices, but there is a solution on the cleanfid side for this problem.

Solution: pass device_ids with the device of the model:

        if use_dataparallel:
            device_ids = [torch.cuda.current_device()]  # or use next(model.parameters()).device
            model = torch.nn.DataParallel(model, device_ids=device_ids)
        def model_fn(x): return model(x)

I would be happy to make a PR fixing this. Unless I am missing something?

Cheers,
Jan

@janfb janfb changed the title cuda device mismatch when not using cuda:0 cuda device mismatch in DataParallel when not using cuda:0 May 15, 2024
@GaParmar
Copy link
Owner

GaParmar commented Jul 1, 2024

Hi Jan,

Thank you for pointing this out!
Feel free to make a PR.
Your proposed solution makes a lot of sense, I will add it to the main repo!

-Gaurav

@janfb janfb linked a pull request Jul 22, 2024 that will close this issue
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants