From 0e5c3f9c2ee9e91aef2e9e42587ea26f82749111 Mon Sep 17 00:00:00 2001 From: Yejing-Lai Date: Wed, 27 Mar 2024 15:02:01 +0800 Subject: [PATCH] Enable mixtral 8x7b autotp (#5257) This PR aims to enable mixtral 8x7b (MoE model) autotp. Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/module_inject/auto_tp.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index bf9c2d74c6352..88f7086518e8e 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -133,7 +133,7 @@ def is_load_module(module): load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm] load_layer_names = [ "LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear", - "MistralRMSNorm", "T5LayerNorm" + "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm" ] return module.__class__ in load_layers or module._get_name() in load_layer_names @@ -303,6 +303,9 @@ def tp_parser(model): elif 'self_attention.dense' in layer and 'falcon' in str( type(module)): # this is a hack to get the right linear layer for this model! gem_list = gem_list + [layer] + # Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce. + elif 'w2' in layer and 'Mixtral' in str(type(module)): + gem_list = gem_list + [layer] layer_list = [] if gem_list != []: @@ -322,6 +325,9 @@ def _replace(self, child, name, conv_linear_layer): return weight_shape = child.weight.shape mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group) + # For mixtral-7x8b, need to skip MoE gate linear replace. + if name == "block_sparse_moe.gate": + return child if name in self.all_reduce_linears: # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] # else [weight_shape[0], weight_shape[1] // mp_size]