diff --git a/tests/ignite/metrics/test_hsic.py b/tests/ignite/metrics/test_hsic.py index b43ef35dcdc..0fcb24f35c9 100644 --- a/tests/ignite/metrics/test_hsic.py +++ b/tests/ignite/metrics/test_hsic.py @@ -137,11 +137,10 @@ def test_integration(self, sigma_x: float, sigma_y: float): if device.type != "xla": metric_devices.append(device) - lin = nn.Linear(n_dims_x, n_dims_y) for metric_device in metric_devices: x = torch.randn((n_iters * batch_size, n_dims_x)).float().to(device) - lin.to(device) + lin = nn.Linear(n_dims_x, n_dims_y).to(device) y = torch.sin(lin(x) * 100) + torch.randn(n_iters * batch_size, n_dims_y) * 1e-4 def data_loader(i):