We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a2a142d commit 79ed5deCopy full SHA for 79ed5de
tests/test_wavegrad_train.py
@@ -1,5 +1,6 @@
1
import unittest
2
3
+import numpy as np
4
import torch
5
from torch import optim
6
from TTS.vocoder.models.wavegrad import Wavegrad
@@ -33,7 +34,8 @@ def test_train_step(self): # pylint: disable=no-self-use
33
34
[1, 2, 4, 8]])
35
model.train()
36
model.to(device)
- model.compute_noise_level(1000, 1e-6, 1e-2)
37
+ betas = np.linspace(1e-6, 1e-2, 1000)
38
+ model.compute_noise_level(betas)
39
model_ref.load_state_dict(model.state_dict())
40
model_ref.to(device)
41
count = 0
0 commit comments