This is the official code implementation for the paper
ConsistencyTTA: Accelerating Diffusion-Based Text-to-Audio Generation with Consistency Distillation
from Microsoft Applied Science Group and UC Berkeley
by Yatong Bai,
Trung Dang,
Dung Tran,
Kazuhito Koishida,
and Somayeh Sojoudi.
[🤗 Live Demo]
[Preprint Paper]
[Project Homepage]
[Code]
[Model Checkpoints]
[Generation Examples]
2024/06 Updates:
- We have hosted an interactive live demo of ConsistencyTTA at 🤗 Huggingface.
- ConsistencyTTA has been accepted to INTERSPEECH 2024! We look forward to meeting you in Kos Island.
- We added a simpler inference-only implementation to the
easy_inference
directory of this repo.
This work proposes a consistency distillation framework to train text-to-audio (TTA) generation models that only require a single neural network query, reducing the computation of the core step of diffusion-based TTA models by a factor of 400. By incorporating classifier-free guidance into the distillation framework, our models retain diffusion models' impressive generation quality and diversity. Furthermore, the non-recurrent differentiable structure of the consistency model allows for end-to-end fine-tuning with novel loss functions such as the CLAP score, further boosting performance.
This codebase performs training, evaluation, and inference.
If you only wish to do inference, there is a simpler implementation at easy_inference
.
This codebase uses PyTorch as the central implementation tool, with extensive usage of HuggingFace's Accelerator package.
The required packages can be found in environment.yml
.
We share three model checkpoints:
- ConsistencyTTA directly distilled from a diffusion model;
- ConsistencyTTA fine-tuned by optimizing the CLAP score;
- The diffusion teacher model from which ConsistencyTTA is distilled.
The first two models are capable of high-quality single-step text-to-audio generation. Generations are 10 seconds long.
These model checkpoints are available on our Huggingface page.
After downloading and unzipping the files, place them in the saved
directory.
ConsistencyTTA models are trained on the AudioCaps dataset. Please download the dataset following the instructions on their website (we cannot share the data).
The .json
files in the data
directory are used for training and evaluation.
Once you have downloaded your version of the data,
you should be able to map it to our format using the file IDs provided in the .json
files.
Please modify the file locations in the .json
files accordingly.
To perform an interactive demo, where the model generates audio following user's input prompts, run the following script:
python demo.py --original_args saved/ConsistencyTTA/summary.jsonl \
--model saved/ConsistencyTTA/epoch_60/pytorch_model_2.bin --use_ema
Some example prompts include:
- Food sizzling with some knocking and banging followed by a dog barking.
- Train diesel engine rumbling and a baby crying.
The training of our consistency model contains three distillation phases:
- (Optional) Distill a diffusion model with adjustable guidance strength.
- Perform the consistency distillation.
- (Optional) Optimize the CLAP score to finetune.
The file train.sh
contains the training script for all three stages.
The trained model checkpoints will be stored in the /saved
directory.
The teacher model for our distilled consistency models is based on TANGO, a state-of-the-art TTA generation framework based on latent diffusion models.
The training script should automatically download the AudioLDM weights from here.
However, if the download is slow or if you face any other issues, then you can:
i) download the audioldm-s-full
file from here,
ii) rename it to audioldm-s-full.ckpt
,
and iii) keep it in the /home/user/.cache/audioldm/
directory.
For fine-tuning and evaluating with CLAP, we use this
CLAP model checkpoint from this repository.
After downloading, place it into the /ckpt
directory.
On two Nvidia RTX 6000 Ada GPUs, Stage 1 (40 epochs) should take ~40 hours, Stage 2 (60 epochs) should take ~80 hours, and Stage 3 (10 epochs) ~30 hours.
To perform inference using a trained consistency model and evaluate the generated audio clips, please refer to inference.sh
.
The generated audio clips will be stored in the /outputs
directory.
To evaluate existing audio generations, use evaluate_existing.py
.
An example script is in inference.sh
.
Our evaluation metrics include Fréchet Audio Distance (FAD), Fréchet Distance (FD), KL Divergence, and CLAP Scores.
# queries (↓) | CLAPT (↑) | CLAPA (↑) | FAD (↓) | FD (↓) | KLD (↓) | |
---|---|---|---|---|---|---|
Diffusion (Baseline) | 400 | 24.57 | 72.79 | 1.908 | 19.57 | 1.350 |
Consistency + CLAP FT (Ours) | 1 | 24.69 | 72.54 | 2.406 | 20.97 | 1.358 |
Consistency (Ours) | 1 | 22.50 | 72.30 | 2.575 | 22.08 | 1.354 |
This PaperWithCode benchmark demonstrates how our single-step models stack up against previous methods, most of which mostly require hundreds of generation steps.
@article{bai2023consistencytta,
title={ConsistencyTTA: Accelerating Diffusion-Based Text-to-Audio Generation with Consistency Distillation},
author={Bai, Yatong and Dang, Trung and Tran, Dung and Koishida, Kazuhito and Sojoudi, Somayeh},
journal={arXiv preprint arXiv:2309.10740},
year={2023}
}
Third-Party Code. The structure of this repository roughly follows TANGO,
which in turn heavily relies on Diffusers and AudioLDM.
We made modifications in audioldm
, audioldm_eval
, and diffusers
directories to for training and evaluating Consistency TTA.
We sincerely appreciate the the authors of these repositories for open-sourcing them.
Please refer to NOTICE.md
for license information.
Trademarks. This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft’s Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party’s policies.