Skip to content

Consistency Distillation with Target Timestep Selection and Decoupled Guidance

Notifications You must be signed in to change notification settings

RedAIGC/Target-Driven-Distillation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

64 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

✨Target-Driven Distillation✨

Arxiv Project page Hugging Face Model Hugging Face Space SDXL Hugging Face Space SDXL Hugging Face Space SVD

Target-Driven Distillation (TDD) is a state-of-the-art consistency distillation model that largely accelerates the inference processes of diffusion models. Using its delicate strategies of target timestep selection and decoupled guidance, models distilled by TDD can generated highly detailed images with only a few steps.

teaser

Samples generated by TDD-distilled SDXL, with only 4--8 steps.

News

  • Sept. 21, 2024: Demo of FLUX-TDD-BETA(4-8-steps) is now available on Hugging Face Hugging Face Space FLUX
  • Sept. 20, 2024: We have released codes for training on FLUX , Our 4-8-steps FLUX.1-dev-related LoRAs are coming soon!
  • Sept. 12, 2024: Demos of TDD-SDXL and TDD-SVD are now available on Hugging Face Hugging Face Space SDXL Hugging Face Space SVD. Give them a try!
  • Sept. 4, 2024: Our detailed research paper is now on arXiv Arxiv.
  • Aug. 29, 2024: We have released codes for training and inference, as well as the pretrained models both w/ and w/o adv, on SDXL.
  • Aug. 22, 2024: Project launched.

Demos

Comparison with Previous Works(LCM, PCM, TCD). From the same seeds, our method(TDD) demonstrates advantages in both image complexity and clarity.

comparison

Video samples generated by AnimateLCM-distilled (top) and TDD-distilled (bottom) SVD-xt 1.1, also with 4--8 steps.

animatelcm_tdd_samples_w_name.mp4

Samples generated by TDD-distilled different base models, and by SDXL with different LoRA adapters or ControlNets.

other

Usage

Inference

  • Clone this repository.
git clone https://github.com/RedAIGC/Target-Driven-Distillation.git
cd Target-Driven-Distillation
  • FLUX Download pretrained models with the script below or from Hugging Face Models.
from huggingface_hub import hf_hub_download
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights(hf_hub_download("RED-AIGC/TDD", "TDD-FLUX.1-dev-lora-beta.safetensors"))
pipe.fuse_lora(lora_scale=0.125)
pipe.to("cuda")

image_flux = pipe(
    prompt=[prompt],
    generator=torch.Generator().manual_seed(int(3413)),
    num_inference_steps=8,
    guidance_scale=2.0,
    height=1024,
    width=1024,
    max_sequence_length=256
).images[0]
  • SDXL Download pretrained models with the script below or from Hugging Face Models.
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id="RedAIGC/TDD", filename="sdxl_tdd_lora_weights.safetensors", local_dir="./tdd_lora")
  • Generate images.
# !pip install opencv-python transformers accelerate 
import torch
import diffusers
from diffusers import StableDiffusionXLPipeline
from tdd_scheduler import TDDScheduler

device = "cuda"
tdd_lora_path = "tdd_lora/sdxl_tdd_lora_weights.safetensors"

pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16").to(device)

pipe.scheduler = TDDScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights(tdd_lora_path, adapter_name="accelerate")
pipe.fuse_lora()

prompt="A photo of a cat made of water."

image = pipe(
    prompt=prompt,
    num_inference_steps=4,
    guidance_scale=1.7,
    eta=0.2, 
    generator=torch.Generator(device=device).manual_seed(546237),
).images[0]

image.save("tdd.png")

Training

See scripts under train.

Introduction

Target-Driven Distillation (TDD) features three key designs, that differ from previous consistency distillation methods.

  1. TDD adopts a delicate selection strategy of target timesteps, increasing the training efficiency. Specifically, it first chooses from a predefined set of equidistant denoising schedules (e.g. 4--8 steps), then adds a stochatic offset to accomodate non-deterministic sampling (e.g. $\gamma$-sampling).
  2. TDD utilizes decoupled guidances during training, making itself open to post-tuning on guidance scale during inference periods. Specifically, it replaces a portion of the text conditions with unconditional (i.e. empty) prompts, in order to align with the standard training process using CFG.
  3. TDD can be optionally equipped with non-equidistant sampling and x0 clipping, enabling a more flexible and accurate way for image sampling.
overview

An overview of TDD. (a) The training process features target timestep selection and decoupled guidance. (b) The inference process can optionally adopt non-equidistant denoising schedules.

For further details of TDD, please refer to our paper: Arxiv.

Acknowledgements

  • Thanks sdbds help us in the training FLUX, This allows us to distill FLUX with a larger batch size.
  • Thanks PSNbst provide the compressed version of TDD, which is less than 20MB. Truly impressive.
  • Thanks to the PCM PCM team for their ADV_loss support!
  • Thanks to the HuggingFace gradio team for their free GPU support!

Concact, Collaboration, and Citationvisitors

If you have any questions about the code, please do not hesitate to contact me!

Email: polu@xiaohongshu.com Email: wangcunzheng2000@163.com