Skip to content

Commit 79ed5de

Browse files
committed
fix wavegrad test
1 parent a2a142d commit 79ed5de

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tests/test_wavegrad_train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22

3+
import numpy as np
34
import torch
45
from torch import optim
56
from TTS.vocoder.models.wavegrad import Wavegrad
@@ -33,7 +34,8 @@ def test_train_step(self): # pylint: disable=no-self-use
3334
[1, 2, 4, 8]])
3435
model.train()
3536
model.to(device)
36-
model.compute_noise_level(1000, 1e-6, 1e-2)
37+
betas = np.linspace(1e-6, 1e-2, 1000)
38+
model.compute_noise_level(betas)
3739
model_ref.load_state_dict(model.state_dict())
3840
model_ref.to(device)
3941
count = 0

0 commit comments

Comments
 (0)