-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
fix loop bug in SlicedAttnProcessor #8836
fix loop bug in SlicedAttnProcessor #8836
Conversation
Thank you! Do you think our tests need to be updated to catch bugs like this?
|
Bugs like this in test_attention_slicing_forward_pass have been fixed, include "SlicedAttnProcessor" and "SlicedAttnAddedKVProcessor". May be there is no need to catch bugs like this in test_attention_slicing_forward_pass. But maybe tests on other function or module are needed. |
Or, if it's needed to update the test "test_attention_slicing_forward_pass" to catch bugs like this, I am happy to do that |
I have update the test_attention_slicing_forward_pass to catch bugs like this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks a lot!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
space in blank line have been removed |
thanks for fixing this for us! |
ok, wait a minute |
done |
…fix_bug_of_SlicedAttnProcessor
sorry for tests bug, have fixed it |
hey I think the failing tests are relevant here https://github.com/huggingface/diffusers/actions/runs/9970139809/job/27588409202?pr=8836#step:7:17944 can you look into them? |
…fix_bug_of_SlicedAttnProcessor
I have look into them, and this is relevant with unet_2d_condition. I change unet_2d_conditon a little, I don't know if this is appropriate. Please check it. |
This commit: 0ede41c |
inputs = self.get_dummy_inputs(generator_device) | ||
output_with_slicing2 = pipe(**inputs)[0] | ||
|
||
pipe.enable_attention_slicing(slice_size=3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we remove the slice_size=3
test? I think the CI would pass without this, no?
@@ -815,7 +815,10 @@ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): | |||
size = slice_size[i] | |||
dim = sliceable_head_dims[i] | |||
if size is not None and size > dim: | |||
raise ValueError(f"size {size} has to be smaller or equal to {dim}.") | |||
slice_size[i] = dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's try not to make this update and change the test instead (we should try not to update the user inputs for user, we always prefer to be explicit and throw an error message, )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, get it. if remove the slice_size=3, the CI will pass.
…fix_bug_of_SlicedAttnProcessor
I have remove slice_size=3, and make unet_2d_condition unchanged. |
* fix loop bug in SlicedAttnProcessor --------- Co-authored-by: neoshang <neoshang@tencent.com>
Fixes # (loop bug in SlicedAttnProcessor)
@sayakpaul @yiyixuxu @DN6