Skip to content

Commit 02923f0

Browse files
authored
Used per-parameter FSDP (#165)
**Numeric Parity** 1D FSDP - Eager: 1k steps of minipile on 8 H100 GPUs, local batch size 8, sequence length 2048, AC/SAC, bf16 mixed precision, fp32 reduce-scatter - FSDP1 (AC): 24.81% peak active, 33.82% peak reserved, 6100-6200 WPS - FSDP1 (SAC): 52.98% peak active, 67.23% peak reserved, 6500-6700 WPS - FSDP2 (AC): 23.92% peak active, 32.64% peak reserved, 6100-6300 WPS - FSDP2 (SAC): 52.13% peak active, 62.51% peak reserved, 6600-6800 WPS - Loss curves match between FSDP1 and FSDP2 - Memory numbers reported as percentage since that is how they are logged; can convert against 95.0396 GiB GPU memory - Compile: same setup as eager - FSDP2 (AC), buffer reuse disabled: 28.72 GiB (30.22%) peak reserved, 7200-7500 WPS, 33% MFU - FSDP2 (AC), buffer reuse enabled: 28.90 GiB (30.40%) peak reserved, 7200-7500 WPS, 33% MFU - FSDP2 (SAC), buffer reuse enabled: 53.83 GiB (56.64%) peak reserved, 8100-8400 WPS, 36% MFU - Loss curves slightly better than eager - For fun -- how much can we push MFU? - If we use FSDP2 (SAC) with 16 local batch size (doubled), we get 88.23 GiB (92.84%) peak reserved, 8600 WPS, 38% MFU. - If we use FSDP2 (no AC) with 8 local batch size, we get 90.28 GiB (94.99%) peak reserved, 9100-9300 WPS, 40% MFU. - Why is FSDP2 faster? (1) fp32 reduce-scatter only uses one div kernel instead of two and (2), `reshard_after_forward=False` for the last transformer block 2D FSDP - Eager (2-way SP, 4-way FSDP): 1k steps of minipile on 8 H100 GPUs, local batch size 16 (to preserve global batch size), sequence length 2048, bf16 mixed precision, fp32 reduce-scatter - FSDP2 (AC): 50.12% peak active, 60.97% peak reserved, 5800-5900 WPS - FSDP2 (SAC): 76.49% peak active, 90.14% peak reserved, 6100-6300 WPS - Loss curves match 8-way FSDP - FSDP1 + SP has incorrect numerics due to the `FSDP.clip_grad_norm_` not all-reducing over TP mesh dimension <details> <summary> Loss curves </summary> <img width="732" alt="Screenshot 2024-03-26 at 3 31 19 PM" src="https://github.com/pytorch/torchtrain/assets/31054793/59ec71cc-ad0a-4dd1-b5c6-a8cbf9ab5e85"> </details> **Meta-Device Initialization** - The PyTorch Core guideline is for `module.reset_parameters()` to only initialize parameters/buffers immediately owned by `module` (i.e. `module.parameters(recurse=False)` and `module.buffers(recurse=False)`). - This makes it challenging to specify custom initializations for core modules like `nn.Linear` and `nn.Embedding`. For example, in @lessw2020's depth-wise truncated normal initialization, the `trunc_normal_` standard deviation depends on the layer ID, which is a property of the `TransformerBlock` but affects the child `nn.Linear`s. - To disambiguate, I suggest avoiding the name `reset_parameters()` in the case that we violate the PyTorch Core guideline and instead use a different name (e.g. `init_weights`). **DCP & Save/Load** - Tested 1D and 2D by specifying `checkpoint_folder = "/tmp/checkpoint_andgu` in the `.toml`, training until saving a checkpoint, terminating the run, and restarting the training to load the checkpoint -- the loss after loading looks reasonable
1 parent 479694f commit 02923f0

File tree

4 files changed

+97
-194
lines changed

4 files changed

+97
-194
lines changed

torchtrain/meta_init.py

-48
This file was deleted.

torchtrain/models/llama/model.py

+41-56
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ def __init__(self, dim: int, eps: float = 1e-6):
4545
super().__init__()
4646
self.eps = eps
4747
self.weight = nn.Parameter(torch.empty(dim))
48-
49-
# re-enable if not using meta-init
50-
# self.reset_parameters()
48+
self.reset_parameters()
5149

5250
def _norm(self, x: torch.Tensor):
5351
"""
@@ -207,19 +205,10 @@ def __init__(self, model_args: ModelArgs):
207205
model_args.n_heads * self.head_dim, model_args.dim, bias=False
208206
)
209207

210-
def reset_parameters(self, init_std):
211-
for item in (self.wq, self.wk, self.wv):
212-
nn.init.trunc_normal_(
213-
item.weight,
214-
mean=0.0,
215-
std=0.02,
216-
)
217-
218-
nn.init.trunc_normal_(
219-
self.wo.weight,
220-
mean=0.0,
221-
std=init_std,
222-
)
208+
def init_weights(self, init_std: float):
209+
for linear in (self.wq, self.wk, self.wv):
210+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
211+
nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)
223212

224213
def forward(
225214
self,
@@ -309,19 +298,10 @@ def __init__(
309298
def forward(self, x):
310299
return self.w2(F.silu(self.w1(x)) * self.w3(x))
311300

312-
def reset_parameters(self, init_std):
313-
nn.init.trunc_normal_(
314-
self.w1.weight,
315-
mean=0.0,
316-
std=0.02,
317-
)
318-
319-
for item in (self.w2, self.w3):
320-
nn.init.trunc_normal_(
321-
item.weight,
322-
mean=0.0,
323-
std=init_std,
324-
)
301+
def init_weights(self, init_std: float):
302+
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
303+
for linear in (self.w2, self.w3):
304+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
325305

326306

327307
class RotaryEmbedding(nn.Module):
@@ -333,13 +313,15 @@ def __init__(self, model_args: ModelArgs):
333313
super().__init__()
334314
self.model_args = model_args
335315
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
316+
self.register_buffer(
317+
"freqs_cis", self._precompute_freqs_cis(), persistent=False
318+
)
336319

337-
self.freqs_cis = precompute_freqs_cis(
338-
# Note that self.model_args.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation
339-
# of models is 4096.
340-
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training
341-
# or fine-tuning.
320+
def _precompute_freqs_cis(self):
321+
return precompute_freqs_cis(
342322
self.model_args.dim // self.model_args.n_heads,
323+
# Need to compute until at least the max token limit for generation
324+
# (use 2x max sequence length to be safe)
343325
self.model_args.max_seq_len * 2,
344326
)
345327

@@ -355,10 +337,14 @@ def forward(self, tokens: torch.Tensor):
355337
"""
356338
_bsz, seqlen = tokens.shape
357339
h = self.tok_embeddings(tokens)
358-
self.freqs_cis = self.freqs_cis.to(h.device)
359340
freqs_cis = self.freqs_cis[0:seqlen]
360341
return h, freqs_cis
361342

343+
def init_weights(self):
344+
with torch.device(self.freqs_cis.device):
345+
self.freqs_cis = self._precompute_freqs_cis()
346+
nn.init.normal_(self.tok_embeddings.weight)
347+
362348

363349
class TransformerBlock(nn.Module):
364350
"""
@@ -421,13 +407,11 @@ def forward(
421407
out = h + self.feed_forward(self.ffn_norm(h))
422408
return out
423409

424-
def reset_parameters(self):
425-
"""reset params and norms for entire block"""
426-
self.attention_norm.reset_parameters()
427-
self.ffn_norm.reset_parameters()
428-
429-
self.attention.reset_parameters(self.weight_init_std)
430-
self.feed_forward.reset_parameters(self.weight_init_std)
410+
def init_weights(self):
411+
for norm in (self.attention_norm, self.ffn_norm):
412+
norm.reset_parameters()
413+
self.attention.init_weights(self.weight_init_std)
414+
self.feed_forward.init_weights(self.weight_init_std)
431415

432416

433417
class Transformer(nn.Module):
@@ -457,28 +441,29 @@ def __init__(self, model_args: ModelArgs):
457441
self.model_dim = model_args.dim
458442

459443
self.embeddings = RotaryEmbedding(model_args)
460-
461444
self.layers = torch.nn.ModuleList()
462445
for layer_id in range(model_args.n_layers):
463446
self.layers.append(TransformerBlock(layer_id, model_args))
464447

465448
self.norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
466449
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
450+
self.init_weights()
467451

468-
# init model weights
469-
470-
# we are doing meta_init, which will call reset_parameters() after
471-
# the model is moved to actual device.
472-
# If you modify and are not using meta_init, you will need to call
473-
# reset_parameters() manually as below:
474-
475-
# self.reset_parameters()
476-
477-
def reset_parameters(
478-
self,
479-
):
452+
def init_weights(self):
453+
"""
454+
[Note: On ``init_weights`` vs. ``reset_parameters``]
455+
Modules may define ``reset_parameters`` to initialize parameter values.
456+
``reset_parameters`` is meant to only initialize directly owned
457+
parameters/buffers, not those of their child modules, and it can be
458+
used to give the initial values for these tensors.
459+
Separately, users may want custom initialization for their modules,
460+
different from that in ``reset_parameters``. For this, we define
461+
``init_weights``. We only call it in the constructor of this
462+
``Transformer`` root module to avoid reinitializing tensors.
463+
"""
464+
self.embeddings.init_weights()
480465
for layer in self.layers:
481-
layer.reset_parameters()
466+
layer.init_weights()
482467
self.norm.reset_parameters()
483468
final_out_std = self.model_dim**-0.5
484469
cutoff_factor = 3

torchtrain/parallelisms/parallelize_llama.py

+40-66
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,13 @@
88
from typing import Tuple
99

1010
import torch
11-
from torch.distributed._tensor import Replicate, Shard
1211

12+
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
13+
from torch.distributed._tensor import Replicate, Shard
1314
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1415
checkpoint_wrapper as ptd_checkpoint_wrapper,
1516
CheckpointImpl,
1617
)
17-
from torch.distributed.fsdp import (
18-
BackwardPrefetch,
19-
FullyShardedDataParallel as FSDP,
20-
MixedPrecision,
21-
ShardingStrategy,
22-
)
23-
from torch.distributed.fsdp.wrap import enable_wrap, wrap
2418
from torch.distributed.tensor.parallel import (
2519
ColwiseParallel,
2620
parallelize_module,
@@ -33,7 +27,6 @@
3327

3428
from torchtrain.config_manager import JobConfig
3529
from torchtrain.logging_utils import logger
36-
from torchtrain.meta_init import meta_to_real_init_fn
3730

3831

3932
# for selective AC
@@ -75,7 +68,6 @@ def selective_checkpointing_context_fn():
7568
preserve_rng_state=False,
7669
)
7770
elif config.mode == "full":
78-
# full AC
7971
return ptd_checkpoint_wrapper(
8072
module,
8173
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
@@ -136,28 +128,23 @@ def get_tp_parallel_strategy(
136128

137129
def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
138130
"""
139-
Apply parallelisms to the model, including PTD parallelisms, and AC.
131+
Apply parallelisms and activation checkpointing to the model.
140132
141-
NOTE: the model passed in preferrablably shoule be a meta device model,
142-
otherwise the model needs to be small enough on GPU or can fit into CPU.
133+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
134+
the model must fit on GPU or CPU memory.
143135
"""
144-
# apply PTD parallelisms
145136
if parallel_dims.pp_enabled:
146137
raise NotImplementedError("PP not implemented yet.")
147138

148-
# First we apply Tensor Parallelism if it's enabled
149139
if parallel_dims.tp_enabled:
150140
tp_mesh = world_mesh["tp"]
151-
tp_degree = job_config.training.tensor_parallel_degree
152-
153141
row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy(
154142
job_config
155143
)
156144

157-
# First:
158-
# 1. parallelize the first embedding and the last linear proj layer
159-
# 2. parallelize the root norm layer by sequence dim
160-
# 3. shard the first layer of transformer block
145+
# 1. Parallelize the first embedding and the last linear proj layer
146+
# 2. Parallelize the root norm layer over the sequence dim
147+
# 3. Shard the first transformer block's inputs
161148
model = parallelize_module(
162149
model,
163150
tp_mesh,
@@ -167,9 +154,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
167154
),
168155
"output": col_parallel_strategy(
169156
input_layouts=Shard(0),
170-
output_layouts=Shard(-1)
171-
if parallel_dims.loss_parallel_enabled
172-
else Replicate(),
157+
output_layouts=(
158+
Shard(-1)
159+
if parallel_dims.loss_parallel_enabled
160+
else Replicate()
161+
),
173162
use_local_output=not parallel_dims.loss_parallel_enabled,
174163
),
175164
"norm": SequenceParallel(sequence_dim=0),
@@ -181,7 +170,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
181170
},
182171
)
183172

184-
# apply tensor + sequence parallelism to every transformer block
173+
# Apply tensor + sequence parallelism to every transformer block
185174
for layer_id, transformer_block in enumerate(model.layers):
186175
layer_plan = {
187176
"attention": PrepareModuleInput(
@@ -203,62 +192,47 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
203192
"ffn_norm": SequenceParallel(sequence_dim=0),
204193
}
205194

206-
# adjust num_heads in attention layer to local heads
195+
# Adjust attention module to use the local number of heads
207196
attn_layer = transformer_block.attention
208-
attn_layer.n_heads = attn_layer.n_heads // tp_degree
209-
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_degree
197+
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
198+
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
210199

211200
parallelize_module(
212201
module=transformer_block,
213202
device_mesh=tp_mesh,
214203
parallelize_plan=layer_plan,
215204
)
216205

217-
logger.info("Applied Sequence Parallelism to the model")
206+
logger.info("Applied Tensor Parallelism to the model")
218207

219208
if parallel_dims.dp_enabled:
220-
dp_mesh = world_mesh["dp"]
221-
222-
fsdp_config = {
223-
"mixed_precision": MixedPrecision(
224-
param_dtype=torch.bfloat16,
225-
# TODO: see whether we should expose a option to user
226-
reduce_dtype=torch.float32,
227-
),
228-
"sharding_strategy": ShardingStrategy.FULL_SHARD,
229-
"backward_prefetch": BackwardPrefetch.BACKWARD_PRE,
230-
# When torch.compile is active, it requires us to set use_orig_params=True
231-
"use_orig_params": True,
232-
"device_mesh": dp_mesh,
233-
"param_init_fn": meta_to_real_init_fn,
234-
}
235-
209+
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
210+
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
211+
# TODO: Expose `reduce_dtype` as a config option.
212+
mp_policy = MixedPrecisionPolicy(
213+
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
214+
)
236215
ac_mode = job_config.activation_checkpoint.mode
237-
with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
238-
for layer_id, transformer_block in enumerate(model.layers):
239-
# apply AC to the transformer block
240-
if ac_mode in ("full", "selective"):
241-
# wrap the transformer block with checkpoint wrapper, using config settings
242-
transformer_block = checkpoint_wrapper(
243-
transformer_block, job_config.activation_checkpoint
244-
)
245-
246-
# Wraps each layer with FSDP
247-
model.layers[layer_id] = wrap(transformer_block)
248-
249-
# wrap the rest layers with FSDP
250-
model = wrap(model)
251-
216+
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
217+
for layer_id, transformer_block in enumerate(model.layers):
218+
if job_config.activation_checkpoint.mode in ("full", "selective"):
219+
transformer_block = checkpoint_wrapper(
220+
transformer_block, job_config.activation_checkpoint
221+
)
222+
# As an optimization, do not reshard after forward for the last
223+
# transformer block since FSDP would prefetch it immediately
224+
reshard_after_forward = layer_id < len(model.layers) - 1
225+
fully_shard(
226+
transformer_block,
227+
**fsdp_config,
228+
reshard_after_forward=reshard_after_forward,
229+
)
230+
model.layers[layer_id] = transformer_block
231+
model = fully_shard(model, **fsdp_config)
252232
if ac_mode in ("full", "selective"):
253233
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
254234
logger.info("Applied FSDP to the model")
255235
else:
256-
meta_to_real_init_fn(model)
257236
model.cuda()
258237

259-
# we have now moved from meta to device,
260-
# reset parameters for proper initialization
261-
model.reset_parameters()
262-
logger.info("Model fully initialized via reset_parameters")
263-
264238
return model

0 commit comments

Comments
 (0)