Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Adding BigVGAN as Vocoder #14

Closed
crypticsymmetry opened this issue Feb 21, 2023 · 18 comments
Closed

Adding BigVGAN as Vocoder #14

crypticsymmetry opened this issue Feb 21, 2023 · 18 comments

Comments

@crypticsymmetry
Copy link

crypticsymmetry commented Feb 21, 2023

Hey im trying to add my BigVGAN vocoder model to the inferencing script,. But when it generates audio it always has a lot of noise, compared to the inferencing script of the original BigVGAN code base. Any Ideas on why that could be? It looks to be the same setup as HiFi-GAN. https://github.com/NVIDIA/BigVGAN. If you would like one of my trained Models let me know ill give you DL link so you can test with it... as there are currently no available models.

Thanks in advance!

%cd /content/BigVGAN

from __future__ import absolute_import, division, print_function, unicode_literals
#import sys
#sys.path.append("./content/BigVGAN")
import glob
import os
import argparse
import json

from scipy.io.wavfile import write
from env import AttrDict
from meldataset1 import mel_spectrogram, MAX_WAV_VALUE
from models1 import BigVGAN as Generator
import librosa

torch.backends.cudnn.benchmark = False

def load_checkpoint(filepath, device):
    assert os.path.isfile(filepath)
    print("Loading '{}'".format(filepath))
    checkpoint_dict = torch.load(filepath, map_location=device)
    print("Complete.")
    return checkpoint_dict


def scan_checkpoint(cp_dir, prefix):
    pattern = os.path.join(cp_dir, prefix + '*')
    cp_list = glob.glob(pattern)
    if len(cp_list) == 0:
        return ''
    return sorted(cp_list)[-1]
    
cp_g = scan_checkpoint("configs", 'g_001')

config_file = os.path.join(os.path.split(cp_g)[0], 'bigvgan_24khz_100band.json') #actually 80-band to work with the StyleTTS model
with open(config_file) as f:
    data = f.read()


json_config = json.loads(data)
h = AttrDict(json_config)
device = torch.device(device)
generator = Generator(h).to(device)
state_dict_g = load_checkpoint(cp_g, device)

generator.load_state_dict(state_dict_g['generator'])
generator.eval()
generator.remove_weight_norm()
%cd /content/StyleTTS
import time
converted_samples = {}
start_time = time.time()

input_length = torch.LongTensor([tokens.shape[-1]]).to(device)
mask = length_to_mask(input_length).to(device)
with torch.no_grad():
    input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
    m = length_to_mask(input_lengths).to(device)
    t_en = model.text_encoder(tokens, input_lengths, m)
    for key, (ref, _) in reference_embeddings.items():
        s = ref.squeeze(1).to(device)
        style = s
        d = model.predictor.text_encoder(t_en, style, input_lengths, m)

        x, _ = model.predictor.lstm(d)
        duration = model.predictor.duration_proj(x)
        pred_dur = torch.round(duration.squeeze()).clamp(min=1)

        pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data), device=device)
        c_frame = 0
        for i in range(pred_aln_trg.size(0)):
            pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
            c_frame += int(pred_dur[i].data)

        en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0))
        style = s.expand(en.shape[0], en.shape[1], -1)

        F0_pred, N_pred = model.predictor.F0Ntrain(en, s)

        out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0)), F0_pred, N_pred, ref.squeeze().unsqueeze(0))

        audio_signal = out.cpu().numpy().squeeze()

        #Apply the Mel Spectrogram transformation
        mel_spectrogram = librosa.feature.melspectrogram(y=audio_signal, sr=24000, n_fft=1024, hop_length=256, n_mels=80, win_length=1024)

        #Convert the Mel Spectrogram to decibels
        mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)
        out1 = torch.FloatTensor(mel_spectrogram_db).to(device)
        y_g_hat = generator(out)
        y_out = y_g_hat.squeeze()
        y_out1 = y_out * MAX_WAV_VALUE
        y_out2 = y_out1.cpu().numpy()

        converted_samples[key] = y_out2

end_time = time.time()
print("Time taken: ", end_time - start_time, "seconds")

Also tried the original with the same result grain/noisy audio.

y_g_hat = generator(out)
y_out = y_g_hat.squeeze().cpu().numpy()
   
converted_samples[key] = y_out
@yl4579
Copy link
Owner

yl4579 commented Feb 21, 2023

It depends on your BigVGAN model. I'm a little curious how you got your BigVGAN as the there is no official pretrained model released yet.

If it was not trained to match the data preprocessing in meldataset.py, it would not work. You will need to do some rescaling and interpolation to match the preprocessing there (if the sampling rate is the same), or you will need to retrian your StyleTTS or BigVGAN to have the same preprocesing.

When there is an official pretrianed model, I will test the quality and see if it is worth it to change all my repos to match that preprocessing for 24 kHz.

@yl4579 yl4579 closed this as completed Feb 21, 2023
@crypticsymmetry
Copy link
Author

crypticsymmetry commented Feb 21, 2023

