-
Notifications
You must be signed in to change notification settings - Fork 24
/
distributed.py
140 lines (114 loc) · 5.01 KB
/
distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
import os
import re
import socket
import torch
import torch.distributed
from . import training_stats
_sync_device = None
#----------------------------------------------------------------------------
def init():
global _sync_device
if not torch.distributed.is_initialized():
# Setup some reasonable defaults for env-based distributed init if
# not set by the running environment.
if 'MASTER_ADDR' not in os.environ:
os.environ['MASTER_ADDR'] = 'localhost'
if 'MASTER_PORT' not in os.environ:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
os.environ['MASTER_PORT'] = str(s.getsockname()[1])
s.close()
if 'RANK' not in os.environ:
os.environ['RANK'] = '0'
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = '0'
if 'WORLD_SIZE' not in os.environ:
os.environ['WORLD_SIZE'] = '1'
backend = 'gloo' if os.name == 'nt' else 'nccl'
torch.distributed.init_process_group(backend=backend, init_method='env://')
torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
_sync_device = torch.device('cuda') if get_world_size() > 1 else None
training_stats.init_multiprocessing(rank=get_rank(), sync_device=_sync_device)
#----------------------------------------------------------------------------
def get_rank():
return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
#----------------------------------------------------------------------------
def get_world_size():
return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
#----------------------------------------------------------------------------
def should_stop():
return False
#----------------------------------------------------------------------------
def should_suspend():
return False
#----------------------------------------------------------------------------
def request_suspend():
pass
#----------------------------------------------------------------------------
def update_progress(cur, total):
pass
#----------------------------------------------------------------------------
def print0(*args, **kwargs):
if get_rank() == 0:
print(*args, **kwargs)
#----------------------------------------------------------------------------
class CheckpointIO:
def __init__(self, **kwargs):
self._state_objs = kwargs
def save(self, pt_path, verbose=True):
if verbose:
print0(f'Saving {pt_path} ... ', end='', flush=True)
data = dict()
for name, obj in self._state_objs.items():
if obj is None:
data[name] = None
elif isinstance(obj, dict):
data[name] = obj
elif hasattr(obj, 'state_dict'):
data[name] = obj.state_dict()
elif hasattr(obj, '__getstate__'):
data[name] = obj.__getstate__()
elif hasattr(obj, '__dict__'):
data[name] = obj.__dict__
else:
raise ValueError(f'Invalid state object of type {type(obj).__name__}')
if get_rank() == 0:
torch.save(data, pt_path)
if verbose:
print0('done')
def load(self, pt_path, verbose=True):
if verbose:
print0(f'Loading {pt_path} ... ', end='', flush=True)
data = torch.load(pt_path, map_location=torch.device('cpu'))
for name, obj in self._state_objs.items():
if obj is None:
pass
elif isinstance(obj, dict):
obj.clear()
obj.update(data[name])
elif hasattr(obj, 'load_state_dict'):
obj.load_state_dict(data[name])
elif hasattr(obj, '__setstate__'):
obj.__setstate__(data[name])
elif hasattr(obj, '__dict__'):
obj.__dict__.clear()
obj.__dict__.update(data[name])
else:
raise ValueError(f'Invalid state object of type {type(obj).__name__}')
if verbose:
print0('done')
def load_latest(self, run_dir, pattern=r'training-state-(\d+).pt', verbose=True):
fnames = [entry.name for entry in os.scandir(run_dir) if entry.is_file() and re.fullmatch(pattern, entry.name)]
if len(fnames) == 0:
return None
pt_path = os.path.join(run_dir, max(fnames, key=lambda x: float(re.fullmatch(pattern, x).group(1))))
self.load(pt_path, verbose=verbose)
return pt_path
#----------------------------------------------------------------------------