Skip to content

Commit 52e7e01

Browse files
author
Andrew Gu
committed
Used per-parameter FSDP
1 parent 8dd5798 commit 52e7e01

File tree

4 files changed

+94
-194
lines changed

4 files changed

+94
-194
lines changed

torchtrain/meta_init.py

-48
This file was deleted.

torchtrain/models/llama/model.py

+34-55
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ def __init__(self, dim: int, eps: float = 1e-6):
4545
super().__init__()
4646
self.eps = eps
4747
self.weight = nn.Parameter(torch.empty(dim))
48-
49-
# re-enable if not using meta-init
50-
# self.reset_parameters()
48+
self.reset_parameters()
5149

5250
def _norm(self, x: torch.Tensor):
5351
"""
@@ -207,19 +205,10 @@ def __init__(self, model_args: ModelArgs):
207205
model_args.n_heads * self.head_dim, model_args.dim, bias=False
208206
)
209207

210-
def reset_parameters(self, init_std):
211-
for item in (self.wq, self.wk, self.wv):
212-
nn.init.trunc_normal_(
213-
item.weight,
214-
mean=0.0,
215-
std=0.02,
216-
)
217-
218-
nn.init.trunc_normal_(
219-
self.wo.weight,
220-
mean=0.0,
221-
std=init_std,
222-
)
208+
def init_weights(self, init_std: float):
209+
for linear in (self.wq, self.wk, self.wv):
210+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
211+
nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)
223212

224213
def forward(
225214
self,
@@ -309,19 +298,10 @@ def __init__(
309298
def forward(self, x):
310299
return self.w2(F.silu(self.w1(x)) * self.w3(x))
311300

312-
def reset_parameters(self, init_std):
313-
nn.init.trunc_normal_(
314-
self.w1.weight,
315-
mean=0.0,
316-
std=0.02,
317-
)
318-
319-
for item in (self.w2, self.w3):
320-
nn.init.trunc_normal_(
321-
item.weight,
322-
mean=0.0,
323-
std=init_std,
324-
)
301+
def init_weights(self, init_std: float):
302+
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
303+
for linear in (self.w2, self.w3):
304+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
325305

326306

327307
class RotaryEmbedding(nn.Module):
@@ -333,13 +313,13 @@ def __init__(self, model_args: ModelArgs):
333313
super().__init__()
334314
self.model_args = model_args
335315
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
316+
self.init_weights()
336317

337-
self.freqs_cis = precompute_freqs_cis(
338-
# Note that self.model_args.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation
339-
# of models is 4096.
340-
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training
341-
# or fine-tuning.
318+
def _precompute_freqs_cis(self):
319+
return precompute_freqs_cis(
342320
self.model_args.dim // self.model_args.n_heads,
321+
# Need to compute until at least the max token limit for generation
322+
# (use 2x max sequence length to be safe)
343323
self.model_args.max_seq_len * 2,
344324
)
345325

@@ -359,6 +339,16 @@ def forward(self, tokens: torch.Tensor):
359339
freqs_cis = self.freqs_cis[0:seqlen]
360340
return h, freqs_cis
361341

342+
def init_weights(self):
343+
if hasattr(self, "freqs_cis"):
344+
with torch.device(self.freqs_cis.device):
345+
self.freqs_cis = self._precompute_freqs_cis()
346+
else:
347+
self.register_buffer(
348+
"freqs_cis", self._precompute_freqs_cis(), persistent=False
349+
)
350+
nn.init.normal_(self.tok_embeddings.weight)
351+
362352

363353
class TransformerBlock(nn.Module):
364354
"""
@@ -400,6 +390,7 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
400390
self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
401391
else:
402392
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5
393+
self.init_weights()
403394

