From 9ed5a6d18a5390404c3fdffb7fcd4199aebad468 Mon Sep 17 00:00:00 2001 From: Stella Biderman Date: Sat, 23 Jan 2021 21:27:51 -0500 Subject: [PATCH 1/2] Trying out a patch @sdtblck is this correct? --- deepspeed/runtime/pipe/module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 8035a8b97e78..0999fa6625b0 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -567,8 +567,8 @@ def load_state_dir(self, load_dir, strict=True): self._synchronize_tied_weights() def _is_checkpointable(self, funcs): - if self.__class__.__name__ == 'GPT2ModelPipe': - return all('ParallelTransformerLayerPipe' in f.__class__.__name__ + if self.__class__.__name__ == 'GPT_NeoXPipe': + return all('TransformerBlock' in f.__class__.__name__ for f in funcs) params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)] From c33b23953dcfbdfea3a7fc81a890e94df015e30d Mon Sep 17 00:00:00 2001 From: Stella Biderman Date: Sat, 23 Jan 2021 21:33:03 -0500 Subject: [PATCH 2/2] Fixed with Sid's explanation --- deepspeed/runtime/pipe/module.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 0999fa6625b0..18498766296c 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -567,9 +567,8 @@ def load_state_dir(self, load_dir, strict=True): self._synchronize_tied_weights() def _is_checkpointable(self, funcs): - if self.__class__.__name__ == 'GPT_NeoXPipe': - return all('TransformerBlock' in f.__class__.__name__ + if self.__class__.__name__ == 'GPT2ModelPipe': + return all('ParallelTransformerLayerPipe' in f.__class__.__name__ + for f in funcs) + return all('TransformerBlock' in f.__class__.__name__ for f in funcs) - - params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)] - return any(len(list(p)) > 0 for p in params)