im currently training a model (using vast.ai, 3 3090s) based on the github instructions and using Lion-pytorch optimizer in place of AdamW. If you want a DL of the Generator let me know ill get you a link to test.

i changed num_mels to 80 to match StyleTTS, and left the rest to match the BigVGAN config. Could that be the issue?

{
    "resblock": "1",
    "num_gpus": 2,
    "batch_size": 32,
    "learning_rate": 0.00003,
    "adam_b1": 0.85, # original .8 -> .85 Lion recommends leaving at .9 and. 99, met in the middle "idk"
    "adam_b2": 0.99,
    "lr_decay": 0.999,
    "seed": 1234,

    "upsample_rates": [4,4,2,2,2,2],
    "upsample_kernel_sizes": [8,8,4,4,4,4],
    "upsample_initial_channel": 1536,
    "resblock_kernel_sizes": [3,7,11],
    "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],

    "activation": "snakebeta",
    "snake_logscale": true,

    "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
    "mpd_reshapes": [2, 3, 5, 7, 11],
    "use_spectral_norm": false,
    "discriminator_channel_mult": 1,

    "segment_size": 8192,
    "num_mels": 80, # Changed to 80 from 100
    "num_freq": 1025,
    "n_fft": 1024,
    "hop_size": 256,
    "win_size": 1024,

    "sampling_rate": 24000,

    "fmin": 0,
    "fmax": 12000,
    "fmax_for_loss": null,

    "num_workers": 4,

    "dist_config": {
        "dist_backend": "nccl",
        "dist_url": "tcp://0.0.0.0:54321",
        "world_size": 1
    }
}

@yl4579
Copy link
Owner

yl4579 commented Feb 21, 2023

It is not about the configuration. You need to match the preprocessing here as well. See yl4579/StarGANv2-VC#59 (comment)

@lexkoro
Copy link

lexkoro commented Feb 21, 2023

Since you are training BigVGAN you would also have to change n_fft=2048, win_length=1200, hop_length=300 in your BigVGAN config to match the params of StyleTTS.

@crypticsymmetry
Copy link
Author

crypticsymmetry commented Feb 21, 2023

oh okay gotcha, so basically retrain the preprocessing to match the same config, or just match the settings. sorry I'm just a tinkerer realistically :P

Edit: I think I read somewhere hop length has to be 256 for some reason for BigVGAN.

so I could go either way train BigVGAN with StyleTTS setting or retrain the preprocessors to match BigVGAN settings?

@yl4579
Copy link
Owner

yl4579 commented Feb 21, 2023

To add on @lexkoro, you also need to change how the melspectrogram is computed. In Hifi-GAN and BigVGAN, they are generated using librosa, while in all of my repos they are generated with torchaudio. There is a slight difference in how they compute melspectrogram, but all of these can be fixed by doing reverse mel and recomputing the mel scale using another library.

If I have extra time, I will change all of my repos to match that of BigVGAN (provided their pretrained models have better qualities than the vocoder I have).

@crypticsymmetry
Copy link
Author

crypticsymmetry commented Feb 21, 2023

Gotcha thanks for the help, yeah was super curious because of their examples, to see how it would work with your repo.

here is the BigVGAN generator trained up to 100,000 steps with pretty great quality IMO, though I haven't heard plain hifiGAN: https://drive.google.com/file/d/1VRUq3hjjXloTQ-gcYMRDsVpBqd_GXPrN/view?usp=sharing
First randomly selected Unseen Speaker example:
0085.webm

config:

{
    "resblock": "1",
    "num_gpus": 2,
    "batch_size": 32,
    "learning_rate": 0.00003,
    "adam_b1": 0.85, # original .8 -> .85 Lion recommends leaving at .9 and. 99, met in the middle "idk"
    "adam_b2": 0.99,
    "lr_decay": 0.999,
    "seed": 1234,

    "upsample_rates": [4,4,2,2,2,2],
    "upsample_kernel_sizes": [8,8,4,4,4,4],
    "upsample_initial_channel": 1536,
    "resblock_kernel_sizes": [3,7,11],
    "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],

    "activation": "snakebeta",
    "snake_logscale": true,

    "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
    "mpd_reshapes": [2, 3, 5, 7, 11],
    "use_spectral_norm": false,
    "discriminator_channel_mult": 1,

    "segment_size": 8192,
    "num_mels": 80, # Changed to 80 from 100
    "num_freq": 1025,
    "n_fft": 1024,
    "hop_size": 256,
    "win_size": 1024,

    "sampling_rate": 24000,

    "fmin": 0,
    "fmax": 12000,
    "fmax_for_loss": null,

    "num_workers": 4,

    "dist_config": {
        "dist_backend": "nccl",
        "dist_url": "tcp://0.0.0.0:54321",
        "world_size": 1
    }
}

@yl4579
Copy link
Owner

yl4579 commented Feb 21, 2023

You can either retrain the text aligner, pitch extractor and StyleTTS altogether with BigVGAN settings, or you can retrain the BigVGAN with StyleTTS settings. Whichever way works faster for you. If you do the former, it also saves the time for me because you can directly tell me how well it works on BigVGAN and I can create a new branch with all of the pretrained models you have there.

