-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Add ControlNet-XS support #5827
Conversation
…nto controlnet-xs
Cc: @DN6 |
Amazing job @UmerHA |
@UmerHA Great work! Can you provide controlnet-xs traning example? |
@universewill Sure - see https://github.com/UmerHA/diffusers/tree/cnxs-training/examples/controlnet_xs. I've tested that they they run, but haven't fully tested full training runs. When I have more time, I'll do that and open a PR. In the meantime, let me know if you encounter any issues! |
* Check in 23-10-05 * check-in 23-10-06 * check-in 23-10-07 2pm * check-in 23-10-08 * check-in 231009T1200 * check-in 230109 * checkin 231010 * init + forward run * checkin * checkin * ControlNetXSModel is now saveable+loadable * Forward works * checkin * Pipeline works with `no_control=True` * checkin * debug: save intermediate outputs of resnet * checkin * Understood time error + fixed connection error * checkin * checkin 231106T1600 * turned off detailled debug prints * time debug logs * small fix * Separated control_scale for connections/time * simplified debug logging * Full denoising works with control scale = 0 * aligned logs * Added control_attention_head_dim param * Passing n_heads instead of dim_head into ctrl unet * Fixed ctrl midblock bug * Cleanup * Fixed time dtype bug * checkin * 1. from_unet, 2. base passed, 3. all unet params * checkin * Finished docstrings * cleanup * make style * checkin * more tests pass * Fixed tests * removed debug logs * make style + quality * make fix-copies * fixed documentation * added cnxs to doc toc * added control start/end param * Update controlnetxs_sdxl.md * tried to fix copies.. * Fixed norm_num_groups in from_unet * added sdxl-depth test * created SD2.1 controlnet-xs pipeline * re-added debug logs * Adjusting group norm ; readded logs * Added debug log statements * removed debug logs ; started tests for sd2.1 * updated sd21 tests * fixed tests * fixed tests * slightly increased error tolerance for 1 test * make style & quality * Added docs for CNXS-SD * make fix-copies * Fixed sd compile test ; fixed gradient ckpointing * vae downs = cnxs conditioning downs; removed guess * make style & quality * Fixed tests * fixed test * Incorporated review feedback * simplified control model surgery * fixed tests & make style / quality * Updated docs; deleted pip & cursor files * Rolled back minimal change to resnet * Update resnet.py * Update resnet.py * Update src/diffusers/models/controlnetxs.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/controlnetxs.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Incorporated review feedback * Update docs/source/en/api/pipelines/controlnetxs_sdxl.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/controlnetxs.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/controlnetxs.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/controlnetxs.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/models/controlnetxs.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/models/controlnetxs.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/controlnetxs.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Incorporated doc feedback --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
* Check in 23-10-05 * check-in 23-10-06 * check-in 23-10-07 2pm * check-in 23-10-08 * check-in 231009T1200 * check-in 230109 * checkin 231010 * init + forward run * checkin * checkin * ControlNetXSModel is now saveable+loadable * Forward works * checkin * Pipeline works with `no_control=True` * checkin * debug: save intermediate outputs of resnet * checkin * Understood time error + fixed connection error * checkin * checkin 231106T1600 * turned off detailled debug prints * time debug logs * small fix * Separated control_scale for connections/time * simplified debug logging * Full denoising works with control scale = 0 * aligned logs * Added control_attention_head_dim param * Passing n_heads instead of dim_head into ctrl unet * Fixed ctrl midblock bug * Cleanup * Fixed time dtype bug * checkin * 1. from_unet, 2. base passed, 3. all unet params * checkin * Finished docstrings * cleanup * make style * checkin * more tests pass * Fixed tests * removed debug logs * make style + quality * make fix-copies * fixed documentation * added cnxs to doc toc * added control start/end param * Update controlnetxs_sdxl.md * tried to fix copies.. * Fixed norm_num_groups in from_unet * added sdxl-depth test * created SD2.1 controlnet-xs pipeline * re-added debug logs * Adjusting group norm ; readded logs * Added debug log statements * removed debug logs ; started tests for sd2.1 * updated sd21 tests * fixed tests * fixed tests * slightly increased error tolerance for 1 test * make style & quality * Added docs for CNXS-SD * make fix-copies * Fixed sd compile test ; fixed gradient ckpointing * vae downs = cnxs conditioning downs; removed guess * make style & quality * Fixed tests * fixed test * Incorporated review feedback * simplified control model surgery * fixed tests & make style / quality * Updated docs; deleted pip & cursor files * Rolled back minimal change to resnet * Update resnet.py * Update resnet.py * Update src/diffusers/models/controlnetxs.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/controlnetxs.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Incorporated review feedback * Update docs/source/en/api/pipelines/controlnetxs_sdxl.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/controlnetxs.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/controlnetxs.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/controlnetxs.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/models/controlnetxs.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/models/controlnetxs.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/controlnetxs.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Incorporated doc feedback --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Sorry, we had to move the implementation to the research folder for now as the design was not in line with the usual diffusers design (e.g. the unet is forwarded into the controlnet-xs function etc...). We should have caught that when reviewing the PR, but sadly failed to do so. We still very much want to add ControlNet-XS to Very sorry @UmerHA that we missed these things in the initial review 🙏 |
super().__init__() | ||
|
||
# 1 - Create control unet | ||
self.control_model = UNet2DConditionModel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The control_model should not be a UNet2DConditionModel
if we have to apply a lot of surgery afterwards. Let's make sure we directly instantiate the correct torch.nn.Modules right away
# 2 - Do model surgery on control model | ||
# 2.1 - Allow to use the same time information as the base model | ||
adjust_time_dims(self.control_model, time_embedding_input_dim, time_embedding_dim) | ||
|
||
# 2.2 - Allow for information infusion from base model | ||
|
||
# We concat the output of each base encoder subblocks to the input of the next control encoder subblock | ||
# (We ignore the 1st element, as it represents the `conv_in`.) | ||
extra_input_channels = [input_channels for input_channels, _ in base_model_channel_sizes["down"][1:]] | ||
it_extra_input_channels = iter(extra_input_channels) | ||
|
||
for b, block in enumerate(self.control_model.down_blocks): | ||
for r in range(len(block.resnets)): | ||
increase_block_input_in_encoder_resnet( | ||
self.control_model, block_no=b, resnet_idx=r, by=next(it_extra_input_channels) | ||
) | ||
|
||
if block.downsamplers: | ||
increase_block_input_in_encoder_downsampler( | ||
self.control_model, block_no=b, by=next(it_extra_input_channels) | ||
) | ||
|
||
increase_block_input_in_mid_resnet(self.control_model, by=extra_input_channels[-1]) | ||
|
||
# 2.3 - Make group norms work with modified channel sizes | ||
adjust_group_norms(self.control_model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't do surgery here, let's make sure to instead instantiate the correct classes right away
# In the mininal implementation setting, we only need the control model up to the mid block | ||
del self.control_model.up_blocks | ||
del self.control_model.conv_norm_out | ||
del self.control_model.conv_out | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They should instead never have been instantiated
def set_attention_slice(self, slice_size): | ||
r""" | ||
Enable sliced attention computation. | ||
|
||
When this option is enabled, the attention module splits the input tensor in slices to compute attention in | ||
several steps. This is useful for saving some memory in exchange for a small decrease in speed. | ||
|
||
Args: | ||
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): | ||
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If | ||
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is | ||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` | ||
must be a multiple of `slice_size`. | ||
""" | ||
self.control_model.set_attention_slice(slice_size) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not provide a set_attention_slice()
operation anymore since with FlashAttention it's pretty useless to slice the attention
|
||
def forward( | ||
self, | ||
base_model: UNet2DConditionModel, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't pass the base_model here into the forward method
encoder_hidden_states: torch.Tensor, | ||
controlnet_cond: torch.Tensor, | ||
conditioning_scale: float = 1.0, | ||
class_labels: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need the class_labels parameter?
if base_model.class_embedding is not None: | ||
if class_labels is None: | ||
raise ValueError("class_labels should be provided when num_class_embeds > 0") | ||
|
||
if base_model.config.class_embed_type == "timestep": | ||
class_labels = base_model.time_proj(class_labels) | ||
|
||
class_emb = base_model.class_embedding(class_labels).to(dtype=self.dtype) | ||
temb = temb + class_emb | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this? I think the class labels are only required for Stable Diffusion Upsampling
if base_model.config.addition_embed_type is not None: | ||
if base_model.config.addition_embed_type == "text": | ||
aug_emb = base_model.add_embedding(encoder_hidden_states) | ||
elif base_model.config.addition_embed_type == "text_image": | ||
raise NotImplementedError() | ||
elif base_model.config.addition_embed_type == "text_time": | ||
# SDXL - style | ||
if "text_embeds" not in added_cond_kwargs: | ||
raise ValueError( | ||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" | ||
) | ||
text_embeds = added_cond_kwargs.get("text_embeds") | ||
if "time_ids" not in added_cond_kwargs: | ||
raise ValueError( | ||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" | ||
) | ||
time_ids = added_cond_kwargs.get("time_ids") | ||
time_embeds = base_model.add_time_proj(time_ids.flatten()) | ||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) | ||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) | ||
add_embeds = add_embeds.to(temb.dtype) | ||
aug_emb = base_model.add_embedding(add_embeds) | ||
elif base_model.config.addition_embed_type == "image": | ||
raise NotImplementedError() | ||
elif base_model.config.addition_embed_type == "image_hint": | ||
raise NotImplementedError() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not call the base model here - instead self.add_embedding(...)
should be called
return False | ||
|
||
|
||
def to_sub_blocks(blocks): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not have to call such a method in the forward pass of the controlnet
for d in b.downsamplers: | ||
sub_blocks.append([d]) | ||
|
||
return list(map(SubBlock, sub_blocks)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try to avoid map
here as this breaks torch.compile
a.norm.num_groups = find_denominator(a.norm.num_channels, start=max_num_group) | ||
|
||
|
||
def is_iterable(o): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not use such a try-except function, this breaks torch.compile
unet.time_embedding.linear_1 = nn.Linear(in_dim, out_dim) | ||
|
||
|
||
def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no, resnet_idx, by): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not have to use such surgery methods
Hi @patrickvonplaten, no hurt feelings & fully understand. I'll start to change it to better fit diffusers, and open a new PR. Happy New Year btw :) |
Hi @patrickvonplaten, could you answer two questions before I start the new implementation:
Thanks! |
Hey @UmerHA, Thanks for the write-up! Yes this design makes a lot of sense to me :-)
I think ControlNet-XS should go into "core" |
* Check in 23-10-05 * check-in 23-10-06 * check-in 23-10-07 2pm * check-in 23-10-08 * check-in 231009T1200 * check-in 230109 * checkin 231010 * init + forward run * checkin * checkin * ControlNetXSModel is now saveable+loadable * Forward works * checkin * Pipeline works with `no_control=True` * checkin * debug: save intermediate outputs of resnet * checkin * Understood time error + fixed connection error * checkin * checkin 231106T1600 * turned off detailled debug prints * time debug logs * small fix * Separated control_scale for connections/time * simplified debug logging * Full denoising works with control scale = 0 * aligned logs * Added control_attention_head_dim param * Passing n_heads instead of dim_head into ctrl unet * Fixed ctrl midblock bug * Cleanup * Fixed time dtype bug * checkin * 1. from_unet, 2. base passed, 3. all unet params * checkin * Finished docstrings * cleanup * make style * checkin * more tests pass * Fixed tests * removed debug logs * make style + quality * make fix-copies * fixed documentation * added cnxs to doc toc * added control start/end param * Update controlnetxs_sdxl.md * tried to fix copies.. * Fixed norm_num_groups in from_unet * added sdxl-depth test * created SD2.1 controlnet-xs pipeline * re-added debug logs * Adjusting group norm ; readded logs * Added debug log statements * removed debug logs ; started tests for sd2.1 * updated sd21 tests * fixed tests * fixed tests * slightly increased error tolerance for 1 test * make style & quality * Added docs for CNXS-SD * make fix-copies * Fixed sd compile test ; fixed gradient ckpointing * vae downs = cnxs conditioning downs; removed guess * make style & quality * Fixed tests * fixed test * Incorporated review feedback * simplified control model surgery * fixed tests & make style / quality * Updated docs; deleted pip & cursor files * Rolled back minimal change to resnet * Update resnet.py * Update resnet.py * Update src/diffusers/models/controlnetxs.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/controlnetxs.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Incorporated review feedback * Update docs/source/en/api/pipelines/controlnetxs_sdxl.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/controlnetxs.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/controlnetxs.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/controlnetxs.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/models/controlnetxs.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/models/controlnetxs.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/controlnetxs.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Incorporated doc feedback --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
What does this PR do?
Adds ControlNet-XS support (and therefore fixes #5168).
Project page: https://vislearn.github.io/ControlNet-XS/
See here for a full working example
This PR is work in progress. Still to do:
SD canny✅,SD depth✅,SDXL depth✅Add documentation✅A few other (iiuc) minor things✅Still, I would love your feedback!
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of who to tag.
Please tag fewer than 3 people.
Core library: