Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Flax] added broadcast_to_shape_from_left helper and Scheduler tests #864

Merged
merged 15 commits into from
Oct 25, 2022

Conversation

kashif
Copy link
Contributor

@kashif kashif commented Oct 17, 2022

instead of the while loop

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 17, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me - could we add one test?
Also did we make sure that the Flax pipeline still works (e.g. did we run the slow Flax tests once? )

@kashif
Copy link
Contributor Author

kashif commented Oct 17, 2022

yes let me add flax scheduler tests! good idea!

@kashif
Copy link
Contributor Author

kashif commented Oct 17, 2022

I will integrate #580 here

@kashif kashif changed the title [Flax] added broadcast_to_shape_from_left helper [Flax] added broadcast_to_shape_from_left helper and Scheduler tests Oct 19, 2022
@kashif
Copy link
Contributor Author

kashif commented Oct 19, 2022

@anton-l i have added some flax scheduler tests... whenever you get a chance i would not mind a review. Thanks!

@patrickvonplaten
Copy link
Contributor

Did we test the changes on a TPUv3-8 or TPUv2-8?

@patrickvonplaten
Copy link
Contributor

Happy to merge once confirmed everything works on TPUv3-8 and fast tests pass. Think the current failure of the fast tests is unrelated and has been fixed by Suraj here: #928

@kashif
Copy link
Contributor Author

kashif commented Oct 20, 2022

@patrickvonplaten thanks let me update my branch!

@pcuenca
Copy link
Member

pcuenca commented Oct 21, 2022

The following test passes in a TPU v3-8:

def test_stable_diffusion_v1_4(self):

But the others in the same class don't. For example, we are now getting (8, 1, 128, 128, 3) here:

assert images.shape == (8, 1, 64, 64, 3)
. The same thing happens in main, but it works in 7c22626.

I'm a bit puzzled, I'll take another look when I've had some rest.

@patrickvonplaten
Copy link
Contributor

Ok as this PR gives the same results as main merging this for now

@patrickvonplaten patrickvonplaten merged commit 240abdd into huggingface:main Oct 25, 2022
@kashif kashif deleted the jax-broadcast branch October 25, 2022 11:48
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…uggingface#864)

* added broadcast_to_shape_from_left helper

* initial tests

* fixed pndm tests

* shape required for pndm

* added require_flax

* fix style

* fix more imports

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants