Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

PyTorch 2.1 Updates (Weight Norm and TorchAudio I/O) #3176

Merged
merged 4 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions TTS/tts/layers/delightful_tts/conv_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn.functional as F
from torch.nn.utils import parametrize

from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor

Expand Down Expand Up @@ -73,7 +74,7 @@ def __init__(
)
nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain))
if self.use_weight_norm:
self.conv = nn.utils.weight_norm(self.conv)
self.conv = nn.utils.parametrizations.weight_norm(self.conv)

def forward(self, signal, mask=None):
conv_signal = self.conv(signal)
Expand Down Expand Up @@ -113,7 +114,7 @@ def __init__(
dilation=1,
w_init_gain="relu",
)
conv_layer = nn.utils.weight_norm(conv_layer.conv, name="weight")
conv_layer = nn.utils.parametrizations.weight_norm(conv_layer.conv, name="weight")
convolutions.append(conv_layer)

self.convolutions = nn.ModuleList(convolutions)
Expand Down Expand Up @@ -567,7 +568,7 @@ def __init__( # pylint: disable=dangerous-default-value

self.convt_pre = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.ConvTranspose1d(
in_channels,
in_channels,
Expand All @@ -584,7 +585,7 @@ def __init__( # pylint: disable=dangerous-default-value
self.conv_blocks.append(
nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
in_channels,
in_channels,
Expand Down Expand Up @@ -665,6 +666,6 @@ def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=25

def remove_weight_norm(self):
self.kernel_predictor.remove_weight_norm()
nn.utils.remove_weight_norm(self.convt_pre[1])
parametrize.remove_parametrizations(self.convt_pre[1], "weight")
for block in self.conv_blocks:
nn.utils.remove_weight_norm(block[1])
parametrize.remove_parametrizations(block[1], "weight")
23 changes: 13 additions & 10 deletions TTS/tts/layers/delightful_tts/kernel_predictor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch.nn as nn # pylint: disable=consider-using-from-import
from torch.nn.utils import parametrize


class KernelPredictor(nn.Module):
Expand Down Expand Up @@ -36,7 +37,9 @@ def __init__( # pylint: disable=dangerous-default-value
kpnet_bias_channels = conv_out_channels * conv_layers # l_b

self.input_conv = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
nn.utils.parametrizations.weight_norm(
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)

Expand All @@ -46,7 +49,7 @@ def __init__( # pylint: disable=dangerous-default-value
self.residual_convs.append(
nn.Sequential(
nn.Dropout(kpnet_dropout),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
Expand All @@ -56,7 +59,7 @@ def __init__( # pylint: disable=dangerous-default-value
)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
Expand All @@ -68,7 +71,7 @@ def __init__( # pylint: disable=dangerous-default-value
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
)
self.kernel_conv = nn.utils.weight_norm(
self.kernel_conv = nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_kernel_channels,
Expand All @@ -77,7 +80,7 @@ def __init__( # pylint: disable=dangerous-default-value
bias=True,
)
)
self.bias_conv = nn.utils.weight_norm(
self.bias_conv = nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_bias_channels,
Expand Down Expand Up @@ -117,9 +120,9 @@ def forward(self, c):
return kernels, bias

def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.input_conv[0])
nn.utils.remove_weight_norm(self.kernel_conv)
nn.utils.remove_weight_norm(self.bias_conv)
parametrize.remove_parametrizations(self.input_conv[0], "weight")
parametrize.remove_parametrizations(self.kernel_conv, "weight")
parametrize.remove_parametrizations(self.bias_conv, "weight")
for block in self.residual_convs:
nn.utils.remove_weight_norm(block[1])
nn.utils.remove_weight_norm(block[3])
parametrize.remove_parametrizations(block[1], "weight")
parametrize.remove_parametrizations(block[3], "weight")
13 changes: 7 additions & 6 deletions TTS/tts/layers/generic/wavenet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from torch import nn
from torch.nn.utils import parametrize


@torch.jit.script
Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(
# init conditioning layer
if c_in_channels > 0:
cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
# intermediate layers
for i in range(num_layers):
dilation = dilation_rate**i
Expand All @@ -75,7 +76,7 @@ def __init__(
in_layer = torch.nn.Conv1d(
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)

if i < num_layers - 1:
Expand All @@ -84,7 +85,7 @@ def __init__(
res_skip_channels = hidden_channels

res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
# setup weight norm
if not weight_norm:
Expand Down Expand Up @@ -115,11 +116,11 @@ def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-a

def remove_weight_norm(self):
if self.c_in_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
parametrize.remove_parametrizations(self.cond_layer, "weight")
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
parametrize.remove_parametrizations(l, "weight")
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
parametrize.remove_parametrizations(l, "weight")


class WNBlocks(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/glow_tts/glow.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __init__(
self.sigmoid_scale = sigmoid_scale
# input layer
start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
start = torch.nn.utils.weight_norm(start)
start = torch.nn.utils.parametrizations.weight_norm(start)
self.start = start
# output layer
# Initializing last layer to 0 makes the affine coupling layers
Expand Down
42 changes: 23 additions & 19 deletions TTS/tts/layers/tortoise/vocoder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize

MAX_WAV_VALUE = 32768.0

Expand Down Expand Up @@ -44,7 +44,9 @@ def __init__(
kpnet_bias_channels = conv_out_channels * conv_layers # l_b

self.input_conv = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
nn.utils.parametrizations.weight_norm(
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)

Expand All @@ -54,7 +56,7 @@ def __init__(
self.residual_convs.append(
nn.Sequential(
nn.Dropout(kpnet_dropout),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
Expand All @@ -64,7 +66,7 @@ def __init__(
)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
Expand All @@ -76,7 +78,7 @@ def __init__(
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
)
self.kernel_conv = nn.utils.weight_norm(
self.kernel_conv = nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_kernel_channels,
Expand All @@ -85,7 +87,7 @@ def __init__(
bias=True,
)
)
self.bias_conv = nn.utils.weight_norm(
self.bias_conv = nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_bias_channels,
Expand Down Expand Up @@ -125,12 +127,12 @@ def forward(self, c):
return kernels, bias

def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.input_conv[0])
nn.utils.remove_weight_norm(self.kernel_conv)
nn.utils.remove_weight_norm(self.bias_conv)
parametrize.remove_parametrizations(self.input_conv[0], "weight")
parametrize.remove_parametrizations(self.kernel_conv, "weight")
parametrize.remove_parametrizations(self.bias_conv)
for block in self.residual_convs:
nn.utils.remove_weight_norm(block[1])
nn.utils.remove_weight_norm(block[3])
parametrize.remove_parametrizations(block[1], "weight")
parametrize.remove_parametrizations(block[3], "weight")


class LVCBlock(torch.nn.Module):
Expand Down Expand Up @@ -169,7 +171,7 @@ def __init__(

self.convt_pre = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.ConvTranspose1d(
in_channels,
in_channels,
Expand All @@ -186,7 +188,7 @@ def __init__(
self.conv_blocks.append(
nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
in_channels,
in_channels,
Expand Down Expand Up @@ -267,9 +269,9 @@ def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=25

def remove_weight_norm(self):
self.kernel_predictor.remove_weight_norm()
nn.utils.remove_weight_norm(self.convt_pre[1])
parametrize.remove_parametrizations(self.convt_pre[1], "weight")
for block in self.conv_blocks:
nn.utils.remove_weight_norm(block[1])
parametrize.remove_parametrizations(block[1], "weight")


class UnivNetGenerator(nn.Module):
Expand Down Expand Up @@ -314,11 +316,13 @@ def __init__(
)
)

self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect"))
self.conv_pre = nn.utils.parametrizations.weight_norm(
nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect")
)

self.conv_post = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
nn.utils.parametrizations.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
nn.Tanh(),
)

Expand Down Expand Up @@ -346,11 +350,11 @@ def eval(self, inference=False):
self.remove_weight_norm()

def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv_pre)
parametrize.remove_parametrizations(self.conv_pre, "weight")

for layer in self.conv_post:
if len(layer.state_dict()) != 0:
nn.utils.remove_weight_norm(layer)
parametrize.remove_parametrizations(layer, "weight")

for res_block in self.res_stack:
res_block.remove_weight_norm()
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/vits/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class DiscriminatorS(torch.nn.Module):

def __init__(self, use_spectral_norm=False):
super().__init__()
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
Expand Down
19 changes: 10 additions & 9 deletions TTS/tts/layers/xtts/hifigan_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations

from TTS.utils.io import load_fsspec

Expand Down Expand Up @@ -120,9 +121,9 @@ def forward(self, x):

def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.convs2:
remove_weight_norm(l)
remove_parametrizations(l, "weight")


class ResBlock2(torch.nn.Module):
Expand Down Expand Up @@ -176,7 +177,7 @@ def forward(self, x):

def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
remove_parametrizations(l, "weight")


class HifiganGenerator(torch.nn.Module):
Expand Down Expand Up @@ -251,10 +252,10 @@ def __init__(
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)

if not conv_pre_weight_norm:
remove_weight_norm(self.conv_pre)
remove_parametrizations(self.conv_pre, "weight")

if not conv_post_weight_norm:
remove_weight_norm(self.conv_post)
remove_parametrizations(self.conv_post, "weight")

if self.cond_in_each_up_layer:
self.conds = nn.ModuleList()
Expand Down Expand Up @@ -317,11 +318,11 @@ def inference(self, c):
def remove_weight_norm(self):
print("Removing weight norm...")
for l in self.ups:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
remove_parametrizations(self.conv_pre, "weight")
remove_parametrizations(self.conv_post, "weight")

def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
Expand Down
3 changes: 1 addition & 2 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from contextlib import contextmanager
from dataclasses import dataclass

import librosa
Expand All @@ -8,7 +7,7 @@
import torchaudio
from coqpit import Coqpit

from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
from TTS.tts.layers.tortoise.audio_utils import wav_to_univnet_mel
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
Expand Down
Loading