-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Apple MPS error in unet_2d_condition.py #358
Comments
@pcuenca Thank you for the fixes! Much obliged! Will wait for the merge eagerly 😄 |
Unfortunately, with the 0.3.0 release installed, this issue crops up on line 95 in Update: tagging @pcuenca since the ticket is closed and not sure if anybody gets notified. |
Hi @FahimF! Works for me. Would you mind sharing a code snippet so I can try to reproduce? Also, some information about your setup could be useful. Thanks a lot! |
@pcuenca Thank you for taking a look. Let me try to remove the fix I put in the code and come up with something simple to demonstrate the issue. I know at least one of those was while using the Here's what I have in my console from that particular crash:
I see from the error above that I was also using As far as set up goes, latest pytorch nightly (installed today) and diffusers 0.3.0 (installed today) and on a 2021 MBP. If you need any additional info (I don't know what would help and what won't) please let me know and I'll provide. Update:
|
Thanks a lot, @FahimF, I'll test the image to image pipeline and report back. |
Just updating that this issue might have been due to something in the PyTorch nightlies. I could not generate valid images via img2img either and then updated PyTorch nightly (and did a clean install) and that issue went away. So I tested for this one and this one's gone too ... |
Actually I could reproduce this issue using the img2img pipeline as you said with the |
Cool 😄 I did see that you had a PR but just letting you know just in case I sent you on a wild-goose chase. I really have no idea what happened but at least two bugs I had 3 days ago have disappeared with the PyTorch nightly from yesterday. |
That's interesting, PyTorch must have merged some fixes maybe? We'll have to test again in case they are falling back to CPU and performance degrades. Thanks! |
Sure thing 😄 If you need any additional info, please let me know but the nightly build that I'm running where I don't have the issues is: |
@pcuenca Sorry to bug you about a totally separate issue, but I tagged you in a closed ticket about an issue which was fixed but still persists (In a different file) here: #239 (comment) Just mentioning since I don't know if you get notifications for closed tickets 😄 If you'd prefer that I create a new ticket for that, I can do so. Please let me know. |
* Initial support for mps in Stable Diffusion pipeline. * Initial "warmup" implementation when using mps. * Make some deterministic tests pass with mps. * Disable training tests when using mps. * SD: generate latents in CPU then move to device. This is especially important when using the mps device, because generators are not supported there. See for example pytorch/pytorch#84288. In addition, the other pipelines seem to use the same approach: generate the random samples then move to the appropriate device. After this change, generating an image in MPS produces the same result as when using the CPU, if the same seed is used. * Remove prints. * Pass AutoencoderKL test_output_pretrained with mps. Sampling from `posterior` must be done in CPU. * Style * Do not use torch.long for log op in mps device. * Perform incompatible padding ops in CPU. UNet tests now pass. See pytorch/pytorch#84535 * Style: fix import order. * Remove unused symbols. * Remove MPSWarmupMixin, do not apply automatically. We do apply warmup in the tests, but not during normal use. This adopts some PR suggestions by @patrickvonplaten. * Add comment for mps fallback to CPU step. * Add README_mps.md for mps installation and use. * Apply `black` to modified files. * Restrict README_mps to SD, show measures in table. * Make PNDM indexing compatible with mps. Addresses huggingface#239. * Do not use float64 when using LDMScheduler. Fixes huggingface#358. * Fix typo identified by @patil-suraj Co-authored-by: Suraj Patil <surajp815@gmail.com> * Adapt example to new output style. * Restore 1:1 results reproducibility with CompVis. However, mps latents need to be generated in CPU because generators don't work in the mps device. * Move PyTorch nightly to requirements. * Adapt `test_scheduler_outputs_equivalence` ton MPS. * mps: skip training tests instead of ignoring silently. * Make VQModel tests pass on mps. * mps ddim tests: warmup, increase tolerance. * ScoreSdeVeScheduler indexing made mps compatible. * Make ldm pipeline tests pass using warmup. * Style * Simplify casting as suggested in PR. * Add Known Issues to readme. * `isort` import order. * Remove _mps_warmup helpers from ModelMixin. And just make changes to the tests. * Skip tests using unittest decorator for consistency. * Remove temporary var. * Remove spurious blank space. * Remove unused symbol. * Remove README_mps. Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* Fix LMS scheduler indexing in `add_noise` huggingface#358. * Fix DDIM and DDPM indexing with mps device. * Verify format is PyTorch before using `.to()`
Describe the bug
When you use an LMSDiscreteScheduler on an Apple Silicon machine, you'll get the following error:
Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead
The offending line is 134 in unet_2d_condition.py.
The current code is:
timesteps = timesteps[None].to(sample.device)
Changing that to the following stops the crash:
timesteps = timesteps[None].long().to(sample.device)
However, I believe you'd really want to do a check to see if the current device is MPS and only do the format conversion if you are on MPS?
Reproduction
When you use an LMSDiscreteScheduler on an Apple Silicon machine you should see the crash.
Logs
No response
System Info
The current main branch from the repo since that appears to be different from the current release version (0.2.4?)
The text was updated successfully, but these errors were encountered: