8
8
from typing import Tuple
9
9
10
10
import torch
11
- from torch .distributed ._tensor import Replicate , Shard
12
11
12
+ from torch .distributed ._composable .fsdp import fully_shard , MixedPrecisionPolicy
13
+ from torch .distributed ._tensor import Replicate , Shard
13
14
from torch .distributed .algorithms ._checkpoint .checkpoint_wrapper import (
14
15
checkpoint_wrapper as ptd_checkpoint_wrapper ,
15
16
CheckpointImpl ,
16
17
)
17
- from torch .distributed .fsdp import (
18
- BackwardPrefetch ,
19
- FullyShardedDataParallel as FSDP ,
20
- MixedPrecision ,
21
- ShardingStrategy ,
22
- )
23
- from torch .distributed .fsdp .wrap import enable_wrap , wrap
24
18
from torch .distributed .tensor .parallel import (
25
19
ColwiseParallel ,
26
20
parallelize_module ,
33
27
34
28
from torchtrain .config_manager import JobConfig
35
29
from torchtrain .logging_utils import logger
36
- from torchtrain .meta_init import meta_to_real_init_fn
37
30
38
31
39
32
# for selective AC
@@ -75,7 +68,6 @@ def selective_checkpointing_context_fn():
75
68
preserve_rng_state = False ,
76
69
)
77
70
elif config .mode == "full" :
78
- # full AC
79
71
return ptd_checkpoint_wrapper (
80
72
module ,
81
73
checkpoint_impl = CheckpointImpl .NO_REENTRANT ,
@@ -136,28 +128,23 @@ def get_tp_parallel_strategy(
136
128
137
129
def parallelize_llama (model , world_mesh , parallel_dims , job_config : JobConfig ):
138
130
"""
139
- Apply parallelisms to the model, including PTD parallelisms, and AC .
131
+ Apply parallelisms and activation checkpointing to the model .
140
132
141
- NOTE: the model passed in preferrablably shoule be a meta device model ,
142
- otherwise the model needs to be small enough on GPU or can fit into CPU.
133
+ NOTE: The passed- in model preferably should be on meta device. Otherwise ,
134
+ the model must fit on GPU or CPU memory .
143
135
"""
144
- # apply PTD parallelisms
145
136
if parallel_dims .pp_enabled :
146
137
raise NotImplementedError ("PP not implemented yet." )
147
138
148
- # First we apply Tensor Parallelism if it's enabled
149
139
if parallel_dims .tp_enabled :
150
140
tp_mesh = world_mesh ["tp" ]
151
- tp_degree = job_config .training .tensor_parallel_degree
152
-
153
141
row_parallel_strategy , col_parallel_strategy = get_tp_parallel_strategy (
154
142
job_config
155
143
)
156
144
157
- # First:
158
- # 1. parallelize the first embedding and the last linear proj layer
159
- # 2. parallelize the root norm layer by sequence dim
160
- # 3. shard the first layer of transformer block
145
+ # 1. Parallelize the first embedding and the last linear proj layer
146
+ # 2. Parallelize the root norm layer over the sequence dim
147
+ # 3. Shard the first transformer block's inputs
161
148
model = parallelize_module (
162
149
model ,
163
150
tp_mesh ,
@@ -167,9 +154,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
167
154
),
168
155
"output" : col_parallel_strategy (
169
156
input_layouts = Shard (0 ),
170
- output_layouts = Shard (- 1 )
171
- if parallel_dims .loss_parallel_enabled
172
- else Replicate (),
157
+ output_layouts = (
158
+ Shard (- 1 )
159
+ if parallel_dims .loss_parallel_enabled
160
+ else Replicate ()
161
+ ),
173
162
use_local_output = not parallel_dims .loss_parallel_enabled ,
174
163
),
175
164
"norm" : SequenceParallel (sequence_dim = 0 ),
@@ -181,7 +170,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
181
170
},
182
171
)
183
172
184
- # apply tensor + sequence parallelism to every transformer block
173
+ # Apply tensor + sequence parallelism to every transformer block
185
174
for layer_id , transformer_block in enumerate (model .layers ):
186
175
layer_plan = {
187
176
"attention" : PrepareModuleInput (
@@ -203,62 +192,46 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
203
192
"ffn_norm" : SequenceParallel (sequence_dim = 0 ),
204
193
}
205
194
206
- # adjust num_heads in attention layer to local heads
195
+ # Adjust attention module to use the local number of heads
207
196
attn_layer = transformer_block .attention
208
- attn_layer .n_heads = attn_layer .n_heads // tp_degree
209
- attn_layer .n_kv_heads = attn_layer .n_kv_heads // tp_degree
197
+ attn_layer .n_heads = attn_layer .n_heads // tp_mesh . size ()
198
+ attn_layer .n_kv_heads = attn_layer .n_kv_heads // tp_mesh . size ()
210
199
211
200
parallelize_module (
212
201
module = transformer_block ,
213
202
device_mesh = tp_mesh ,
214
203
parallelize_plan = layer_plan ,
215
204
)
216
205
217
- logger .info ("Applied Sequence Parallelism to the model" )
206
+ logger .info ("Applied Tensor Parallelism to the model" )
218
207
219
208
if parallel_dims .dp_enabled :
220
- dp_mesh = world_mesh ["dp" ]
221
-
222
- fsdp_config = {
223
- "mixed_precision" : MixedPrecision (
224
- param_dtype = torch .bfloat16 ,
225
- # TODO: see whether we should expose a option to user
226
- reduce_dtype = torch .float32 ,
227
- ),
228
- "sharding_strategy" : ShardingStrategy .FULL_SHARD ,
229
- "backward_prefetch" : BackwardPrefetch .BACKWARD_PRE ,
230
- # When torch.compile is active, it requires us to set use_orig_params=True
231
- "use_orig_params" : True ,
232
- "device_mesh" : dp_mesh ,
233
- "param_init_fn" : meta_to_real_init_fn ,
234
- }
235
-
209
+ dp_mesh = world_mesh ["dp" ] if world_mesh .ndim > 1 else world_mesh
210
+ assert dp_mesh .mesh_dim_names == ("dp" ,), dp_mesh .mesh_dim_names
211
+ mp_policy = MixedPrecisionPolicy (
212
+ param_dtype = torch .bfloat16 , reduce_dtype = torch .float32
213
+ )
236
214
ac_mode = job_config .activation_checkpoint .mode
237
- with enable_wrap (wrapper_cls = FSDP , ** fsdp_config ):
238
- for layer_id , transformer_block in enumerate (model .layers ):
239
- # apply AC to the transformer block
240
- if ac_mode in ("full" , "selective" ):
241
- # wrap the transformer block with checkpoint wrapper, using config settings
242
- transformer_block = checkpoint_wrapper (
243
- transformer_block , job_config .activation_checkpoint
244
- )
245
-
246
- # Wraps each layer with FSDP
247
- model .layers [layer_id ] = wrap (transformer_block )
248
-
249
- # wrap the rest layers with FSDP
250
- model = wrap (model )
251
-
215
+ fsdp_config = {"mesh" : dp_mesh , "mp_policy" : mp_policy }
216
+ for layer_id , transformer_block in enumerate (model .layers ):
217
+ if job_config .activation_checkpoint .mode in ("full" , "selective" ):
218
+ transformer_block = checkpoint_wrapper (
219
+ transformer_block , job_config .activation_checkpoint
220
+ )
221
+ # As an optimization, do not reshard after forward for the last
222
+ # transformer block since FSDP would prefetch it immediately
223
+ reshard_after_forward = layer_id < len (model .layers ) - 1
224
+ fully_shard (
225
+ transformer_block ,
226
+ ** fsdp_config ,
227
+ reshard_after_forward = reshard_after_forward ,
228
+ )
229
+ model .layers [layer_id ] = transformer_block
230
+ model = fully_shard (model , ** fsdp_config )
252
231
if ac_mode in ("full" , "selective" ):
253
232
logger .info (f"Applied { ac_mode } activation checkpointing to the model" )
254
233
logger .info ("Applied FSDP to the model" )
255
234
else :
256
- meta_to_real_init_fn (model )
257
235
model .cuda ()
258
236
259
- # we have now moved from meta to device,
260
- # reset parameters for proper initialization
261
- model .reset_parameters ()
262
- logger .info ("Model fully initialized via reset_parameters" )
263
-
264
237
return model
0 commit comments