-
Notifications
You must be signed in to change notification settings - Fork 277
Add StableDiffusion3 #1820
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Add StableDiffusion3 #1820
Conversation
* Agg Vgg16 backbone * update names * update tests * update test * add image classifier * incorporate review comments * Update test case * update backbone test * add image classifier * classifier cleanup * code reformat * add vgg16 image classifier * make vgg generic * update doc string * update docstring * add classifier test * update tests * update docstring * address review comments * code reformat * update the configs * address review comments * fix task saved model test * update init * code reformatted
* Add ResNetV1 and ResNetV2 * Address comments
* Add CSP DarkNet * Add CSP DarkNet * snake_case function names * change use_depthwise to block_type
…Backbone` (keras-team#1769) * Add FeaturePyramidBackbone and update ResNetBackbone * Simplify the implementation * Fix CI * Make ResNetBackbone compatible with timm and add FeaturePyramidBackbone * Add conversion implementation * Update docstrings * Address comments
* Add DenseNet * fix testcase * address comments * nit * fix lint errors * move description
* add vit det vit_det_backbone * update docstring * code reformat * fix tests * address review comments * bump year on all files * address review comments * rename backbone * fix tests * change back to ViT * address review comments * update image shape
* Add MixTransformer * fix testcase * test changes and comments * lint fix * update config list * modify testcase for 2 layers
* update input_image_shape -> image_shape * update docstring example * code reformat * update tests
add missing __init__ file to vit_det
This is a temporary way to test out the keras-hub branch. - Does a global rename of all symbols during package build. - Registers the "old" name on symbol export for saving compat. - Adds a github action to publish every commit to keras-hub as a new package. - Removes our descriptions on PyPI temporarily, until we want to message this more broadly.
* Add `CLIPTokenizer`, `T5XXLTokenizer`, `CLIPTextEncoder` and `T5XXLTextEncoder`. * Make CLIPTextEncoder as Backbone * Add `T5XXLPreprocessor` and remove `T5XXLTokenizer` Add `CLIPPreprocessor` * Use `tf = None` at the top * Replace manual implementation of `CLIPAttention` with `MultiHeadAttention`
* Bounding box utils * - Correct test cases * - Remove hard tensorflow dtype * - fix api gen * - Fix import for test cases - Use setup for converters test case * - fix api_gen issue * - FIx api gen * - Fix api gen error * - Correct test cases as per new api changes
* mobilenet_v3 added in keras-nlp * minor bug fixed in mobilenet_v3_backbone * formatting corrected * refactoring backbone * correct_pad_downsample method added * refactoring backbone * parameters updated * Testcaseupdated, expected output shape corrected * code formatted with black * testcase updated * refactoring and description added * comments updated * added mobilenet v1 and v2 * merge conflict resolved * version arg removed, and config options added * input_shape changed to image_shape in arg * config updated * input shape corrected * comments resolved * activation function format changed * minor bug fixed * minor bug fixed * added vision_backbone_test * channel_first bug resolved * channel_first cases working * comments resolved * formatting fixed * refactoring --------- Co-authored-by: ushareng <usha.rengaraju@gmail.com>
* migrating efficientnet models to keras-hub * merging changes from other sources * autoformatting pass * initial consolidation of efficientnet_backbone * most updates and removing separate implementation * cleanup, autoformatting, keras generalization * removed layer examples outside of effiicient net * many, mainly documentation changes, small test fixes
* Add ResNet_vd to ResNet backbone * Addressed requested parameter changes * Fixed tests and updated comments * Added new parameters to docstring
* Add `VAEImageDecoder` for StableDiffusionV3 * Use `keras.Model` for `VAEImageDecoder` and follows the coding style in `VAEAttention`
* add pyramid outputs * fix testcase * format fix * make common testcase for pyramid outputs * change default shape * simplify testcase * test case change and add channel axis
* Add `MMDiT` * Update * Update * Update implementation
* - Add formats, iou, utils for bounding box * - Add `AnchorGenerator`, `BoxMatcher` and `NonMaxSupression` layers * - Remove scope_name not required. * use default keras name scope * - Correct format error * - Remove layers as of now and keep them at model level till keras core supports them * - Correct api_gen
Awesome work! The samples are very exciting. Still reading through it, but some initial comments.
Yeah definitely see what you mean. I think the best way to break this up would be to have In the case of clip, it actually seems like we would like a standalone clip model with it's own In the case of WDYT?
You mean in the colab right? The code looks like it's still supporting it everywhere. I think that is fine, probably the right way to start. Part of me is tempted to just rip our the T5 part entirely for now, wait for someone to ask for it. It'd make the initial implementation a lot simpler, and it seems like what almost all users will want anyway. |
Sounds good to me. This should be the first nested backbone model in KerasHub but it probably won’t be the last. 😅
I’m not too familiar with CLIP, but I noticed
Yeah, I second that. Will make
Actually, in the colab, T5 wasn't loaded because I set |
Yeah that sounds like the right way to do it. Right now our weights our backbone weights are monolithic and loaded in full. We could consider supporting a partial instantiation of weights, but I might actually do this as a separate upload.
Something like that. Then users wouldn't even have to download t5 in the "more usual" path. But anyway, that we can figure out later, supporting both as config options sounds good if it doesn't sound like it'd be too bad to implement. |
Agree with this and it shouldn't be difficult to implement. Some updates:
Let me know when this is ready. If so, I will add docstrings, the weight conversion script (in |
@james77777778 We have renamed the repo and code to KerasHub! Sorry about this disruptive change but a one time cost. Please feel free to close this PR and open new one with the new master. |
No problem. I will be back on 9/22 and resubmit the PR soon. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good to me! I think we need to think a bit about saving and loading more, but we can do that in a follow up PR.
Interesting questions there.
Should keras_hub.tokenizers.Tokenizer.from_preset("sd3_preset_name")
return something? Do we want to add a way to create each tokenizer individually? What does instantiation look like?
This will also need a rebase after our big symbol rename change.
metrics="auto", | ||
**kwargs, | ||
): | ||
# TODO: Figure out how to compile. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably we can just chain to super()
here and clear the generate_function
. Compile doesn't actually create a traced function, we actually do that lazily the first time a generate, predict, train function is called.
If we even had a argument that we thought we be common to a lot of models and requires recompiling the function (like sampler
for text models), we could consider adding it here, but I don't think we need to do that now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. I referred to this implementation:
https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L410-L414
def compile(
self,
optimizer="auto",
loss="auto",
*,
metrics="auto",
**kwargs,
):
# Ref: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L410-L414
if optimizer == "auto":
optimizer = keras.optimizers.AdamW(
1e-4, weight_decay=1e-2, epsilon=1e-8, clipnorm=1.0
)
if loss == "auto":
loss = keras.losses.MeanSquaredError()
if metrics == "auto":
metrics = [keras.metrics.MeanSquaredError()]
super().compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
**kwargs,
)
self.generate_function = None
|
||
|
||
@keras_nlp_export("keras_nlp.models.StableDiffusion3Backbone") | ||
class StableDiffusion3Backbone(Backbone): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Backbone is looking good! one good thing to test is model.summary()
and make sure it's looking reasonable, now that we are on a functional model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added some helper layers to get a cleaner model.summary()
.
Now it looks like this:
Model: "stable_diffusion3_backbone"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ clip_l_token_ids (InputLayer) │ (None, None) │ 0 │ - │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ clip_g_token_ids (InputLayer) │ (None, None) │ 0 │ - │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ clip_l_negative_token_ids │ (None, None) │ 0 │ - │
│ (InputLayer) │ │ │ │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ clip_g_negative_token_ids │ (None, None) │ 0 │ - │
│ (InputLayer) │ │ │ │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ clip_l (CLIPTextEncoder) │ [(None, None, 768), │ 123,060,480 │ clip_l_token_ids[0][0], │
│ │ (None, None, 768)] │ │ clip_l_negative_token_ids… │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ clip_g (CLIPTextEncoder) │ [(None, None, 1280), │ 693,021,440 │ clip_g_token_ids[0][0], │
│ │ (None, None, 1280)] │ │ clip_g_negative_token_ids… │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ num_steps (InputLayer) │ () │ 0 │ - │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ concatenate_1 (Concatenate) │ (None, None, 2048) │ 0 │ clip_l[0][0], clip_g[0][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ concatenate_3 (Concatenate) │ (None, None, 2048) │ 0 │ clip_l[1][0], clip_g[1][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ cast_63 (Cast) │ (None, None, 768) │ 0 │ clip_l[0][1] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ cast_64 (Cast) │ (None, None, 1280) │ 0 │ clip_g[0][1] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ cast_65 (Cast) │ (None, None, 768) │ 0 │ clip_l[1][1] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ cast_66 (Cast) │ (None, None, 1280) │ 0 │ clip_g[1][1] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ scheduler │ () │ 0 │ num_steps[0][0], │
│ (FlowMatchEulerDiscreteSched… │ │ │ num_steps[0][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ pad (Pad) │ (None, None, 4096) │ 0 │ concatenate_1[0][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ pad_2 (Pad) │ (None, None, 4096) │ 0 │ concatenate_3[0][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ clip_l_projection │ (None, 768) │ 589,824 │ cast_63[0][0], │
│ (CLIPProjection) │ │ │ clip_l_token_ids[0][0], │
│ │ │ │ cast_65[0][0], │
│ │ │ │ clip_l_negative_token_ids… │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ clip_g_projection │ (None, 1280) │ 1,638,400 │ cast_64[0][0], │
│ (CLIPProjection) │ │ │ clip_g_token_ids[0][0], │
│ │ │ │ cast_66[0][0], │
│ │ │ │ clip_g_negative_token_ids… │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ multiply (Multiply) │ () │ 0 │ scheduler[0][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ latents (InputLayer) │ (None, 100, 100, 16) │ 0 │ - │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ pad_1 (Pad) │ (None, None, 4096) │ 0 │ pad[0][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ pad_3 (Pad) │ (None, None, 4096) │ 0 │ pad_2[0][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ concatenate (Concatenate) │ (None, 2048) │ 0 │ clip_l_projection[0][0], │
│ │ │ │ clip_g_projection[0][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ concatenate_2 (Concatenate) │ (None, 2048) │ 0 │ clip_l_projection[1][0], │
│ │ │ │ clip_g_projection[1][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ broadcast_to (BroadcastTo) │ (None) │ 0 │ multiply[0][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ concatenate_4 (Concatenate) │ (None, None, 4096) │ 0 │ pad_1[0][0], pad_3[0][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ concatenate_6 (Concatenate) │ (None, 100, 100, 16) │ 0 │ latents[0][0], │
│ │ │ │ latents[0][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ concatenate_5 (Concatenate) │ (None, 2048) │ 0 │ concatenate[0][0], │
│ │ │ │ concatenate_2[0][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ concatenate_7 (Concatenate) │ (None) │ 0 │ broadcast_to[0][0], │
│ │ │ │ broadcast_to[0][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ diffuser (MMDiT) │ (None, None, None, 16) │ 2,084,951,104 │ concatenate_4[0][0], │
│ │ │ │ concatenate_6[0][0], │
│ │ │ │ concatenate_5[0][0], │
│ │ │ │ concatenate_7[0][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ cast_67 (Cast) │ (None, None, None, 16) │ 0 │ diffuser[0][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ guidance_scale (InputLayer) │ () │ 0 │ - │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ classifier_free_guidance │ (None, None, None, 16) │ 0 │ cast_67[0][0], │
│ (ClassifierFreeGuidance) │ │ │ guidance_scale[0][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ euler_step (EulerStep) │ (None, 100, 100, 16) │ 0 │ latents[0][0], │
│ │ │ │ classifier_free_guidance[… │
│ │ │ │ scheduler[0][0], │
│ │ │ │ scheduler[1][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ latent_calibration │ (None, 100, 100, 16) │ 0 │ euler_step[0][0] │
│ (LatentCalibration) │ │ │ │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ decoder (VAEImageDecoder) │ (None, 800, 800, 3) │ 49,545,475 │ latent_calibration[0][0] │
└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘
Total params: 2,952,806,723 (5.50 GB)
Trainable params: 2,952,806,723 (5.50 GB)
Non-trainable params: 0 (0.00 B)
config = super().get_config() | ||
config.update( | ||
{ | ||
"clip_l_preprocessor": layers.serialize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do clip_l and clip_g have separate tokenizers? We should think about seralialization a bit here. Right now for our "preset" saving, we assume one tokenizer, in a fixed directory of assets/tokenizer
. We probably need to tweak our saving and loading a bit.
Probably makes sense to do this as a separate PR? I'll think about how to best do this and post some thoughts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's a scaffold of what we could do here. #1860
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will keep the current implementation for now.
Should I create a new PR for this after SD3 is merged, or will you finish it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to patch this in to a PR you own. I think actually using it to save sd3 assets will be an important way to test it out
for batched inputs. | ||
|
||
Args: | ||
latents: A <float>[batch_size, height, width, channels] tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't think we use this notation usually? I'd just say "A float tensor with shape (batch_size, height, width, channels)
..."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be some legacy docstrings in the codebase:
keras-hub/keras_hub/src/models/gemma/gemma_causal_lm.py
Lines 328 to 358 in 9b1cf95
"""Score a generation represented by the provided token ids. | |
Args: | |
token_ids: A <int>[batch_size, num_tokens] tensor containing tokens | |
to score. Typically, this tensor captures the output from a call | |
to `GemmaCausalLM.generate()`, i.e., tokens for both the input | |
text and the model-generated text. | |
padding_mask: A <bool>[batch_size, num_tokens] tensor indicating the | |
tokens that should be preserved during generation. This is an | |
artifact required by the GemmaBackbone and isn't influential on | |
the computation of this function. If omitted, this function uses | |
`keras.ops.ones()` to create a tensor of the appropriate shape. | |
scoring_mode: The type of scores to return, either "logits" or | |
"loss", both will be per input token. | |
layer_intercept_fn: An optional function for augmenting activations | |
with additional computation, for example, as part of | |
interpretability research. This function will be passed the | |
activations as its first parameter and a numeric index | |
associated with that backbone layer. _This index _is not_ an | |
index into `self.backbone.layers`_. The index -1 accompanies the | |
embeddings returned by calling `self.backbone.token_embedding()` | |
on `token_ids` in the forward direction. All subsequent indexes | |
will be 0-based indices for the activations returned by each of | |
the Transformers layers in the backbone. This function must | |
return a <float>[batch_size, num_tokens, hidden_dims] tensor | |
that can be passed as an input to the next layer in the model. | |
target_ids: An <bool>[batch_size, num_tokens] tensor containing the | |
predicted tokens against which the loss should be computed. If a | |
span of tokens is provided (sequential truthy values along | |
axis=1 in the tensor), the loss will be computed as the | |
aggregate across those tokens. |
I adapted it from there. It's fixed now.
I think it probably makes sense to pull this in without compilation all the way figured out, or saving. And take those on as two (hopefully independent) follow ups. Nice work! |
I have switched to |
@mattdangerw @divyashreepathihalli I have addressed the above comments. The model works as-is: Please let me know if the implementation is ready. I will add unit tests and complete the missing docstrings afterward. Also, please let me know if I should take over #1860 or not. |
@james77777778 The implementation is looking good!! the results are looking good! please go ahead and finish up the docstrings and unit tests! Thanks! |
@divyashreepathihalli |
@james77777778 the jax test is failing for Keras 3.1, but passing on others, let me know if it is fixable or just a bug on keras 3.1. |
It should be fixed now. The root cause is that Keras 3.1 doesn't have a setter for dtype policy. The solution is to directly assign a |
Thanks James!! merging this! |
This is more of a draft, as we may need further discussion regarding the implementation.
Notes for reviewing:
StableDiffusion3Backbone
is a large model, resulting in a very long init signature. Is this acceptable? How could we refactor it?text_to_image
due to unexpected OOM issues. However, when splitting it intoencode
,denoise
anddecode
functions, it worked fine. I'm unsure the performance impact about not compiling the entire function.Refs:
Demo colab:
including weights conversion for https://huggingface.co/stabilityai/stable-diffusion-3-medium.
https://colab.research.google.com/drive/1rrQMs0nlKSEzYNhIJChQwgnrZNiydexS?usp=sharing
"a cat holding a sign that says hello world"
"cute wallpaper art of a cat"
TODO:
stable_diffusion_3
@divyashreepathihalli @mattdangerw @SamanehSaadat
BTW, I will be unavailable from 9/17~9/22