-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Device agnostic testing #5612
Device agnostic testing #5612
Conversation
@patrickvonplaten |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look very reasonable to me! Thanks a lot for making everything device-agnostic.
Just a bit worried about the is_torch_fp16_available
because we're essestially just saying if the matmul doesn't work fp16 is not available, but the matmul might also not work because of other reasons (badly installed CUDA, OOM, ...)
In PyTorch there is actually a is_bf16_available()
: https://github.com/pytorch/pytorch/blob/d64bc8f0f81bd9b514eb1a5ee6f5b03094e4e6e9/torch/cuda/__init__.py#L141
The function seems to check some device properties which is probably less brittle - guess it's hard to do this for fp16 here, but can we maybe make sure that we don't accidentally misinterpret other erros as fp16 not being available?
Thanks, @patrickvonplaten. I agree that just catching any exception might not be the best way to do this, but I'm not sure if there is a specific exception that would be agnostic to any accelerator or device. On CPU and XLA I believe you get a One suggestion is to write it so the exception is logged and print the error when tests are run, to specify why FP16 is not working? This would make it clear to the user whether it is unsupported behaviour or an issue with their setup. Looking at the PyTorch Another suggestion I have is to add a CUDA specific check and skip the FP16 matmul check in This would still be in line with the changes as these backends have defaults specified for them in the custom function dispatch as well. Then the matmul check would only be used if a non-default device is being used. We could do this and also log the error to make it explicit to the user. Let me know if that makes sense to you and I will add those changes, or any other suggestions you have. Thanks! |
Could we maybe do something like this: https://github.com/huggingface/diffusers/pull/5612/files#r1399038284 just to add an extra safety-mechanism that a sure doesn't understand the function incorrectly in case cuda is badly set up? Also can we make the function private for now, e.g. add an underscore so that it's |
I've added the changes, slightly restructured so the FP16 op-check happens by default for all accelerators and the CUDA error is raised only if the device type is |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool thanks for iterating here. This PR LGTM - should we merge it now or do you want to add tests for other classes directly here?
|
||
|
||
# Guard for when Torch is not available | ||
if is_torch_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this meant to run when torch
isn't available or if DIFFUSERS_TEST_DEVICE_SPEC
is set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The guard is there because the function dispatch should only run if torch is available, it doesn't strictly matter if DIFFUSERS_TEST_DEVICE_SPEC
is set. For example for a GPU, CPU or MPS device, a spec doesn't need to be set but torch
must still be available to dispatch to the default torch device functions.
If thats okay I'll add some more before merging 😄 I had a few other tests ready but removed them for this PR to keep it minimal. |
I've added more test coverage. The latest commit has the changes for most of the model classes (unet, vae, vq, unet2d and some common files) and one pipeline test (SD2). Any more tests could be added in future PRs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! The changes look good to me - @DN6 wdyt? Feel free to merge once you're happy with it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 👍🏽 Nice work @arsalanu!
* utils and test modifications to enable device agnostic testing * device for manual seed in unet1d * fix generator condition in vae test * consistency changes to testing * make style * add device agnostic testing changes to source and one model test * make dtype check fns private, log cuda fp16 case * remove dtype checks from import utils, move to testing_utils * adding tests for most model classes and one pipeline * fix vae import
* utils and test modifications to enable device agnostic testing * device for manual seed in unet1d * fix generator condition in vae test * consistency changes to testing * make style * add device agnostic testing changes to source and one model test * make dtype check fns private, log cuda fp16 case * remove dtype checks from import utils, move to testing_utils * adding tests for most model classes and one pipeline * fix vae import
* utils and test modifications to enable device agnostic testing * device for manual seed in unet1d * fix generator condition in vae test * consistency changes to testing * make style * add device agnostic testing changes to source and one model test * make dtype check fns private, log cuda fp16 case * remove dtype checks from import utils, move to testing_utils * adding tests for most model classes and one pipeline * fix vae import
What does this PR do?
Adds new features to
testing_utils.py
andimport_utils.py
to make testing with non-default PyTorch backends (beyond justcuda
,cpu
andmps
) possible. This should not affect any current testing within the repo or the behaviour of the devices they are run on.This is heavily based on similar work we have done for Transformers, see: Transformers PR #25870
Adds some device agnostic functions which dispatch to specific backend functions. This is mainly applicable to functions which are device-specific (e.g. torch.cuda.manual_seed). Users can specify new backends and backends for device agnostic functions by creating a device specification file and pointing the test suite to it using the environment variable DIFFUSERS_TEST_DEVICE_SPEC, and add a new device for PyTorch using
DIFFUSERS_TEST_DEVICE
.Example of a device specification to run the tests with an alternative accelerator:
Implementation details are fully outlined in the issue #5562
I have a modified a single file (UNet2D condition model tests, and
test_modeling_common
as this is used in the UNet2D tests) rather than all the tests, as this PR is more focused on the implementation of features required for device-agnostic testing.Before submitting
documentation guidelines, and
here are tips on formatting docstrings.