Skip to content

Commit

Permalink
[Resubmit #41318] NCCL backend support for torch bool (#41959)
Browse files Browse the repository at this point in the history
Summary:
Resubmit of pytorch/pytorch#41318 pushed to ci-all branch.

Original description:
Closes pytorch/pytorch#24137.
This PR adds support for the torch.bool tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since bool is not supported as a native ncclDataType_t, we add the following logic:

Map at::kBool to ncclUint8
During reduction (allreduce for example), if the operation is SUM, we instead override to to a MAX, to avoid overflow issues. The rest of the operations work with no changes. In the boolean case, changing sum to max makes no correctness difference since they both function as a bitwise OR.
The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see pytorch/pytorch#41362

Pull Request resolved: pytorch/pytorch#41959

Reviewed By: mrshenli

Differential Revision: D22719665

Pulled By: rohan-varma

fbshipit-source-id: 8bc4194a8d1268589640242277124f277d2ec9f1
  • Loading branch information
rohan-varma authored and facebook-github-bot committed Jul 25, 2020
1 parent 3858042 commit 366c014
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 7 deletions.
127 changes: 127 additions & 0 deletions test/distributed/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import errno
import fcntl
import os
import random
import sys
import time
import tempfile
Expand Down Expand Up @@ -2249,6 +2250,132 @@ def test_SyncBatchNorm_process_group(self):
process_group_sync = res50_model_sync.layer1[0].bn1.process_group
self.assertEqual(process_group_sync, process_group)

def _run_reduction_test(
self, tensor, expected_tensor, op, reduction_fn=dist.all_reduce, dst=None
):
if reduction_fn != dist.all_reduce and dst is None:
raise ValueError(f"Reduction fn {reduction_fn} must specify dst!")
if dst is not None:
reduction_fn(tensor, dst, op)
# Only destination rank tensor is expected to have final result.
if dist.get_rank() == dst:
self.assertEqual(tensor, expected_tensor)
else:
reduction_fn(tensor, op)
self.assertEqual(tensor, expected_tensor)

@require_backend({"nccl"})
@require_backends_available({"nccl"})
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_nccl_backend_bool_allreduce(self):
torch.cuda.set_device(self.rank)
# Run all_reduce with PRODUCT
element = self.rank % 2 == 0
for op in [dist.ReduceOp.PRODUCT, dist.ReduceOp.MIN]:
input_tensor = torch.tensor([element, element]).to(self.rank)
self._run_reduction_test(
input_tensor, torch.tensor([False, False]).to(self.rank), op
)
# Ensure that all ranks contributing True (cast to 1) results in the
# correct reduction.
input_tensor = torch.tensor([True, True]).to(self.rank)
expected_tensor = input_tensor.clone()
self._run_reduction_test(
input_tensor, expected_tensor, op
)

# Run all_reduce with SUM
for op in [dist.ReduceOp.SUM, dist.ReduceOp.MAX]:
input_tensor = torch.tensor([element, element]).to(self.rank)
self._run_reduction_test(
input_tensor, torch.tensor([True, True]).to(self.rank), op
)
# TODO: NCCL backend does not work correctly for bitwise reduction ops
# (see https://github.com/pytorch/pytorch/issues/41362). Add tests for
# these once it is supported.

@require_backend({"nccl"})
@require_backends_available({"nccl"})
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_nccl_backend_bool_allgather(self):
torch.cuda.set_device(self.rank)
inp = {0: [True, True], 1: [False, True]}
input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank)
# Preserve a copy of the tensor to compare against after allgather.
input_tensor_copy = input_tensor.clone()
tensor_list = [
torch.tensor([False, False]).to(self.rank)
for _ in range(dist.get_world_size())
]
dist.all_gather(tensor_list, input_tensor)

self.assertEqual(len(tensor_list), dist.get_world_size())
for i, t in enumerate(tensor_list):
expected = torch.tensor(inp[i % 2]).to(self.rank)
self.assertEqual(t, expected)
# Ensure that the input tensor is not modified, since this collective
# does not modify its input.
self.assertEqual(input_tensor_copy, input_tensor)

