diff --git a/tests/ignite/metrics/test_hsic.py b/tests/ignite/metrics/test_hsic.py index f2d99d970f2..6ee4237b96e 100644 --- a/tests/ignite/metrics/test_hsic.py +++ b/tests/ignite/metrics/test_hsic.py @@ -120,54 +120,54 @@ def test_accumulator_detached(): @pytest.mark.usefixtures("distributed") class TestDistributed: - @pytest.mark.parametrize("sigma_x", [-1.0, 1.0]) - @pytest.mark.parametrize("sigma_y", [-1.0, 1.0]) - def test_integration(self, sigma_x: float, sigma_y: float): - tol = 2e-5 - n_iters = 100 - batch_size = 20 - n_dims_x = 100 - n_dims_y = 50 + @pytest.mark.parametrize("sigma_x", [-1.0, 1.0]) + @pytest.mark.parametrize("sigma_y", [-1.0, 1.0]) + def test_integration(self, sigma_x: float, sigma_y: float): + tol = 2e-5 + n_iters = 100 + batch_size = 20 + n_dims_x = 100 + n_dims_y = 50 + + rank = idist.get_rank() + torch.manual_seed(12 + rank) - rank = idist.get_rank() - torch.manual_seed(12 + rank) - - device = idist.device() - metric_devices = [torch.device("cpu")] - if device.type != "xla": - metric_devices.append(device) + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) - for metric_device in metric_devices: - x = torch.randn((n_iters * batch_size, n_dims_x)).float().to(device) + for metric_device in metric_devices: + x = torch.randn((n_iters * batch_size, n_dims_x)).float().to(device) - lin = nn.Linear(n_dims_x, n_dims_y).to(device) - y = torch.sin(lin(x) * 100) + torch.randn(n_iters * batch_size, n_dims_y) * 1e-4 + lin = nn.Linear(n_dims_x, n_dims_y).to(device) + y = torch.sin(lin(x) * 100) + torch.randn(n_iters * batch_size, n_dims_y) * 1e-4 - def data_loader(i, input_x, input_y): - return input_x[i * batch_size : (i + 1) * batch_size], input_y[i * batch_size : (i + 1) * batch_size] + def data_loader(i, input_x, input_y): + return input_x[i * batch_size : (i + 1) * batch_size], input_y[i * batch_size : (i + 1) * batch_size] - engine = Engine(lambda e, i: data_loader(i, x, y)) + engine = Engine(lambda e, i: data_loader(i, x, y)) - m = HSIC(sigma_x=sigma_x, sigma_y=sigma_y, device=metric_device) - m.attach(engine, "hsic") + m = HSIC(sigma_x=sigma_x, sigma_y=sigma_y, device=metric_device) + m.attach(engine, "hsic") - data = list(range(n_iters)) - engine.run(data=data, max_epochs=1) + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) - assert "hsic" in engine.state.metrics - res = engine.state.metrics["hsic"] + assert "hsic" in engine.state.metrics + res = engine.state.metrics["hsic"] - x = idist.all_gather(x) - y = idist.all_gather(y) - total_n_iters = idist.all_reduce(n_iters) + x = idist.all_gather(x) + y = idist.all_gather(y) + total_n_iters = idist.all_reduce(n_iters) - np_res = 0.0 - for i in range(total_n_iters): - x_batch, y_batch = data_loader(i, x, y) - np_res += np_hsic(x_batch, y_batch, sigma_x, sigma_y) + np_res = 0.0 + for i in range(total_n_iters): + x_batch, y_batch = data_loader(i, x, y) + np_res += np_hsic(x_batch, y_batch, sigma_x, sigma_y) - expected_hsic = np_res / total_n_iters - assert pytest.approx(expected_hsic, abs=tol) == res + expected_hsic = np_res / total_n_iters + assert pytest.approx(expected_hsic, abs=tol) == res def test_accumulator_device(self): device = idist.device()