@crypticsymmetry
Copy link
Author

sounds good. i might train everything, ill let you know!

@crypticsymmetry
Copy link
Author

crypticsymmetry commented Feb 21, 2023

Should i convert from torchaudio to librosa in AuxiliaryASR and PitchExtractor or just leave it with torchaudio?
something like this?
chatgpt converted:

import librosa
import numpy as np

DEFAULT_DICT_PATH = osp.join(osp.dirname(__file__), 'word_index_dict.txt')
SPECT_PARAMS = {
    "n_fft": 1024,
    "win_length": 1024,
    "hop_length": 256
}
MEL_PARAMS = {
    "n_mels": 80,
    "n_fft": 1024,
    "win_length": 1024,
    "hop_length": 256
}

class MelDataset(torch.utils.data.Dataset):
    def __init__(self,
                 data_list,
                 dict_path=DEFAULT_DICT_PATH,
                 sr=24000
                ):

        spect_params = SPECT_PARAMS
        mel_params = MEL_PARAMS

        _data_list = [l[:-1].split('|') for l in data_list]
        self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list]
        self.text_cleaner = TextCleaner(dict_path)
        self.sr = sr

        self.to_melspec = librosa.feature.melspectrogram(sr=sr, **MEL_PARAMS)
        self.mean, self.std = -4, 4
        ```

@yl4579
Copy link
Owner

yl4579 commented Feb 22, 2023

You should use completely the way BIGVGAN processes the melspectrogram. You probably need to replace the meldataset class in this file entirely with this: https://github.com/NVIDIA/BigVGAN/blob/main/meldataset.py, except you also return the text labels.

@yl4579
Copy link
Owner

yl4579 commented Feb 22, 2023

They have released their pre-traiend models on 22k and it sounds quite good. I will try to retrain all the models to match BIGVGAN's setting and create a new branch for all of my repos.

@crypticsymmetry
Copy link
Author

crypticsymmetry commented Feb 22, 2023

Lol didn't expect them to release models so quick. I Might try continuing converting the scripts to, but might be a little hard for me, ill post if i theres any progress. Thats Awesome, can't wait! Good luck!

I wonder if it would be worth adding CLAP at some point to classify/direct the style or quality of audio somehow https://github.com/LAION-AI/CLAP I have no idea if this even worth mentioning 🤷‍♂️

@yl4579
Copy link
Owner

yl4579 commented Feb 23, 2023

Maybe you are referring to something like this: http://dongchaoyang.top/InstructTTS?

Note that it probably works for StyleTTS too if you have the right data, but their data isn't public, and CLAP isn't suitable for text-instructed emotion control as the CLAP isn't about speech but audio in general.

@crypticsymmetry
Copy link
Author

crypticsymmetry commented Feb 23, 2023

gotcha, well thanks for all the help/information, wish I could help out more, much appreciated.

@arampacha
Copy link

Hello @yl4579
I wonder if you still have plans to train and release StyleTTS compatible with BigVGAN?
I tried training one but the quality of outputs I've got so far is significantly worse compared to your original model. Could you share the the tensorboard logs for the LibriTTS model training? I'd like to compare the learning dynamics, maybe it help to understand what goes wrong

@yl4579
Copy link
Owner

yl4579 commented Apr 16, 2023

@arampacha I'm very busy with my other projects, so it might be difficult to find the Tensorboard logs for the LibriTTS dataset now. But could you please tell me if you were able to reproduce the results with BigVGAN recipes on LJSpeech, or were you able to train a model with the original recipe on LibriTTS?

Just want to make sure there's nothing wrong with the code I uploaded. I only tested this repo on LJSpeech unfortunately, the LibriTTS was the model I trained using experiment code without cleaning, so there might be some differences that I didn't realize. Also, I'd like to make sure the code you modified for BigVGAN recipes has no problems. I was able to train a model on LJSpeech with the BIGVGAN recipe with similar quality using the code from this repo, but I haven't tried it on LirbiTTS, so I didn't update the repo with it. But you can refer to this branch with the pre-trained text aligner and pitch extractor for how to modify the references. Once I finish the paper on the E2E version of StyleTTS I will update this repo with BIGVGAN recipes.

@arampacha
Copy link

Thanks for your response!
I've trained another model using an extended dataset (LibriTTS + some other small datasets) with same meldataset parameters (22kHz, 80 mel bands). Changed the phonemize to match yours - I used with_stress=False first time. Now the quality seems to be similar to that of the original model. But for some style vectors the duration prediction is not quite right resulting in slower speech compared to the style reference. The second stage did not completely converged for this run yet, so maybe it will improve in this regard. Maybe the previous run was unlucky - I had a jump in mel_loss during early first stage which didn't happen this time.
I didn't train LJSpeech model with BigVGAN meldataset params nor tried to reproduce your results with LibriTTS. If my model keeps having issues, I'll try to train model with my additional data starting from your checkpoint.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants