Skip to content

Commit

Permalink
fix baichuan lm_head replace issue (deepspeedai#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yejing-Lai authored Nov 23, 2023
1 parent 0ebb1ed commit 547ac96
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,13 @@ def set_lm_head(module):
module.lm_head, "weight") and module.lm_head.weight.is_meta:
module.lm_head.weight = embedding_weight
# enable tensor parallel for the last linear
if hasattr(module, "lm_head") and hasattr(module.lm_head, "weight") and not module.lm_head.weight.is_meta:
if hasattr(module, "lm_head") and hasattr(module.lm_head,
"weight") and not module.lm_head.weight.is_meta and isinstance(
module.lm_head, torch.nn.Linear):
module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head")
elif hasattr(module, "embed_out") and hasattr(module.embed_out,
"weight") and not module.embed_out.weight.is_meta:
"weight") and not module.embed_out.weight.is_meta and isinstance(
module.embed_out, torch.nn.Linear):
module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
return module

Expand Down

0 comments on commit 547ac96

Please # to comment.