Skip to content

Commit

Permalink
Merge branch 'master' into lyj/lm_head_replace
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Jan 10, 2025
2 parents 32c0bc0 + 1d15ef0 commit 77a8107
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/nv-ds-chat.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
- name: Install deepspeed
run: |
pip install transformers==4.45.2
pip install transformers
pip install .[dev]
ds_report
Expand Down
4 changes: 4 additions & 0 deletions SECURITY.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ We prefer all communications to be in English.
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).

<!-- END MICROSOFT SECURITY.MD BLOCK -->

---

Please see [PyTorch's Security Policy](https://github.com/pytorch/pytorch/blob/main/SECURITY.md) for more information and recommendations on how to securely interact with models.
2 changes: 1 addition & 1 deletion blogs/windows/08-2024/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Regardless of the installation choice, you can check that the installation was s
We use an image classification model, CIFAR10, and a language model, BERT, to demonstrate pretraining on Windows with DeepSpeed.

## Pretraining CIFAR10
The scripts and codes required for CIFAR10 pretraining example are available in the following path: DeepSpeedExamples\training\cifar. You can launch the CIFAR10 pretraining experiment using the following command: `deepspeed cifar10_deepspeed.py deepspeed`. The final output should look something like this:
The scripts and codes required for CIFAR10 pretraining example are available in the following path: DeepSpeedExamples\training\cifar. You can launch the CIFAR10 pretraining experiment using the following command: `deepspeed cifar10_deepspeed.py --deepspeed`. The final output should look something like this:
<div align="center">
<img src="./media/cifar10_training.png" style="width:6.5in;height:3.42153in" />
</div>
Expand Down
15 changes: 10 additions & 5 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def is_load_module(module):
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear",
"MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding",
"Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding", "Qwen2RMSNorm"
"Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding", "Qwen2RMSNorm",
"DeepseekV2RMSNorm", "DeepseekV2YarnRotaryEmbedding", "MoEGate"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names

Expand Down Expand Up @@ -332,9 +333,9 @@ def _replace(self, child, name, conv_linear_layer):
return
weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
# For mixtral-7x8b, need to skip MoE gate linear replace.
if name == "block_sparse_moe.gate" or (('mlp.shared_expert_gate' == name or 'mlp.gate' == name)
and 'qwen2_moe' in str(type(self.module))):
# For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
if "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or (
('mlp.shared_expert_gate' == name or 'mlp.gate' == name) and 'qwen2_moe' in str(type(self.module))):
return child
# For Yuan model
if 'Yuan' in str(self.module):
Expand All @@ -350,11 +351,15 @@ def _replace(self, child, name, conv_linear_layer):
arctic_w2_all_reduce_linear = False
if 'Arctic' in str(self.module) and 'w2' in name:
arctic_w2_all_reduce_linear = True
# For MoE MLP model, e.g., deepseek and jamba
down_proj = False
if 'down_proj' in name:
down_proj = True
# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
return LinearLayer(weight=weight, bias=bias)
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear:
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]

Expand Down
2 changes: 1 addition & 1 deletion deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(self, weight_shape=None, weight=None, bias=None):
self.offset = 2
super().__init__(weight_shape, weight=weight)

def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0, position_ids: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()

Expand Down
7 changes: 6 additions & 1 deletion deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,16 @@ def get_num_attention_heads():
def get_shard_size(total_size, mp_size, name=None, rank=None):
global num_kv_heads
last_linear = ["lm_head", "embed_out"]
# MoE MLP layer use near even division will get better perf.
moe_mlp_layer = ["gate_proj", "up_proj", "down_proj", "w1", "w2", "w3"]
not_moe_mlp_layer = True
if name != None and any(s in str(name) for s in moe_mlp_layer):
not_moe_mlp_layer = False
# When we have num_kv_heads defined, uneven division is possible, otherwise enforce near even division
if rank == None:
rank = dist.get_rank()
if num_kv_heads != None and total_size % num_kv_heads == 0 and "mlp" not in str(name) and str(
name) not in last_linear:
name) not in last_linear and not_moe_mlp_layer:
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
return total_size * my_slices // num_kv_heads
else:
Expand Down
5 changes: 4 additions & 1 deletion deepspeed/ops/transformer/inference/triton/matmul_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
# -----------------------------------------------------------------------------
# util class/functions for triton
def is_nfs_path(path):
if os.name == 'nt':
return False

# Normalize the path to get the absolute path
path = os.path.abspath(path)

Expand Down Expand Up @@ -99,7 +102,7 @@ def put(self, table):
with FileLock(self.lock_path):
with open(self.file_path + ".tmp", 'wb') as handle:
pickle.dump(table, handle)
os.rename(self.file_path + ".tmp", self.file_path)
os.replace(self.file_path + ".tmp", self.file_path)

def load(self):
if os.path.exists(self.file_path):
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,9 @@ def op_enabled(op_name):
include_package_data=True,
scripts=scripts,
classifiers=[
'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10'
'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12'
],
license='Apache Software License 2.0',
ext_modules=ext_modules,
Expand Down

0 comments on commit 77a8107

Please # to comment.