Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Core] Add Kolors #8812

Merged
merged 20 commits into from
Jul 11, 2024
Merged

[Core] Add Kolors #8812

merged 20 commits into from
Jul 11, 2024

Conversation

asomoza
Copy link
Member

@asomoza asomoza commented Jul 9, 2024

What does this PR do?

Adds Kolors from the Kwai-Kolors team

Text to Image

import torch

from diffusers import DPMSolverMultistepScheduler, KolorsPipeline

pipe = KolorsPipeline.from_pretrained("Kwai-Kolors/Kolors-diffusers", torch_dtype=torch.float16, variant="fp16").to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)

prompt = '一张瓢虫的照片,微距,变焦,高质量,电影,拿着一个牌子,写着"可图"'

image = pipe(
    prompt=prompt,
    negative_prompt="",
    guidance_scale=6.5,
    num_inference_steps=25,
).images[0]
20240708220004_3124802607 20240708221159_2604715491 20240708221455_2017696674

Image to Image

import torch
import math

from diffusers import DPMSolverMultistepScheduler, KolorsImg2ImgPipeline
from diffusers.utils import load_image

pipe = KolorsImg2ImgPipeline.from_pretrained("Kwai-Kolors/Kolors-diffusers", torch_dtype=torch.float16, variant="fp16").to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)

source_image = load_image(
    "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/kolors/bunny_source.png?download=true"
)

prompt = "high quality image of a capybara wearing sunglasses. In the background of the image there are trees, poles, grass and other objects. At the bottom of the object there is the road., 8k, highly detailed."

strength = 0.65
steps = 25

image = pipe(
    prompt=prompt,
    image=source_image,
    negative_prompt="",
    guidance_scale=6.5,
    num_inference_steps=math.ceil(steps / strength),
    strength = strength
).images[0]
source strength 0.65 strength 0.9
20240501130625_2487854446 20240709012944_3628531632 20240709023731_3598424131

Fixes #8801

TODO

  • Tests
  • Docs

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@yiyixuxu @sayakpaul @JincanDeng

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me overall!
I wonder if we should call it pipeline_stable_diffusion_xl_kolors? it is sort of a "variant" , let's ask the author for their preference!

@yiyixuxu yiyixuxu requested a review from sayakpaul July 9, 2024 03:01
@asomoza
Copy link
Member Author

asomoza commented Jul 9, 2024

I'm ignoring for now the failed test with the docs until I finish with all the other stuff. Kind of curious why the docs are the only ones that fail with a circular import.

@s9anus98a
Copy link

u guys need made it under diffusionpipeline just like playground

@asomoza
Copy link
Member Author

asomoza commented Jul 9, 2024

@s9anus98a We decided not to do that because this pipeline doesn't share the same defaults and also nothing from SDXL works with this model right now, like IP Adapters, LoRAs and TIs

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Just some minor fixes related to docstrings. Could we also add fast tests please.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks absolutely amazing! Just a few missing copied-froms

@vladmandic
Copy link
Contributor

it seems its not picking up the selected variant correctly?

import torch
import diffusers
from rich.traceback import install

install()

pipe = diffusers.KolorsPipeline.from_pretrained(
    'Kwai-Kolors/Kolors',
    variant = 'fp16',
    torch_dtype = torch.float16,
    cache_dir = '/mnt/models/Diffusers',
)
OSError: Error no file named diffusion_pytorch_model.bin found in directory /mnt/models/Diffusers/models--Kwai-Kolors--Kolors/snapshots/f0d70fcea007ce197eea9631a35507d827e4a72e/unet.

there are both fp32 and fp16 variants available in https://huggingface.co/Kwai-Kolors/Kolors/tree/main/unet, but they were not picked up.

this is the local snapshot state - no unet downloaded at all:

    └── f0d70fcea007ce197eea9631a35507d827e4a72e
        ├── model_index.json -> ../../blobs/d473050985693036d4be16253f0d12418b3c3378
        ├── scheduler
        │   └── scheduler_config.json -> ../../../blobs/4cb73ebda6cfb0a97a2eefd289ef0bd27d8f712e
        ├── text_encoder
        │   ├── config.json -> ../../../blobs/c6e19300822b25ae0a07125bbc171c6581dbeda4
        │   ├── pytorch_model-00001-of-00007.bin -> ../../../blobs/b6a6388dae55b598efe76c704e7f017bd84e6f6213466b7686a8f8326f78ab05
        │   ├── pytorch_model-00002-of-00007.bin -> ../../../blobs/2f96bef324acb5c3fe06b7a80f84272fe064d0327cbf14eddfae7af0d665a6ac
        │   ├── pytorch_model-00003-of-00007.bin -> ../../../blobs/2400101255213250d9df716f778b7d2325f2fa4a8acaedee788338fceee5b27e
        │   ├── pytorch_model-00004-of-00007.bin -> ../../../blobs/472567c1b0e448a19171fbb5b3dab5670426d0a5dfdfd2c3a87a60bb1f96037d
        │   ├── pytorch_model-00005-of-00007.bin -> ../../../blobs/ef2aea78fa386168958e5ba42ecf09cbb567ed3e77ce2be990d556b84081e2b9
        │   ├── pytorch_model-00006-of-00007.bin -> ../../../blobs/35191adf21a1ab632c2b175fcbb6c27601150026cb1ed5d602938d825954526f
        │   ├── pytorch_model-00007-of-00007.bin -> ../../../blobs/b7cdaa9b8ed183284905c49d19bf42360037fdf2f95acb3093039d3c3a459261
        │   └── pytorch_model.bin.index.json -> ../../../blobs/cd6a94a411934e8cc04e271f169a68754d14733d
        ├── tokenizer
        │   ├── tokenization_chatglm.py -> ../../../blobs/50e44b05e4b3e54d2f1c3f0cab8247ea53a7d4e5
        │   ├── tokenizer.model -> ../../../blobs/e7dc4c393423b76e4373e5157ddc34803a0189ba96b21ddbb40269d31468a6f2
        │   ├── tokenizer_config.json -> ../../../blobs/f6f13c88707490cebd8023da86e8bf7a56fa21e3
        │   └── vocab.txt -> ../../../blobs/e7dc4c393423b76e4373e5157ddc34803a0189ba96b21ddbb40269d31468a6f2
        ├── unet
        │   └── config.json -> ../../../blobs/52ab049bc6d1490ead84272eb3736822e1d02d63
        └── vae
            ├── config.json -> ../../../blobs/6e9694046afd2a944dd17a2390b98773cacf2f7c
            └── diffusion_pytorch_model.fp16.bin -> ../../../blobs/2ce744db8ec41697eaecabe3508566aa76e53d71f79e595b0d0f56c9f07405ce

@asomoza
Copy link
Member Author

asomoza commented Jul 10, 2024

@vladmandic it was the decision of the model authors to have a different repo for the diffusers compatible model, so the repo is Kwai-Kolors/Kolors-diffusers and not Kwai-Kolors/Kolors

@vladmandic
Copy link
Contributor

vladmandic commented Jul 10, 2024

thanks for pointing it out, it works
two notes:

  • Kolors requires pipe.vae.config.force_upcast = True which is no longer needed on most modern models.
    any chance authors can use fp16-fixed vae?
  • and option to quantize massive text_encoder during load would be a nice addition.
    right now, ChatGLMModel does not expose underlying transformers quantization options.

(yes, both can be monkey-patched in app, just would be nice to have out-of-the-box solution)

@lixiang007666
Copy link

https://huggingface.co/Kwai-Kolors/Kolors-diffusers/tree/main

Currently, only fp16 models are available under huggingface. Should fp32 models also be provided to allow control via the variant parameter in from_pretrained?

@asomoza
Copy link
Member Author

asomoza commented Jul 11, 2024

@vladmandic the vae will be changed to the fp16fix one, about the quantization we're still deciding about it, the original one has an external dependency and we want to probably use transformers for it.

edit: The vae won't be changed since there was a performance degradation detected, users will need to replace it by themselves.

@asomoza
Copy link
Member Author

asomoza commented Jul 11, 2024

Currently, only fp16 models are available under huggingface. Should fp32 models also be provided to allow control via the variant parameter in from_pretrained

We can add them but AFAIK the model was trained in fp16, so you'll be just wasting VRAM with this. @JincanDeng can you confirm this?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Looks veryyyyy good. Left a couple of comments.

Comment on lines +536 to +538
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker for merging but did you have some time play around with these params? They are fun.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that you mention it, I've never played with these, not even with the base SDXL, I'll do some experiments but for kolors they change the generation but doesn't seem to do what they should. Maybe because of the training?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could very well be. @JincanDeng do you have any suggestions here?

@JincanDeng
Copy link

Currently, only fp16 models are available under huggingface. Should fp32 models also be provided to allow control via the variant parameter in from_pretrained

We can add them but AFAIK the model was trained in fp16, so you'll be just wasting VRAM with this. @JincanDeng can you confirm this?

Yes, the model was trained in fp16.

@vladmandic
Copy link
Contributor

vladmandic commented Jul 11, 2024

edit: The vae won't be changed since there was a performance degradation detected, users will need to replace it by themselves.

can you share that info? fp16-fixed should be better than requiring to run vae with upcast?

@asomoza
Copy link
Member Author

asomoza commented Jul 11, 2024

can you share that info? fp16-fixed should be better that requiring to run vae with upcast?

That was the answer I got, I don't have more info. I'm also curious about it so I intend to do some test when I have time.
Kolors could be more sensible to the vae since it's a lot more aesthetic.

Also probably the human eye doesn't see the difference between them but the numbers are more important to researchers (understandable) than the VRAM or if there's no visual difference.

@yiyixuxu yiyixuxu merged commit 87b9db6 into huggingface:main Jul 11, 2024
13 of 15 checks passed
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for Kolors
10 participants