Skip to content

Commit 81b8854

Browse files
feginatalman
andauthored
[DSD] Add a test to verify FSDP lazy initialization case (#127069) (#127130)
* [DSD] Add a test to verify FSDP lazy initialization case (#127069) Summary: Distributed state_dict should not error out because the `model.state_dict()` will trigger FSDP to initialize. Pull Request resolved: #127069 Approved by: https://github.com/wz337 * Add missing import get_optimizer_state_dict --------- Co-authored-by: Andrey Talman <atalman@fb.com>
1 parent e63004b commit 81b8854

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

test/distributed/checkpoint/test_state_dict.py

+14
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
_patch_model_state_dict,
2323
_patch_optimizer_state_dict,
2424
get_model_state_dict,
25+
get_optimizer_state_dict,
2526
get_state_dict,
2627
set_model_state_dict,
2728
set_state_dict,
@@ -555,6 +556,19 @@ def _test_activation_ckpt_fqns_fsdp1(self, use_orig_params: bool) -> None:
555556

556557
self.assertEqual(original_keys, new_keys)
557558

559+
@with_comms
560+
@skip_if_lt_x_gpu(2)
561+
def test_fsdp_root_not_initialized(self) -> None:
562+
# This test verifies that FSDP root is not initialized but we should
563+
# still be able to get the state_dict without errors because
564+
# fsdp_model.state_dict() will trigger the FSDP initialization.
565+
device_mesh = init_device_mesh("cuda", (self.world_size,))
566+
model = CompositeParamModel(device=torch.device("cuda"))
567+
fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh)
568+
fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
569+
get_model_state_dict(fsdp_model)
570+
get_optimizer_state_dict(fsdp_model, fsdp_optim)
571+
558572

559573
if __name__ == "__main__":
560574
run_tests()

0 commit comments

Comments
 (0)