@require_backend({"nccl"})
@require_backends_available({"nccl"})
@skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
@skip_if_rocm
def test_nccl_backend_bool_reduce(self):
torch.cuda.set_device(self.rank)
inp = {0: [True, True], 1: [False, False]}
# Run reduce() with product op
for op in [dist.ReduceOp.PRODUCT, dist.ReduceOp.MIN]:
input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank)
expected = torch.tensor([False, False]).to(self.rank)
self._run_reduction_test(
input_tensor, expected, op, dist.reduce, dst=0
)
# Ensure that all ranks contributing True (cast to 1) results in the
# correct reduction.
input_tensor = torch.tensor([True, True]).to(self.rank)
expected_tensor = input_tensor.clone()
self._run_reduction_test(
input_tensor, expected_tensor, op, dist.reduce, dst=0
)

for op in [dist.ReduceOp.SUM, dist.ReduceOp.MAX]:
input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank)
expected = (
torch.tensor([True, True]).to(self.rank)
if self.rank == 0
else input_tensor.clone()
)
self._run_reduction_test(
input_tensor, expected, op, dist.reduce, dst=0
)

@require_backend({"nccl"})
@require_backends_available({"nccl"})
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_nccl_backend_bool_broadcast(self):
tensor_size = 10
bcast_tensor = torch.tensor(
[
(random.random() < 0.5 if self.rank == 0 else False)
for _ in range(tensor_size)
]
).to(self.rank)
dist.broadcast(bcast_tensor, src=0)
# Now allgather and ensure the tensors are equal.
tensor_list = [
torch.tensor([False for _ in range(tensor_size)]).to(self.rank)
for _ in range(dist.get_world_size())
]
dist.all_gather(tensor_list, bcast_tensor)
expected = tensor_list[0]
for tensor in tensor_list[1:]:
self.assertEqual(tensor, expected)


if BACKEND == "gloo" or BACKEND == "nccl":
WORLD_SIZE = os.environ["WORLD_SIZE"]

Expand Down
26 changes: 19 additions & 7 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,30 @@ std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
{at::kInt, ncclInt32},
{at::kLong, ncclInt64},
{at::kHalf, ncclHalf},
{at::kBool, ncclUint8},
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 301
{at::kBFloat16, ncclBfloat16},
#endif
};

// Helper function that gets the data type and issues error if not supported
ncclDataType_t getNcclDataType(at::ScalarType type) {
try {
return ncclDataType.at(type);
} catch (std::out_of_range& e) {
throw std::runtime_error("Unsupported data type for NCCL process group");
auto it = ncclDataType.find(type);
TORCH_CHECK(
it != ncclDataType.end(),
"Input tensor data type is not supported for NCCL process group: ",
type);
return it->second;
}

ncclRedOp_t getNcclReduceOp(const ReduceOp reduceOp, at::Tensor& input) {
if (reduceOp == ReduceOp::SUM && input.scalar_type() == at::kBool) {
// For bool tensors, map sum to max, which both represent a bitwise or.
// This is to prevent overflow issues with sum, since we use uint8 to
// represent a bool (see ncclDataType mapping).
return ncclMax;
}
return ncclOp[reduceOp];
}

// Get the deviceList String from the list of devices
Expand Down Expand Up @@ -795,7 +807,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allreduce(
output.data_ptr(),
input.numel(),
getNcclDataType(input.scalar_type()),
ncclOp[opts.reduceOp],
getNcclReduceOp(opts.reduceOp, input),
comm,
stream.stream());
});
Expand Down Expand Up @@ -849,7 +861,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce(
output.data_ptr(),
input.numel(),
getNcclDataType(input.scalar_type()),
ncclOp[opts.reduceOp],
getNcclReduceOp(opts.reduceOp, input),
root,
comm,
stream.stream());
Expand Down Expand Up @@ -931,7 +943,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce_scatter(
output.data_ptr(),
output.numel(),
getNcclDataType(input.scalar_type()),
ncclOp[opts.reduceOp],
getNcclReduceOp(opts.reduceOp, input),
comm,
stream.stream());
},
Expand Down

0 comments on commit 366c014

Please # to comment.