-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdistributed.py
124 lines (99 loc) · 3.8 KB
/
distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
""" Pytorch Distributed utils
This piece of code was heavily inspired by the equivalent of Fairseq-py
https://github.com/pytorch/fairseq
"""
from __future__ import print_function
import math
import pickle
import torch.distributed
from others.logging import logger
def is_master(gpu_ranks, device_id):
return gpu_ranks[device_id] == 0
def multi_init(device_id, world_size,gpu_ranks):
print(gpu_ranks)
dist_init_method = 'tcp://localhost:10000'
dist_world_size = world_size
torch.distributed.init_process_group(
backend='nccl', init_method=dist_init_method,
world_size=dist_world_size, rank=gpu_ranks[device_id])
gpu_rank = torch.distributed.get_rank()
if not is_master(gpu_ranks, device_id):
# print('not master')
logger.disabled = True
return gpu_rank
def all_reduce_and_rescale_tensors(tensors, rescale_denom,
buffer_size=10485760):
"""All-reduce and rescale tensors in chunks of the specified size.
Args:
tensors: list of Tensors to all-reduce
rescale_denom: denominator for rescaling summed Tensors
buffer_size: all-reduce chunk size in bytes
"""
# buffer size in bytes, determine equiv. # of elements based on data type
buffer_t = tensors[0].new(
math.ceil(buffer_size / tensors[0].element_size())).zero_()
buffer = []
def all_reduce_buffer():
# copy tensors into buffer_t
offset = 0
for t in buffer:
numel = t.numel()
buffer_t[offset:offset+numel].copy_(t.view(-1))
offset += numel
# all-reduce and rescale
torch.distributed.all_reduce(buffer_t[:offset])
buffer_t.div_(rescale_denom)
# copy all-reduced buffer back into tensors
offset = 0
for t in buffer:
numel = t.numel()
t.view(-1).copy_(buffer_t[offset:offset+numel])
offset += numel
filled = 0
for t in tensors:
sz = t.numel() * t.element_size()
if sz > buffer_size:
# tensor is bigger than buffer, all-reduce and rescale directly
torch.distributed.all_reduce(t)
t.div_(rescale_denom)
elif filled + sz > buffer_size:
# buffer is full, all-reduce and replace buffer with grad
all_reduce_buffer()
buffer = [t]
filled = sz
else:
# add tensor to buffer
buffer.append(t)
filled += sz
if len(buffer) > 0:
all_reduce_buffer()
def all_gather_list(data, max_size=4096):
"""Gathers arbitrary data from all nodes into a list."""
world_size = torch.distributed.get_world_size()
if not hasattr(all_gather_list, '_in_buffer') or \
max_size != all_gather_list._in_buffer.size():
all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
all_gather_list._out_buffers = [
torch.cuda.ByteTensor(max_size)
for i in range(world_size)
]
in_buffer = all_gather_list._in_buffer
out_buffers = all_gather_list._out_buffers
enc = pickle.dumps(data)
enc_size = len(enc)
if enc_size + 2 > max_size:
raise ValueError(
'encoded data exceeds max_size: {}'.format(enc_size + 2))
assert max_size < 255*256
in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
in_buffer[1] = enc_size % 255
in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc))
torch.distributed.all_gather(out_buffers, in_buffer.cuda())
results = []
for i in range(world_size):
out_buffer = out_buffers[i]
size = (255 * out_buffer[0].item()) + out_buffer[1].item()
bytes_list = bytes(out_buffer[2:size+2].tolist())
result = pickle.loads(bytes_list)
results.append(result)
return results