-
Notifications
You must be signed in to change notification settings - Fork 9
/
training_io.py
266 lines (225 loc) · 10.4 KB
/
training_io.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
"""Provides IO support for training:
* checkpoint save and load
* metrics logging
* profiling of XLA computations
* reporting FLOPs per device
"""
import jax
import jax.numpy as jnp
from jax.experimental import multihost_utils
from typing import Tuple, Any
from dataclasses import dataclass
import os
import fsspec
import zarr
from numcodecs import blosc
from clearml import Logger
import numpy as np
import datetime
import concurrent
import jax.profiler
import tempfile
import shutil
from jax.lib import xla_client
PyTree = Any
@dataclass
class IOConfig:
# Max number of threads to use for IO-bound tasks like saving and loading checkpoints.
# Recommendation: about 1MiB/thread is typical, so 1024 thread is reasonable for 1GiB of overhead.
# Since this work is IO-bound rather than CPU-bound, it is fine to have many more threads than
# CPU cores.
max_io_threads: int
def log(step: int, logger: Logger, output: PyTree):
"""Logs the output of a training step. The output must be a PyTree of f32 arrays."""
if jax.process_index() == 0:
metrics_dict = {}
for path, arr in jax.tree_util.tree_leaves_with_path(output):
path = jax.tree_util.keystr(path)
arr = jax.device_get(arr)
if arr.shape == () and arr.dtype == jnp.float32:
if logger:
logger.report_scalar(
title=path, series=path, value=arr, iteration=step)
metrics_dict[path] = float(arr)
elif arr.dtype == jnp.float32:
if logger:
logger.report_histogram(
title=path, series=path, values=arr, iteration=step)
else:
raise ValueError(f"Output {path} has unsupported shape {arr.shape} and dtype {arr.dtype}.")
now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"[{now}] Step {step}: {metrics_dict}")
def load_checkpoint_if_it_exists(checkpoint_dir: str, state: PyTree, config: IOConfig) -> Tuple[PyTree, int]:
"""Loads the latest checkpoint if it exists, otherwise return the initial state.
In either case, uses the sharding and PyTree structure of `state` to produce the output.
Since the state may occupy a large amount of memory, this function makes sure to delete `state`
before loading the checkpoint. To facilitate this, callers should ensure not to hold on to any
additional references to `state` when calling this function.
Returns state and step number. Step 0 is the initial state, which may or may not have been loaded
from a checkpoint.
"""
blosc.use_threads = False # Blindly following recommendation from https://zarr.readthedocs.io/en/stable/tutorial.html#parallel-computing-and-synchronization
checkpoint_dir_pseudofile = fsspec.open(checkpoint_dir)
fs = checkpoint_dir_pseudofile.fs
checkpoint_dir_path = checkpoint_dir_pseudofile.path
del checkpoint_dir_pseudofile
# Check working_dir for checkpoint files.
# Process index 0 selects the checkpoint, then broadcasts it to everyone else.
selected_checkpoint = -1
if jax.process_index() == 0:
if fs.exists(checkpoint_dir_path):
# fs.mkdir(checkpoint_dir, create_parents=False)
checkpoint_dirs = fs.ls(checkpoint_dir_path)
for c in reversed(sorted(checkpoint_dirs)):
try:
checkpoint_number = int(os.path.basename(c))
except ValueError:
continue
root = zarr.open_group(zarr.storage.FSStore(c, fs=fs))
if "write_completed" not in root.attrs:
print(f"zarr 'write_completed' marker is missing in checkpoint {c}; skipping.")
continue
selected_checkpoint = checkpoint_number
break
selected_checkpoint = multihost_utils.broadcast_one_to_all(jnp.int32(selected_checkpoint))
if selected_checkpoint == -1:
print(f"No checkpoints found in {checkpoint_dir_path}, starting from initial state.")
return state, 0
print(f'Found checkpoint {selected_checkpoint} in {checkpoint_dir_path}, starting from there.')
return load_zarr(os.path.join(checkpoint_dir, step_to_str(selected_checkpoint)), state, config), selected_checkpoint
def save_checkpoint(checkpoint_dir: str, step: int, state: PyTree, config: IOConfig):
"""Saves a checkpoint for the specified step number.
See docs/pytree-zarr-checkpoint.md for the checkpoint format.
"""
blosc.use_threads = False
checkpoint_file = os.path.join(checkpoint_dir, step_to_str(step))
if jax.process_index() == 0:
# If there's already a checkpoint at this step, delete it. It might have been a partially
# written checkpoint from a previous run.
f = fsspec.open(checkpoint_dir)
checkpoint_path = os.path.join(f.path, step_to_str(step))
if f.fs.exists(checkpoint_path):
f.fs.rm(checkpoint_path, recursive=True)
print(f"[{datetime.datetime.now()}] Saving checkpoint {step} to {checkpoint_file}.")
save_zarr(checkpoint_file, state, config)
print(f"[{datetime.datetime.now()}] Finished saving checkpoint {step} to {checkpoint_file}.")
def load_zarr(filename: str, state: PyTree, config: IOConfig) -> PyTree:
"""Loads a zarr checkpoint from disk.
See docs/pytree-zarr-checkpoint.md for the checkpoint format.
"""
root = zarr.open_group(filename, mode="r")
if "write_completed" not in root.attrs:
raise ValueError(f"zarr 'write_completed' marker is missing. Should not have selected this checkpoint to load from.")
def load_one(path: Tuple, prev: jax.Array) -> jax.Array:
path = jax.tree_util.keystr(path)
shape = prev.shape
sharding = prev.sharding
arr = root[path]
assert arr.shape == shape, f'Expected shape {shape} but got {arr.shape} for {path} in {filename}'
assert arr.dtype == prev.dtype, f'Expected dtype {prev.dtype} but got {arr.dtype} for {path} in {filename}'
del prev # Deallocate memory before loading its replacement!
return jax.make_array_from_callback(shape, sharding, lambda shard_index: arr[shard_index])
state, treedef = jax.tree_util.tree_flatten_with_path(state)
with concurrent.futures.ThreadPoolExecutor(max_workers=config.max_io_threads) as executor:
state_futures = [executor.submit(load_one, path, shape) for (path, shape) in state]
states = [f.result() for f in state_futures]
return jax.tree_util.tree_unflatten(treedef, states)
def save_zarr(filename: str, state: PyTree, config: IOConfig):
"""Saves a zarr checkpoint to disk.
See docs/pytree-zarr-checkpoint.md for the checkpoint format.
"""
state, _treedef = jax.tree_util.tree_flatten_with_path(state)
if jax.process_index() == 0:
# Create the zarr file and all the arrays.
try:
root = zarr.open_group(filename, mode='w-')
except zarr.errors.ContainsGroupError:
raise ValueError(f"Checkpoint {filename} already exists.")
for path, arr in state:
path = jax.tree_util.keystr(path)
chunk_shape = arr.sharding.shard_shape(arr.shape)
root.empty(path, shape=arr.shape, chunks=chunk_shape, dtype=arr.dtype)
multihost_utils.sync_global_devices("save_zarr_begin")
root = zarr.open_group(filename, mode='r+')
def save_shard(dst: zarr.Array, shard: jax.Array, index: Tuple[int, ...]):
dst[index] = np.asarray(shard)
with concurrent.futures.ThreadPoolExecutor(max_workers=config.max_io_threads) as executor:
for path, arr in state:
path = jax.tree_util.keystr(path)
dst = root[path]
assert dst.chunks == arr.sharding.shard_shape(arr.shape)
for shard in arr.addressable_shards:
if shard.replica_id == 0:
executor.submit(save_shard, dst, shard.data, shard.index)
multihost_utils.sync_global_devices("save_zarr_end")
if jax.process_index() == 0:
root.attrs["write_completed"] = True
multihost_utils.sync_global_devices("save_zarr_committed")
def step_to_str(step: int) -> str:
"""Converts a step number to a string with leading zeros.
We pad up to 10 digits so that lexicographic order matches numerical. 1e10 training steps
should be enough for anyone: the biggest runs as of 2024 are probably around 1e7 tokens/batch,
1e13 tokens total, so 1e6 training steps total.
"""
return str(step).zfill(10)
_PROFILE_DIR = None
def start_profile():
"""Starts gathering a JAX profile."""
# Get fresh temporary directory
global _PROFILE_DIR
_PROFILE_DIR = tempfile.mkdtemp()
print(f'[{datetime.datetime.now()}] Starting profile, saving to {_PROFILE_DIR}')
jax.profiler.start_trace(_PROFILE_DIR, create_perfetto_trace=True)
def stop_profile(working_dir: str):
"""Stops gathering the JAX profile and saves it to a file."""
global _PROFILE_DIR
jax.profiler.stop_trace()
print(f'[{datetime.datetime.now()}] Finished profile, copying to {working_dir}')
fsspec_put(_PROFILE_DIR + '/', working_dir + '/')
shutil.rmtree(_PROFILE_DIR)
print(f'[{datetime.datetime.now()}] Finished copying profile to {working_dir}')
_PROFILE_DIR = None
def fsspec_put(local_src: str, remote_dst: str):
"""Copies a file from local disk to a remote location specified by a fsspec path."""
f = fsspec.open(remote_dst)
fs = f.fs
path = f.path
del f
print(f'Put {local_src} to {path}')
fs.put(local_src, path, recursive=True, create_parents=True)
def save_hlo_svg(filespec: str, compiled: jax.stages.Compiled):
"""Saves a compiled function's HLO to an SVG file."""
compiled_hlo_dot = xla_client._xla.hlo_module_to_dot_graph(compiled.runtime_executable().hlo_modules()[0])
with tempfile.TemporaryDirectory() as d:
with open(os.path.join(d, "hlo.dot"), "w") as f:
f.write(compiled_hlo_dot)
hlo_orig_svg = os.path.join(d, "hlo.original.svg")
hlo_svg = os.path.join(d, "hlo.svg")
os.system(f"dot -Tsvg {f.name} -o{hlo_orig_svg}")
# Edit the SVG to remove everything before <svg>. There's a bunch of hover CSS that massively slows down
# rendering in Chrome and adds little value: it just highlights edges when you hover over them.
with open(hlo_orig_svg, "r") as f:
svg = f.read()
svg = svg[svg.index("<svg "):]
with open(hlo_svg, "w") as f:
f.write(svg)
fsspec_put(hlo_svg, filespec)
def mkdir(filespec: str):
"""Creates a directory at the specified (possibly remote) fsspec path."""
f = fsspec.open(filespec)
fs = f.fs
path = f.path
del f
if not fs.exists(path):
fs.mkdir(path, create_parents=False)
def get_flops_per_device():
"""Gets the FLOPS per device for the current device kind."""
device = jax.devices()[0].device_kind
if device.startswith("NVIDIA A100"):
result = 312e12
else:
print(f'Unrecognized device, assuming ridiculously low 1 MFLOPS. Device name: {device}')
result = 1e6
print(f'Device kind: {device}')
print(f'FLOPS per device: {result:_}')
return result