diff --git a/test/distributed/test_distributed.py b/test/distributed/test_distributed.py index 74a4e255cac..2e570ed1505 100644 --- a/test/distributed/test_distributed.py +++ b/test/distributed/test_distributed.py @@ -3,6 +3,7 @@ import errno import fcntl import os +import random import sys import time import tempfile @@ -2249,6 +2250,132 @@ def test_SyncBatchNorm_process_group(self): process_group_sync = res50_model_sync.layer1[0].bn1.process_group self.assertEqual(process_group_sync, process_group) + def _run_reduction_test( + self, tensor, expected_tensor, op, reduction_fn=dist.all_reduce, dst=None + ): + if reduction_fn != dist.all_reduce and dst is None: + raise ValueError(f"Reduction fn {reduction_fn} must specify dst!") + if dst is not None: + reduction_fn(tensor, dst, op) + # Only destination rank tensor is expected to have final result. + if dist.get_rank() == dst: + self.assertEqual(tensor, expected_tensor) + else: + reduction_fn(tensor, op) + self.assertEqual(tensor, expected_tensor) + + @require_backend({"nccl"}) + @require_backends_available({"nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_nccl_backend_bool_allreduce(self): + torch.cuda.set_device(self.rank) + # Run all_reduce with PRODUCT + element = self.rank % 2 == 0 + for op in [dist.ReduceOp.PRODUCT, dist.ReduceOp.MIN]: + input_tensor = torch.tensor([element, element]).to(self.rank) + self._run_reduction_test( + input_tensor, torch.tensor([False, False]).to(self.rank), op + ) + # Ensure that all ranks contributing True (cast to 1) results in the + # correct reduction. + input_tensor = torch.tensor([True, True]).to(self.rank) + expected_tensor = input_tensor.clone() + self._run_reduction_test( + input_tensor, expected_tensor, op + ) + + # Run all_reduce with SUM + for op in [dist.ReduceOp.SUM, dist.ReduceOp.MAX]: + input_tensor = torch.tensor([element, element]).to(self.rank) + self._run_reduction_test( + input_tensor, torch.tensor([True, True]).to(self.rank), op + ) + # TODO: NCCL backend does not work correctly for bitwise reduction ops + # (see https://github.com/pytorch/pytorch/issues/41362). Add tests for + # these once it is supported. + + @require_backend({"nccl"}) + @require_backends_available({"nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_nccl_backend_bool_allgather(self): + torch.cuda.set_device(self.rank) + inp = {0: [True, True], 1: [False, True]} + input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank) + # Preserve a copy of the tensor to compare against after allgather. + input_tensor_copy = input_tensor.clone() + tensor_list = [ + torch.tensor([False, False]).to(self.rank) + for _ in range(dist.get_world_size()) + ] + dist.all_gather(tensor_list, input_tensor) + + self.assertEqual(len(tensor_list), dist.get_world_size()) + for i, t in enumerate(tensor_list): + expected = torch.tensor(inp[i % 2]).to(self.rank) + self.assertEqual(t, expected) + # Ensure that the input tensor is not modified, since this collective + # does not modify its input. + self.assertEqual(input_tensor_copy, input_tensor) + + @require_backend({"nccl"}) + @require_backends_available({"nccl"}) + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + @skip_if_rocm + def test_nccl_backend_bool_reduce(self): + torch.cuda.set_device(self.rank) + inp = {0: [True, True], 1: [False, False]} + # Run reduce() with product op + for op in [dist.ReduceOp.PRODUCT, dist.ReduceOp.MIN]: + input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank) + expected = torch.tensor([False, False]).to(self.rank) + self._run_reduction_test( + input_tensor, expected, op, dist.reduce, dst=0 + ) + # Ensure that all ranks contributing True (cast to 1) results in the + # correct reduction. + input_tensor = torch.tensor([True, True]).to(self.rank) + expected_tensor = input_tensor.clone() + self._run_reduction_test( + input_tensor, expected_tensor, op, dist.reduce, dst=0 + ) + + for op in [dist.ReduceOp.SUM, dist.ReduceOp.MAX]: + input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank) + expected = ( + torch.tensor([True, True]).to(self.rank) + if self.rank == 0 + else input_tensor.clone() + ) + self._run_reduction_test( + input_tensor, expected, op, dist.reduce, dst=0 + ) + + @require_backend({"nccl"}) + @require_backends_available({"nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_nccl_backend_bool_broadcast(self): + tensor_size = 10 + bcast_tensor = torch.tensor( + [ + (random.random() < 0.5 if self.rank == 0 else False) + for _ in range(tensor_size) + ] + ).to(self.rank) + dist.broadcast(bcast_tensor, src=0) + # Now allgather and ensure the tensors are equal. + tensor_list = [ + torch.tensor([False for _ in range(tensor_size)]).to(self.rank) + for _ in range(dist.get_world_size()) + ] + dist.all_gather(tensor_list, bcast_tensor) + expected = tensor_list[0] + for tensor in tensor_list[1:]: + self.assertEqual(tensor, expected) + + if BACKEND == "gloo" or BACKEND == "nccl": WORLD_SIZE = os.environ["WORLD_SIZE"] diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index c8be683dc03..225f1464dcd 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -52,6 +52,7 @@ std::map ncclDataType = { {at::kInt, ncclInt32}, {at::kLong, ncclInt64}, {at::kHalf, ncclHalf}, + {at::kBool, ncclUint8}, #if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 301 {at::kBFloat16, ncclBfloat16}, #endif @@ -59,11 +60,22 @@ std::map ncclDataType = { // Helper function that gets the data type and issues error if not supported ncclDataType_t getNcclDataType(at::ScalarType type) { - try { - return ncclDataType.at(type); - } catch (std::out_of_range& e) { - throw std::runtime_error("Unsupported data type for NCCL process group"); + auto it = ncclDataType.find(type); + TORCH_CHECK( + it != ncclDataType.end(), + "Input tensor data type is not supported for NCCL process group: ", + type); + return it->second; +} + +ncclRedOp_t getNcclReduceOp(const ReduceOp reduceOp, at::Tensor& input) { + if (reduceOp == ReduceOp::SUM && input.scalar_type() == at::kBool) { + // For bool tensors, map sum to max, which both represent a bitwise or. + // This is to prevent overflow issues with sum, since we use uint8 to + // represent a bool (see ncclDataType mapping). + return ncclMax; } + return ncclOp[reduceOp]; } // Get the deviceList String from the list of devices @@ -795,7 +807,7 @@ std::shared_ptr ProcessGroupNCCL::allreduce( output.data_ptr(), input.numel(), getNcclDataType(input.scalar_type()), - ncclOp[opts.reduceOp], + getNcclReduceOp(opts.reduceOp, input), comm, stream.stream()); }); @@ -849,7 +861,7 @@ std::shared_ptr ProcessGroupNCCL::reduce( output.data_ptr(), input.numel(), getNcclDataType(input.scalar_type()), - ncclOp[opts.reduceOp], + getNcclReduceOp(opts.reduceOp, input), root, comm, stream.stream()); @@ -931,7 +943,7 @@ std::shared_ptr ProcessGroupNCCL::reduce_scatter( output.data_ptr(), output.numel(), getNcclDataType(input.scalar_type()), - ncclOp[opts.reduceOp], + getNcclReduceOp(opts.reduceOp, input), comm, stream.stream()); },