We have put a large emphasis on making training as fast as possible. Consequently, some pre-processing steps are required.
Namely, before starting any training, we
- Obtain training data as videos, audios, and captions.
- Encode training audios into spectrograms and then with VAE into mean/std
- Extract CLIP and synchronization features from videos
- Extract CLIP features from text (captions)
- Encode all extracted features into MemoryMappedTensors with TensorDict
NOTE: for maximum training speed (e.g., when training the base model with 2*H100s), you would need around 3~5 GB/s of random read speed. Spinning disks would not be able to catch up and most consumer-grade SSDs would struggle. In my experience, the best bet is to have a large enough system memory such that the OS can cache the data. This way, the data is read from RAM instead of disk.
The current training script does not support _v2
training.
These are what I recommend for a smooth and efficient training experience. These are not minimum requirements.
- Single-node machine. We did not implement multi-node training
- GPUs: for the small model, two 80G-H100s or above; for the large model, eight 80G-H100s or above
- System memory: for 16kHz training, 600GB+; for 44kHz training, 700GB+
- Storage: >2TB of fast NVMe storage. If you have enough system memory, OS caching will help and the storage does not need to be as fast.
-
Install av-benchmark. We use this library to automatically evaluate on the validation set during training, and on the test set after training.
-
Extract features for evaluation using av-benchmark for the validation and test set as a validation cache and a test cache. You can also download the precomputed evaluation cache here.
-
You will need ffmpeg to extract frames from videos. Note that
torchaudio
imposes a maximum version limit (ffmpeg<7
). You can install it as follows:
conda install -c conda-forge 'ffmpeg<7'
-
Download the training datasets. We used VGGSound, AudioCaps, and WavCaps. Note that the audio files in the huggingface release of WavCaps have been downsampled to 32kHz. To the best of our ability, we located the original (high-sampling rate) audio files and used them instead to prevent artifacts during 44.1kHz training. We did not use the "SoundBible" portion of WavCaps, since it is a small set with many short audio unsuitable for our training.
-
Download the corresponding VAE (
v1-16.pth
for 16kHz training, andv1-44.pth
for 44.1kHz training), vocoder models (best_netG.pt
for 16kHz training; the vocoder for 44.1kHz training will be downloaded automatically), the empty string encoding, and Synchformer weights from MODELS.md place them inext_weights/
.
We have prepared some example data in training/example_videos
.
training/extract_video_training_latents.py
extracts audio, video, and text features and save them as a TensorDict
with a .tsv
file containing metadata to output_dir
.
To run this script, use the torchrun
utility:
torchrun --standalone training/extract_video_training_latents.py
You can run this script with multiple GPUs (with --nproc_per_node=<n>
after --standalone
and before the script name) to speed up extraction.
Modify the definitions near the top of the script to switch between 16kHz/44.1kHz extraction.
Change the data path definitions in data_cfg
if necessary.
Arguments:
latent_dir
-- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.output_dir
-- where TensorDict and the metadata file are saved.
Outputs produced in output_dir
:
- A directory named
vgg-{split}
(i.e., in the TensorDict format), containing a.mean.memmap
mean values predicted by the VAE encoder (number of videos X sequence length X channel size) b.std.memmap
standard deviation values predicted by the VAE encoder (number of videos X sequence length X channel size) c.text_features.memmap
text features extracted from CLIP (number of videos X 77 (sequence length) X 1024) d.clip_features.memmap
clip features extracted from CLIP (number of videos X 64 (8 fps) X 1024) e.sync_features.memmap
synchronization features extracted from Synchformer (number of videos X 192 (24 fps) X 768) f.meta.json
that contains the metadata for the above memory mappings - A tab-separated values file named
vgg-{split}.tsv
that contains two columns:id
containing video file names without extension, andlabel
containing corresponding text labels (i.e., captions)
We have prepared some example data in training/example_audios
.
- Run
training/partition_clips
to partition each audio file into clips (by finding start and end points; we do not save the partitioned audio onto the disk to save disk space) - Run
training/extract_audio_training_latents.py
to extract each clip's audio and text features and save them as aTensorDict
with a.tsv
file containing metadata tooutput_dir
.
Run
python training/partition_clips.py
Arguments:
data_dir
-- path to a directory containing the audio files (.flac
or.wav
)output_dir
-- path to the output.csv
filestart
-- optional; useful when you need to run multiple processes to speed up processing -- this defines the beginning of the chunk to be processedend
-- optional; useful when you need to run multiple processes to speed up processing -- this defines the end of the chunk to be processed
Run
torchrun --standalone training/extract_audio_training_latents.py
You can run this with multiple GPUs (with --nproc_per_node=<n>
) to speed up extraction.
Modify the definitions near the top of the script to switch between 16kHz/44.1kHz extraction.
Arguments:
data_dir
-- path to a directory containing the audio files (.flac
or.wav
), same as the previous stepcaptions_tsv
-- path to the captions file, a tab-separated values (tsv) file at least with columnsid
andcaption
clips_tsv
-- path to the clips file, generated in the last steplatent_dir
-- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.output_dir
-- where TensorDict and the metadata file are saved.
Reference tsv files (with overlaps removed as mentioned in the paper) can be found here.
Note that these reference tsv files are the outputs of extract_audio_training_latents.py
, which means the id
column might contain duplicate entries (one per clip). You can still use it as the captions_tsv
input though -- the script will handle duplicates gracefully.
Among these reference tsv files, audioset_sl.tsv
, bbcsound.tsv
, and freesound.tsv
are subsets that are parts of WavCaps. These subsets might be smaller than the original datasets.
The Clotho data contains both the development set and the validation set.
Outputs produced in output_dir
:
- A directory named
{basename(output_dir)}
(i.e., in the TensorDict format), containing a.mean.memmap
mean values predicted by the VAE encoder (number of audios X sequence length X channel size) b.std.memmap
standard deviation values predicted by the VAE encoder (number of audios X sequence length X channel size) c.text_features.memmap
text features extracted from CLIP (number of audios X 77 (sequence length) X 1024) f.meta.json
that contains the metadata for the above memory mappings - A tab-separated values file named
{basename(output_dir)}.tsv
that contains two columns:id
containing audio file names without extension, andlabel
containing corresponding text labels (i.e., captions)
We use Distributed Data Parallel (DDP) for training.
First, specify the data path in config/data/base.yaml
. If you used the default parameters in the scripts above to extract features for the example data, the Example_video
and Example_audio
items should already be correct.
To run training on the example data, use the following command:
OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=1 train.py exp_id=debug compile=False debug=True example_train=True batch_size=1
This will not train a useful model, but it will check if everything is set up correctly.
For full training on the base model with two GPUs, use the following command:
OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=2 train.py exp_id=exp_1 model=small_16k
Any outputs from training will be stored in output/<exp_id>
.
More configuration options can be found in config/base_config.yaml
and config/train_config.yaml
.
For the medium and large models, specify vgg_oversample_rate
to be 3
to reduce overfitting.
Model checkpoints, including optimizer states and the latest EMA weights, are available here: https://huggingface.co/hkchengrex/MMAudio
Godspeed!