404395
def forward(
405396
self,
@@ -421,13 +412,11 @@ def forward(
421412
out = h + self.feed_forward(self.ffn_norm(h))
422413
return out
423414

424-
def reset_parameters(self):
425-
"""reset params and norms for entire block"""
426-
self.attention_norm.reset_parameters()
427-
self.ffn_norm.reset_parameters()
428-
429-
self.attention.reset_parameters(self.weight_init_std)
430-
self.feed_forward.reset_parameters(self.weight_init_std)
415+
def init_weights(self):
416+
for norm in (self.attention_norm, self.ffn_norm):
417+
norm.reset_parameters()
418+
self.attention.init_weights(self.weight_init_std)
419+
self.feed_forward.init_weights(self.weight_init_std)
431420

432421

433422
class Transformer(nn.Module):
@@ -457,29 +446,19 @@ def __init__(self, model_args: ModelArgs):
457446
self.model_dim = model_args.dim
458447

459448
self.embeddings = RotaryEmbedding(model_args)
460-
461449
self.layers = torch.nn.ModuleList()
462450
for layer_id in range(model_args.n_layers):
463451
self.layers.append(TransformerBlock(layer_id, model_args))
464452

465453
self.norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
466454
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
455+
self.init_weights()
467456

468-
# init model weights
469-
470-
# we are doing meta_init, which will call reset_parameters() after
471-
# the model is moved to actual device.
472-
# If you modify and are not using meta_init, you will need to call
473-
# reset_parameters() manually as below:
474-
475-
# self.reset_parameters()
476-
477-
def reset_parameters(
478-
self,
479-
):
457+
def init_weights(self):
480458
for layer in self.layers:
481-
layer.reset_parameters()
459+
layer.init_weights()
482460
self.norm.reset_parameters()
461+
self.embeddings.init_weights()
483462
final_out_std = self.model_dim**-0.5
484463
cutoff_factor = 3
485464
nn.init.trunc_normal_(

torchtrain/parallelisms/parallelize_llama.py

+39-66
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,13 @@
88
from typing import Tuple
99

1010
import torch
11-
from torch.distributed._tensor import Replicate, Shard
1211

12+
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
13+
from torch.distributed._tensor import Replicate, Shard
1314
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1415
checkpoint_wrapper as ptd_checkpoint_wrapper,
1516
CheckpointImpl,
1617
)
17-
from torch.distributed.fsdp import (
18-
BackwardPrefetch,
19-
FullyShardedDataParallel as FSDP,
20-
MixedPrecision,
21-
ShardingStrategy,
22-
)
23-
from torch.distributed.fsdp.wrap import enable_wrap, wrap
2418
from torch.distributed.tensor.parallel import (
2519
ColwiseParallel,
2620
parallelize_module,
@@ -33,7 +27,6 @@
3327

3428
from torchtrain.config_manager import JobConfig
3529
from torchtrain.logging_utils import logger
36-
from torchtrain.meta_init import meta_to_real_init_fn
3730

3831

3932
# for selective AC
@@ -75,7 +68,6 @@ def selective_checkpointing_context_fn():
7568
preserve_rng_state=False,
7669
)
7770
elif config.mode == "full":
78-
# full AC
7971
return ptd_checkpoint_wrapper(
8072
module,
8173
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
@@ -136,28 +128,23 @@ def get_tp_parallel_strategy(
136128

137129
def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
138130
"""
139-
Apply parallelisms to the model, including PTD parallelisms, and AC.
131+
Apply parallelisms and activation checkpointing to the model.
140132
141-
NOTE: the model passed in preferrablably shoule be a meta device model,
142-
otherwise the model needs to be small enough on GPU or can fit into CPU.
133+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
134+
the model must fit on GPU or CPU memory.
143135
"""
144-
# apply PTD parallelisms
145136
if parallel_dims.pp_enabled:
146137
raise NotImplementedError("PP not implemented yet.")
147138

148-
# First we apply Tensor Parallelism if it's enabled
149139
if parallel_dims.tp_enabled:
150140
tp_mesh = world_mesh["tp"]
151-
tp_degree = job_config.training.tensor_parallel_degree
152-
153141
row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy(
154142
job_config
155143
)
156144

157-
# First:
158-
# 1. parallelize the first embedding and the last linear proj layer
159-
# 2. parallelize the root norm layer by sequence dim
160-
# 3. shard the first layer of transformer block
145+
# 1. Parallelize the first embedding and the last linear proj layer
146+
# 2. Parallelize the root norm layer over the sequence dim
147+
# 3. Shard the first transformer block's inputs
161148
model = parallelize_module(
162149
model,
163150
tp_mesh,
@@ -167,9 +154,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
167154
),
168155
"output": col_parallel_strategy(
169156
input_layouts=Shard(0),
170-
output_layouts=Shard(-1)
171-
if parallel_dims.loss_parallel_enabled
172-
else Replicate(),
157+
output_layouts=(
158+
Shard(-1)
159+
if parallel_dims.loss_parallel_enabled
160+
else Replicate()
161+
),
173162
use_local_output=not parallel_dims.loss_parallel_enabled,
174163
),
175164
"norm": SequenceParallel(sequence_dim=0),
@@ -181,7 +170,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
181170
},
182171
)
183172

184-
# apply tensor + sequence parallelism to every transformer block
173+
# Apply tensor + sequence parallelism to every transformer block
185174
for layer_id, transformer_block in enumerate(model.layers):
186175
layer_plan = {
187176
"attention": PrepareModuleInput(
@@ -203,62 +192,46 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
203192
"ffn_norm": SequenceParallel(sequence_dim=0),
204193
}
205194

206-
# adjust num_heads in attention layer to local heads
195+
# Adjust attention module to use the local number of heads
207196
attn_layer = transformer_block.attention
208-
attn_layer.n_heads = attn_layer.n_heads // tp_degree
209-
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_degree
197+
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
198+
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
210199

211200
parallelize_module(
212201
module=transformer_block,
213202
device_mesh=tp_mesh,
214203
parallelize_plan=layer_plan,
215204
)
216205

217-
logger.info("Applied Sequence Parallelism to the model")
206+
logger.info("Applied Tensor Parallelism to the model")
218207

219208
if parallel_dims.dp_enabled:
220-
dp_mesh = world_mesh["dp"]
221-
222-
fsdp_config = {
223-
"mixed_precision": MixedPrecision(
224-
param_dtype=torch.bfloat16,
225-
# TODO: see whether we should expose a option to user
226-
reduce_dtype=torch.float32,
227-
),
228-
"sharding_strategy": ShardingStrategy.FULL_SHARD,
229-
"backward_prefetch": BackwardPrefetch.BACKWARD_PRE,
230-
# When torch.compile is active, it requires us to set use_orig_params=True
231-
"use_orig_params": True,
232-
"device_mesh": dp_mesh,
233-
"param_init_fn": meta_to_real_init_fn,
234-
}
235-
209+
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
210+
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
211+
mp_policy = MixedPrecisionPolicy(
212+
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
213+
)
236214
ac_mode = job_config.activation_checkpoint.mode
237-
with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
238-
for layer_id, transformer_block in enumerate(model.layers):
239-
# apply AC to the transformer block
240-
if ac_mode in ("full", "selective"):
241-
# wrap the transformer block with checkpoint wrapper, using config settings
242-
transformer_block = checkpoint_wrapper(
243-
transformer_block, job_config.activation_checkpoint
244-
)
245-
246-
# Wraps each layer with FSDP
247-
model.layers[layer_id] = wrap(transformer_block)
248-
249-
# wrap the rest layers with FSDP
250-
model = wrap(model)
251-
215+
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
216+
for layer_id, transformer_block in enumerate(model.layers):
217+
if job_config.activation_checkpoint.mode in ("full", "selective"):
218+
transformer_block = checkpoint_wrapper(
219+
transformer_block, job_config.activation_checkpoint
220+
)
221+
# As an optimization, do not reshard after forward for the last
222+
# transformer block since FSDP would prefetch it immediately
223+
reshard_after_forward = layer_id < len(model.layers) - 1
224+
fully_shard(
225+
transformer_block,
226+
**fsdp_config,
227+
reshard_after_forward=reshard_after_forward,
228+
)
229+
model.layers[layer_id] = transformer_block
230+
model = fully_shard(model, **fsdp_config)
252231
if ac_mode in ("full", "selective"):
253232
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
254233
logger.info("Applied FSDP to the model")
255234
else:
256-
meta_to_real_init_fn(model)
257235
model.cuda()
258236

259-
# we have now moved from meta to device,
260-
# reset parameters for proper initialization
261-
model.reset_parameters()
262-
logger.info("Model fully initialized via reset_parameters")
263-
264237
return model

0 commit comments

Comments
 (0)