import math import os import random import torch import torch.utils.data import numpy as np from librosa.util import normalize from librosa.filters import mel as librosa_mel_fn import librosa import torchaudio import torch.nn as nn from pghipy import pghi def load_wav(full_path, sample_rate): data, _ = librosa.load(full_path, sr=sample_rate, mono=True) return data def dynamic_range_compression(x, C=1, clip_val=1e-5): return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) def dynamic_range_decompression(x, C=1): return np.exp(x) / C def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): return torch.log(torch.clamp(x, min=clip_val) * C) def dynamic_range_decompression_torch(x, C=1): return torch.exp(x) / C def spectral_normalize_torch(magnitudes): output = dynamic_range_compression_torch(magnitudes) return output def spectral_de_normalize_torch(magnitudes): output = dynamic_range_decompression_torch(magnitudes) return output mel_window = {} inv_mel_window = {} def param_string(sampling_rate, n_fft, num_mels, fmin, fmax, win_size, device): return f"{sampling_rate}-{n_fft}-{num_mels}-{fmin}-{fmax}-{win_size}-{device}" def mel_spectrogram( y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=True, in_dataset=False, ): global mel_window device = torch.device("cpu") if in_dataset else y.device ps = param_string(sampling_rate, n_fft, num_mels, fmin, fmax, win_size, device) if ps in mel_window: mel_basis, hann_window = mel_window[ps] # print(mel_basis, hann_window) # mel_basis, hann_window = mel_basis.to(y.device), hann_window.to(y.device) else: mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) mel_basis = torch.from_numpy(mel).float().to(device) hann_window = torch.hann_window(win_size).to(device) mel_window[ps] = (mel_basis.clone(), hann_window.clone()) spec = torch.stft( y.to(device), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window.to(device), center=True, return_complex=True, ) spec = mel_basis.to(device) @ spec.abs() spec = spectral_normalize_torch(spec) return spec # [batch_size,n_fft/2+1,frames] def inverse_mel( mel, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, in_dataset=False, ): global inv_mel_window, mel_window device = torch.device("cpu") if in_dataset else mel.device ps = param_string(sampling_rate, n_fft, num_mels, fmin, fmax, win_size, device) if ps in inv_mel_window: inv_basis = inv_mel_window[ps] else: if ps in mel_window: mel_basis, _ = mel_window[ps] else: mel_np = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) mel_basis = torch.from_numpy(mel_np).float().to(device) hann_window = torch.hann_window(win_size).to(device) mel_window[ps] = (mel_basis.clone(), hann_window.clone()) inv_basis = mel_basis.pinverse() inv_mel_window[ps] = inv_basis.clone() return inv_basis.to(device) @ spectral_de_normalize_torch(mel.to(device)) def amp_pha_specturm(y, n_fft, hop_size, win_size): hann_window = torch.hann_window(win_size).to(y.device) stft_spec = torch.stft( y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=True, return_complex=True, ) # [batch_size, n_fft//2+1, frames, 2] log_amplitude = torch.log( stft_spec.abs() + 1e-5 ) # [batch_size, n_fft//2+1, frames] phase = stft_spec.angle() # [batch_size, n_fft//2+1, frames] return log_amplitude, phase, stft_spec.real, stft_spec.imag def get_dataset_filelist(input_training_wav_list, input_validation_wav_list): training_files = [] filelist = os.listdir(input_training_wav_list) for files in filelist: src = os.path.join(input_training_wav_list, files) training_files.append(src) validation_files = [] filelist = os.listdir(input_validation_wav_list) for files in filelist: src = os.path.join(input_validation_wav_list, files) validation_files.append(src) return training_files, validation_files class Dataset(torch.utils.data.Dataset): def __init__( self, training_files, segment_size, n_fft, num_mels, hop_size, win_size, sampling_rate, fmin, fmax, meloss, split=True, shuffle=True, n_cache_reuse=1, device=None, inv_mel=False, use_pghi=False, ): self.audio_files = training_files random.seed(1234) if shuffle: random.shuffle(self.audio_files) self.segment_size = segment_size self.sampling_rate = sampling_rate self.split = split self.n_fft = n_fft self.num_mels = num_mels self.hop_size = hop_size self.win_size = win_size self.fmin = fmin self.fmax = fmax self.cached_wav = None self.n_cache_reuse = n_cache_reuse self._cache_ref_count = 0 self.device = device self.meloss = meloss self.inv_mel = inv_mel self.pghi = use_pghi def __getitem__(self, index): filename = self.audio_files[index] if self._cache_ref_count == 0: audio = load_wav(filename, self.sampling_rate) self.cached_wav = audio self._cache_ref_count = self.n_cache_reuse else: audio = self.cached_wav self._cache_ref_count -= 1 audio = torch.FloatTensor(audio) # [T] audio = audio.unsqueeze(0) # [1,T] if self.split: if audio.size(1) >= self.segment_size: max_audio_start = audio.size(1) - self.segment_size audio_start = random.randint(0, max_audio_start) audio = audio[:, audio_start : audio_start + self.segment_size] # [1,T] else: audio = torch.nn.functional.pad( audio, (0, self.segment_size - audio.size(1)), "constant" ) mel = mel_spectrogram( audio, self.n_fft, self.num_mels, self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax, center=True, in_dataset=True, ) meloss1 = mel_spectrogram( audio, self.n_fft, self.num_mels, self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.meloss, center=True, in_dataset=True, ) log_amplitude, phase, rea, imag = amp_pha_specturm( audio, self.n_fft, self.hop_size, self.win_size ) # [1,n_fft/2+1,frames] inv_mel = ( inverse_mel( mel, self.n_fft, self.num_mels, self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax, ) .abs() .clamp_min(1e-5) .squeeze() if self.inv_mel else torch.tensor([0]) ) if self.pghi: pghid = torch.tensor( pghi(inv_mel.squeeze(0).T.numpy(), self.win_size, self.hop_size) ).T pghid = torch.polar(torch.ones_like(pghid), pghid).angle() else: pghid = torch.tensor([0]) # print(pghid) return ( mel.squeeze(), log_amplitude.squeeze(), phase.squeeze(), rea.squeeze(), imag.squeeze(), audio.squeeze(0), meloss1.squeeze(), inv_mel, pghid, ) def __len__(self): return len(self.audio_files)