diff --git a/src/graph/backend/graph_compiler/patterns/mlp_pattern.hpp b/src/graph/backend/graph_compiler/patterns/mlp_pattern.hpp index 4744336b200..590a25faa6e 100644 --- a/src/graph/backend/graph_compiler/patterns/mlp_pattern.hpp +++ b/src/graph/backend/graph_compiler/patterns/mlp_pattern.hpp @@ -257,7 +257,7 @@ void create_gpt_mlp(const std::shared_ptr &pgraph, pm::pb_node_t *append_rms_norm_option1( const std::shared_ptr &pgraph, pm::pb_node_t *input, - bool is_bf16 = false, bool is_int8 = false) { + bool is_bf16 = false, bool is_int8 = false, bool end_cast = false) { if (is_bf16) { auto typecast = pgraph->append_op( graph::op_kind::TypeCast, {in_edge(0, input, 0)}); @@ -277,14 +277,18 @@ pm::pb_node_t *append_rms_norm_option1( auto mul2 = pgraph->append_op( graph::op_kind::Multiply, {in_edge(0, cast1, 0)}); mul2->allow_external_outputs(); - UNUSED(is_bf16); + pm::pb_node_t *output = mul2; + if (end_cast) { + output = pgraph->append_op( + graph::op_kind::TypeCast, {in_edge(0, mul2, 0)}); + } UNUSED(is_int8); - return mul2; + return output; }; pm::pb_node_t *append_rms_norm_option2( const std::shared_ptr &pgraph, pm::pb_node_t *input, - bool is_bf16 = false, bool is_int8 = false) { + bool is_bf16 = false, bool is_int8 = false, bool end_cast = false) { pm::pb_node_t *pow_in = input; pm::pb_node_t *mul1_in = input; if (is_bf16) { @@ -311,9 +315,13 @@ pm::pb_node_t *append_rms_norm_option2( auto mul2 = pgraph->append_op( graph::op_kind::Multiply, {in_edge(0, cast1, 0)}); mul2->allow_external_outputs(); - UNUSED(is_bf16); + pm::pb_node_t *output = mul2; + if (end_cast) { + output = pgraph->append_op( + graph::op_kind::TypeCast, {in_edge(0, mul2, 0)}); + } UNUSED(is_int8); - return mul2; + return output; }; /* @@ -343,20 +351,20 @@ pm::pb_node_t *append_rms_norm_option2( */ void create_llama_mlp(const std::shared_ptr &pgraph, bool is_bf16 = false, bool is_int8 = false, - bool use_rms_norm_alternative = false, - bool split_smooth_quant = false) { + bool use_rms_norm_alternative = false, bool split_smooth_quant = false, + bool end_cast = false) { auto matmul1 = create_dequant_matmul(pgraph, nullptr, is_bf16, is_int8); auto add1 = pgraph->append_op(graph::op_kind::Add, {in_edge(0, matmul1, 0)}); add1->allow_external_outputs(); auto norm1 = use_rms_norm_alternative - ? append_rms_norm_option1(pgraph, add1, is_bf16, is_int8) - : append_rms_norm_option2(pgraph, add1, is_bf16, is_int8); + ? append_rms_norm_option1(pgraph, add1, is_bf16, is_int8, end_cast) + : append_rms_norm_option2(pgraph, add1, is_bf16, is_int8, end_cast); pm::pb_node_t *norm1_for_lhs = norm1, *norm1_for_rhs = norm1; if (is_int8) { auto extra_cast_before_mul = append_single_op_repetition_subgraph( - pgraph, graph::op_kind::TypeCast, norm1); + pgraph, graph::op_kind::TypeCast, norm1_for_lhs); auto smooth_quant_mul1 = append_single_op_repetition_subgraph( pgraph, graph::op_kind::Multiply, extra_cast_before_mul); auto extra_cast_after_mul = append_single_op_repetition_subgraph( @@ -366,7 +374,7 @@ void create_llama_mlp(const std::shared_ptr &pgraph, if (split_smooth_quant) { auto extra_cast_before_mul_rhs = append_single_op_repetition_subgraph( - pgraph, graph::op_kind::TypeCast, norm1); + pgraph, graph::op_kind::TypeCast, norm1_for_rhs); auto smooth_quant_mul1_rhs = append_single_op_repetition_subgraph( pgraph, graph::op_kind::Multiply, extra_cast_before_mul_rhs); @@ -414,8 +422,8 @@ void create_llama_mlp(const std::shared_ptr &pgraph, graph::op_kind::Add, {in_edge(0, matmul4, 0), in_edge(1, add1, 0)}); add2->allow_external_outputs(); auto norm2 = use_rms_norm_alternative - ? append_rms_norm_option1(pgraph, add2, is_bf16, is_int8) - : append_rms_norm_option2(pgraph, add2, is_bf16, is_int8); + ? append_rms_norm_option1(pgraph, add2, is_bf16, is_int8, end_cast) + : append_rms_norm_option2(pgraph, add2, is_bf16, is_int8, end_cast); if (is_int8) { auto extra_cast_before_mul = append_single_op_repetition_subgraph( pgraph, graph::op_kind::TypeCast, norm2); @@ -813,6 +821,22 @@ COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(compiler, int8_bf16_llama_mlp) .set_attr("FCreatePattern", [](const std::shared_ptr &pgraph) -> void { create_llama_mlp(pgraph, true, true, true, true); + }) + .set_attr("FCreatePattern", + [](const std::shared_ptr &pgraph) -> void { + create_llama_mlp(pgraph, true, true, false, false, true); + }) + .set_attr("FCreatePattern", + [](const std::shared_ptr &pgraph) -> void { + create_llama_mlp(pgraph, true, true, true, false, true); + }) + .set_attr("FCreatePattern", + [](const std::shared_ptr &pgraph) -> void { + create_llama_mlp(pgraph, true, true, false, true, true); + }) + .set_attr("FCreatePattern", + [](const std::shared_ptr &pgraph) -> void { + create_llama_mlp(pgraph, true, true, true, true, true); }); COMPILER_BACKEND_REGISTER_PASSES_DEF_END @@ -1077,6 +1101,14 @@ COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(compiler, bf16_llama_mlp) .set_attr("FCreatePattern", [](const std::shared_ptr &pgraph) -> void { create_llama_mlp(pgraph, true, false, true); + }) + .set_attr("FCreatePattern", + [](const std::shared_ptr &pgraph) -> void { + create_llama_mlp(pgraph, true, false, false, false, true); + }) + .set_attr("FCreatePattern", + [](const std::shared_ptr &pgraph) -> void { + create_llama_mlp(pgraph, true, false, true, false, true); }); COMPILER_BACKEND_REGISTER_PASSES_DEF_END