Skip to content

Commit

Permalink
Enable mixtral 8x7b autotp (#5257)
Browse files Browse the repository at this point in the history
This PR aims to enable mixtral 8x7b (MoE model) autotp.

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
Yejing-Lai and loadams authored Mar 27, 2024
1 parent 4520edd commit 0e5c3f9
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def is_load_module(module):
load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear",
"MistralRMSNorm", "T5LayerNorm"
"MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names

Expand Down Expand Up @@ -303,6 +303,9 @@ def tp_parser(model):
elif 'self_attention.dense' in layer and 'falcon' in str(
type(module)): # this is a hack to get the right linear layer for this model!
gem_list = gem_list + [layer]
# Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce.
elif 'w2' in layer and 'Mixtral' in str(type(module)):
gem_list = gem_list + [layer]

layer_list = []
if gem_list != []:
Expand All @@ -322,6 +325,9 @@ def _replace(self, child, name, conv_linear_layer):
return
weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
# For mixtral-7x8b, need to skip MoE gate linear replace.
if name == "block_sparse_moe.gate":
return child
if name in self.all_reduce_linears:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]
Expand Down

0 comments on commit 0e5c3f9

Please sign in to comment.