Skip to content

Commit

Permalink
graph: backend: compiler: update llama mlp
Browse files Browse the repository at this point in the history
  • Loading branch information
yifeizh2 authored and vpirogov committed Aug 2, 2024
1 parent d6c216a commit ff680fc
Showing 1 changed file with 46 additions and 14 deletions.
60 changes: 46 additions & 14 deletions src/graph/backend/graph_compiler/patterns/mlp_pattern.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ void create_gpt_mlp(const std::shared_ptr<pb_graph_t> &pgraph,

pm::pb_node_t *append_rms_norm_option1(
const std::shared_ptr<pb_graph_t> &pgraph, pm::pb_node_t *input,
bool is_bf16 = false, bool is_int8 = false) {
bool is_bf16 = false, bool is_int8 = false, bool end_cast = false) {
if (is_bf16) {
auto typecast = pgraph->append_op(
graph::op_kind::TypeCast, {in_edge(0, input, 0)});
Expand All @@ -277,14 +277,18 @@ pm::pb_node_t *append_rms_norm_option1(
auto mul2 = pgraph->append_op(
graph::op_kind::Multiply, {in_edge(0, cast1, 0)});
mul2->allow_external_outputs();
UNUSED(is_bf16);
pm::pb_node_t *output = mul2;
if (end_cast) {
output = pgraph->append_op(
graph::op_kind::TypeCast, {in_edge(0, mul2, 0)});
}
UNUSED(is_int8);
return mul2;
return output;
};

pm::pb_node_t *append_rms_norm_option2(
const std::shared_ptr<pb_graph_t> &pgraph, pm::pb_node_t *input,
bool is_bf16 = false, bool is_int8 = false) {
bool is_bf16 = false, bool is_int8 = false, bool end_cast = false) {
pm::pb_node_t *pow_in = input;
pm::pb_node_t *mul1_in = input;
if (is_bf16) {
Expand All @@ -311,9 +315,13 @@ pm::pb_node_t *append_rms_norm_option2(
auto mul2 = pgraph->append_op(
graph::op_kind::Multiply, {in_edge(0, cast1, 0)});
mul2->allow_external_outputs();
UNUSED(is_bf16);
pm::pb_node_t *output = mul2;
if (end_cast) {
output = pgraph->append_op(
graph::op_kind::TypeCast, {in_edge(0, mul2, 0)});
}
UNUSED(is_int8);
return mul2;
return output;
};

/*
Expand Down Expand Up @@ -343,20 +351,20 @@ pm::pb_node_t *append_rms_norm_option2(
*/
void create_llama_mlp(const std::shared_ptr<pb_graph_t> &pgraph,
bool is_bf16 = false, bool is_int8 = false,
bool use_rms_norm_alternative = false,
bool split_smooth_quant = false) {
bool use_rms_norm_alternative = false, bool split_smooth_quant = false,
bool end_cast = false) {
auto matmul1 = create_dequant_matmul(pgraph, nullptr, is_bf16, is_int8);
auto add1
= pgraph->append_op(graph::op_kind::Add, {in_edge(0, matmul1, 0)});
add1->allow_external_outputs();
auto norm1 = use_rms_norm_alternative
? append_rms_norm_option1(pgraph, add1, is_bf16, is_int8)
: append_rms_norm_option2(pgraph, add1, is_bf16, is_int8);
? append_rms_norm_option1(pgraph, add1, is_bf16, is_int8, end_cast)
: append_rms_norm_option2(pgraph, add1, is_bf16, is_int8, end_cast);

pm::pb_node_t *norm1_for_lhs = norm1, *norm1_for_rhs = norm1;
if (is_int8) {
auto extra_cast_before_mul = append_single_op_repetition_subgraph(
pgraph, graph::op_kind::TypeCast, norm1);
pgraph, graph::op_kind::TypeCast, norm1_for_lhs);
auto smooth_quant_mul1 = append_single_op_repetition_subgraph(
pgraph, graph::op_kind::Multiply, extra_cast_before_mul);
auto extra_cast_after_mul = append_single_op_repetition_subgraph(
Expand All @@ -366,7 +374,7 @@ void create_llama_mlp(const std::shared_ptr<pb_graph_t> &pgraph,
if (split_smooth_quant) {
auto extra_cast_before_mul_rhs
= append_single_op_repetition_subgraph(
pgraph, graph::op_kind::TypeCast, norm1);
pgraph, graph::op_kind::TypeCast, norm1_for_rhs);
auto smooth_quant_mul1_rhs = append_single_op_repetition_subgraph(
pgraph, graph::op_kind::Multiply,
extra_cast_before_mul_rhs);
Expand Down Expand Up @@ -414,8 +422,8 @@ void create_llama_mlp(const std::shared_ptr<pb_graph_t> &pgraph,
graph::op_kind::Add, {in_edge(0, matmul4, 0), in_edge(1, add1, 0)});
add2->allow_external_outputs();
auto norm2 = use_rms_norm_alternative
? append_rms_norm_option1(pgraph, add2, is_bf16, is_int8)
: append_rms_norm_option2(pgraph, add2, is_bf16, is_int8);
? append_rms_norm_option1(pgraph, add2, is_bf16, is_int8, end_cast)
: append_rms_norm_option2(pgraph, add2, is_bf16, is_int8, end_cast);
if (is_int8) {
auto extra_cast_before_mul = append_single_op_repetition_subgraph(
pgraph, graph::op_kind::TypeCast, norm2);
Expand Down Expand Up @@ -813,6 +821,22 @@ COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(compiler, int8_bf16_llama_mlp)
.set_attr<FCreatePattern>("FCreatePattern",
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
create_llama_mlp(pgraph, true, true, true, true);
})
.set_attr<FCreatePattern>("FCreatePattern",
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
create_llama_mlp(pgraph, true, true, false, false, true);
})
.set_attr<FCreatePattern>("FCreatePattern",
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
create_llama_mlp(pgraph, true, true, true, false, true);
})
.set_attr<FCreatePattern>("FCreatePattern",
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
create_llama_mlp(pgraph, true, true, false, true, true);
})
.set_attr<FCreatePattern>("FCreatePattern",
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
create_llama_mlp(pgraph, true, true, true, true, true);
});
COMPILER_BACKEND_REGISTER_PASSES_DEF_END

Expand Down Expand Up @@ -1077,6 +1101,14 @@ COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(compiler, bf16_llama_mlp)
.set_attr<FCreatePattern>("FCreatePattern",
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
create_llama_mlp(pgraph, true, false, true);
})
.set_attr<FCreatePattern>("FCreatePattern",
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
create_llama_mlp(pgraph, true, false, false, false, true);
})
.set_attr<FCreatePattern>("FCreatePattern",
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
create_llama_mlp(pgraph, true, false, true, false, true);
});
COMPILER_BACKEND_REGISTER_PASSES_DEF_END

Expand Down

0 comments on commit ff680fc

Please sign in to comment.