Skip to content

Commit

Permalink
Fixed format
Browse files Browse the repository at this point in the history
  • Loading branch information
gyou2021 committed Jan 21, 2025
1 parent 9adcc49 commit d71c993
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def update_policy_list(policy_list, new_module, new_gems):
# if module already exists in policy, combine gems and remove duplicates
if policy[0] == type(new_module):
#new_gems = set(new_gems + policy[1])
print('new_gems ,policy[1] type',type(new_gems), type(policy[1]))
print('new_gems ,policy[1] type', type(new_gems), type(policy[1]))
new_gems = (new_gems + policy[1])
policy_list[i] = tuple([type(new_module), new_gems])
return policy_list
Expand Down Expand Up @@ -309,19 +309,21 @@ def tp_parser(model):
elif 'out_proj' in layer:
gem_list = gem_list + [layer]
elif 'o_proj' in layer:
gem_list = gem_list + [layer]
elif 'down_proj' in layer and not (('DeepseekV2' in str(type(module))) or ('qwen2_moe' in str(type(module)))):
gem_list = gem_list + [layer]
elif 'shared_experts.down_proj' in layer and (('DeepseekV2' in str(type(module))) or ('qwen2_moe' in str(type(module)))):
elif 'down_proj' in layer and not (('DeepseekV2' in str(type(module))) or
('qwen2_moe' in str(type(module)))):
gem_list = gem_list + [layer]
elif 'mlp.down_proj' in layer and ('DeepseekV2' in str(type(module))):
elif 'shared_experts.down_proj' in layer and (('DeepseekV2' in str(type(module))) or
('qwen2_moe' in str(type(module)))):
gem_list = gem_list + [layer]
elif 'mlp.down_proj' in layer and ('DeepseekV2' in str(type(module))):
gem_list = gem_list + [layer]
elif 'attention.dense' in layer and 'GPTNeoX' in str(model):
gem_list = gem_list + [layer]
elif 'self_attention.dense' in layer and 'falcon' in str(
type(module)): # this is a hack to get the right linear layer for this model!
gem_list = gem_list + [layer]
# Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce.
# Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce.
elif 'w2' in layer and 'Mixtral' in str(type(module)):
gem_list = gem_list + [layer]
elif 'self_attn.dense' in layer and 'Phi' in str(type(module)):
Expand Down Expand Up @@ -371,10 +373,10 @@ def _replace(self, child, name, conv_linear_layer):
arctic_w2_all_reduce_linear = False
if 'Arctic' in str(self.module) and 'w2' in name:
arctic_w2_all_reduce_linear = True
# For MoE MLP model, e.g., deepseek and jamba
# For MoE MLP model, e.g., deepseek and jamba
down_proj = False
#Deepseek processes different down_proj in different ways.
if 'down_proj' in name and 'DeepseekV2' not in str(type(self.module)):
#Deepseek processes different down_proj in different ways.
if 'down_proj' in name and 'DeepseekV2' not in str(type(self.module)):
down_proj = True
# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
Expand Down Expand Up @@ -427,14 +429,14 @@ def _replace(self, child, name, conv_linear_layer):
torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name()))
else:
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name),
dim=1 if self.conv_linear_layer else 0)
dim=1 if self.conv_linear_layer else 0)
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
del data

if child.bias is not None:
bias_data = child.bias.data.split(get_shard_size_list(
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name),
dim=0)
dim=0)
bias_data = move(bias_data[mp_replace.gpu_index], device_name, return_new_copy)
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
del bias_data
Expand Down

0 comments on commit d71c993

Please # to comment.