diff --git a/ignite/distributed/comp_models/xla.py b/ignite/distributed/comp_models/xla.py index 2c2cb27c9d18..620370f65c53 100644 --- a/ignite/distributed/comp_models/xla.py +++ b/ignite/distributed/comp_models/xla.py @@ -140,7 +140,7 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[ if group is not None and not self._check_group_type(group): raise ValueError("Argument group should be list of int") op = self._reduce_op_map[op] - xm.all_reduce(op, [tensor], groups=group) + xm.all_reduce(op, [tensor], groups=[group]) return tensor def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor: @@ -152,11 +152,11 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t group_size = self.get_world_size() output = torch.zeros((group_size,) + tensor.shape, dtype=tensor.dtype, device=tensor.device) output[self.get_rank() % group_size] = tensor - xm.all_reduce("sum", [output], groups=group) + xm.all_reduce("sum", [output], groups=[group]) return output.reshape(-1, *output.shape[2:]) def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any: - return [ranks] + return ranks def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: # from https://github.com/jysohn23/xla/blob/model-parallel-colab/Gather_Scatter_Broadcast_PyTorch_XLA.ipynb diff --git a/tests/ignite/distributed/utils/__init__.py b/tests/ignite/distributed/utils/__init__.py index 798cdd311837..040954a27956 100644 --- a/tests/ignite/distributed/utils/__init__.py +++ b/tests/ignite/distributed/utils/__init__.py @@ -155,17 +155,19 @@ def _test_distrib_all_reduce_group(device): def _test_distrib_all_gather(device): + rank = idist.get_rank() + res = torch.tensor(idist.all_gather(10), device=device) true_res = torch.tensor([10] * idist.get_world_size(), device=device) assert (res == true_res).all() - t = torch.tensor(idist.get_rank(), device=device) + t = torch.tensor(rank, device=device) res = idist.all_gather(t) true_res = torch.tensor([i for i in range(idist.get_world_size())], device=device) assert (res == true_res).all() x = "test-test" - if idist.get_rank() == 0: + if rank == 0: x = "abc" res = idist.all_gather(x) true_res = ["abc"] + ["test-test"] * (idist.get_world_size() - 1) @@ -173,14 +175,14 @@ def _test_distrib_all_gather(device): base_x = "tests/ignite/distributed/utils/test_native.py" * 2000 x = base_x - if idist.get_rank() == 0: + if rank == 0: x = "abc" res = idist.all_gather(x) true_res = ["abc"] + [base_x] * (idist.get_world_size() - 1) assert res == true_res - t = torch.arange(100, device=device).reshape(4, 25) * (idist.get_rank() + 1) + t = torch.arange(100, device=device).reshape(4, 25) * (rank + 1) in_dtype = t.dtype res = idist.all_gather(t) assert res.shape == (idist.get_world_size() * 4, 25) @@ -218,8 +220,6 @@ def _test_distrib_all_gather_group(device): res = idist.all_gather(t, group=ranks) assert torch.equal(res, torch.tensor(ranks, device=device)) - ranks = "abc" - if bnd in ("nccl", "gloo", "mpi"): with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"): res = idist.all_gather(t, group="abc") @@ -307,7 +307,7 @@ def _test_distrib_new_group(device): if rank in ranks: assert g1.rank() == g2.rank() elif idist.has_xla_support and bnd in ("xla-tpu"): - assert idist.new_group(ranks) == [ranks] + assert idist.new_group(ranks) == ranks elif idist.has_hvd_support and bnd in ("horovod"): from horovod.common.process_sets import ProcessSet