Skip to content

Commit

Permalink
backend: compiler: update MHA training pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
yifeizh2 authored and ElaineBao committed Dec 15, 2022
1 parent d845d95 commit 147d9bc
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 73 deletions.
174 changes: 106 additions & 68 deletions src/backend/graph_compiler/patterns/mha_pattern.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,46 @@ pm::repetition_t *create_append_transpose_repetition_subgraph(
return transpose;
};

pm::repetition_t *create_optional_mul_subgraph(
const std::shared_ptr<pb_graph_t> &pgraph, pm::pb_op_t *input,
bool allow_external = false) {
auto optional_mul_subgraph
= std::make_shared<pb_graph_t>("optional_mul_subgraph");
auto optional_mul = optional_mul_subgraph->append_op(
impl::op_kind::Multiply, "optional_mul");
if (allow_external) { optional_mul->allow_external_outputs(); }

optional_mul_subgraph->create_input_port(0, optional_mul, 0);
optional_mul_subgraph->create_output_port(0, optional_mul, 0);
auto mul = pgraph->append_optional(
optional_mul_subgraph, {in_edge(0, input, 0)}, "optional_mul");
return mul;
};

pm::alternation_t *create_alternative_mul_subgraph(
const std::shared_ptr<pb_graph_t> &pgraph, pm::pb_op_t *input) {
// 2 alternations: 1) a single mul; 2) 2 consecutive mul
auto successive_mul_subgraph
= std::make_shared<pb_graph_t>("successive_mul_subgraph");
auto mul1 = successive_mul_subgraph->append_op(
impl::op_kind::Multiply, "mul1");
auto mul2 = successive_mul_subgraph->append_op(
impl::op_kind::Multiply, {in_edge(0, mul1, 0)}, "mul2");
successive_mul_subgraph->create_input_port(0, mul1, 0);
successive_mul_subgraph->create_output_port(0, mul2, 0);

auto single_mul_subgraph
= std::make_shared<pb_graph_t>("single_mul_subgraph");
auto mul = single_mul_subgraph->append_op(impl::op_kind::Multiply, "mul");
single_mul_subgraph->create_input_port(0, mul, 0);
single_mul_subgraph->create_output_port(0, mul, 0);

auto mul_subgraph = pgraph->append_alternation(
{successive_mul_subgraph, single_mul_subgraph},
{in_edge(0, input, 0)}, "mul_subgraph");
return mul_subgraph;
}

COMPILER_BACKEND_REGISTER_PASSES_DEF_BEGIN(fp32_mha_pattern)
// fp32 MHA pattern
/*
Expand Down Expand Up @@ -237,10 +277,12 @@ COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(
\ /
Add
|
Softmax [Dropout](f32)
\ /
Mul [ValueTrans](f32)
\ /
Softmax [Dropout](f32)
\ /
Multiply [Select](f32)
\ /
(optional)Multiply [ValueTrans](f32)
\ /
MatMul
|
Transpose
Expand Down Expand Up @@ -274,8 +316,10 @@ COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(
auto dropout = pgraph->append_op(impl::op_kind::Multiply,
{in_edge(0, softmax, 0)}, "dropout");
dropout->allow_external_outputs();
auto select = create_optional_mul_subgraph(
pgraph, dropout, true);
auto matmul_v = pgraph->append_op(impl::op_kind::MatMul,
{in_edge(0, dropout, 0)}, "matmul_v");
{in_edge(0, select, 0)}, "matmul_v");
matmul_v->append_decision_function(
check_input_dtype<impl::data_type::f32>);
auto reshape
Expand All @@ -293,25 +337,27 @@ COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(
|
(f32)[DrouputOut] Transpose [ValueTrans](f32)
\ / \ /
MatMul MatMul [Dropout](f32)
| \ /
[output](f32) Mul [SoftmaxOut](f32)
/ \ /
/ Mul
| |
| ReduceSum
\ /
Sub [SoftmaxOut](f32)
\ /
Mul [Fscore](f32)
\ /
Div|Mul [QueryTrans](f32)
___________________/ \ /
\ [KeyTrans](f32) MatMul
\ / |
MatMul [output](f32)
|
[output](f32)
MatMul MatMul [Select](f32)
| \ /
[output](f32) (optional)Multiply [Dropout](f32)
\ /
Multiply [SoftmaxOut](f32)
/ \ /
/ Multiply
| |
| ReduceSum
\ /
Sub [SoftmaxOut](f32)
\ /
Multiply [Fscore](f32)
\ /
Div|Mul [QueryTrans](f32)
___________________/ \ /
\ [KeyTrans](f32) MatMul
\ / |
MatMul [output](f32)
|
[output](f32)
*/
COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(
compiler, fp32_mha_backward_pattern)
Expand All @@ -336,11 +382,8 @@ COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(
{in_edge(0, in_transpose, 0)}, "bmm_v_grad_data");
bmm_v_grad_data->append_decision_function(
check_input_dtype<impl::data_type::f32>);
auto dropout_grad = pgraph->append_op(
impl::op_kind::Multiply,
{in_edge(0, bmm_v_grad_data, 0)}, "dropout_grad");
dropout_grad->append_decision_function(
check_input_dtype<impl::data_type::f32>);
auto dropout_grad = create_alternative_mul_subgraph(
pgraph, bmm_v_grad_data);
auto softmax_mul = pgraph->append_op(
impl::op_kind::Multiply,
{in_edge(0, dropout_grad, 0)}, "softmax_mul");
Expand Down Expand Up @@ -446,16 +489,8 @@ COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(
impl::op_kind::StaticTranspose,
{in_edge(0, key_reshape, 0)}, "key_transpose");

auto optional_mul_subgraph = std::make_shared<pb_graph_t>(
"optional_mul_subgraph");
auto optional_mul = optional_mul_subgraph->append_op(
impl::op_kind::Multiply, "optional_mul");
optional_mul_subgraph->create_input_port(
0, optional_mul, 0);
optional_mul_subgraph->create_output_port(
0, optional_mul, 0);
auto mul = pgraph->append_optional(optional_mul_subgraph,
{in_edge(0, key_transpose, 0)}, "mul");
auto mul = create_optional_mul_subgraph(
pgraph, key_transpose);

auto matmul_qk = pgraph->append_op(impl::op_kind::MatMul,
{in_edge(0, query_transpose, 0),
Expand Down Expand Up @@ -786,10 +821,12 @@ COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(
\ /
Add
|
Softmax [Dropout](bf16)
\ /
Mul [ValueTrans](bf16)
\ /
Softmax [Dropout](bf16)
\ /
Multiply [Select](bf16)
\ /
(optional)Multiply [ValueTrans](bf16)
\ /
MatMul
|
Transpose
Expand Down Expand Up @@ -821,8 +858,10 @@ COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(
auto dropout = pgraph->append_op(impl::op_kind::Multiply,
{in_edge(0, softmax, 0)}, "dropout");
dropout->allow_external_outputs();
auto select = create_optional_mul_subgraph(
pgraph, dropout, true);
auto matmul_v = pgraph->append_op(impl::op_kind::MatMul,
{in_edge(0, dropout, 0)}, "matmul_v");
{in_edge(0, select, 0)}, "matmul_v");
matmul_v->append_decision_function(
check_input_dtype<impl::data_type::bf16>);
auto reshape
Expand All @@ -840,25 +879,27 @@ COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(
|
(bf16)[DrouputOut] Transpose [ValueTrans](bf16)
\ / \ /
MatMul MatMul [Dropout](bf16)
| \ /
[output](bf16) Mul [SoftmaxOut](bf16)
/ \ /
/ Mul
| |
| ReduceSum
\ /
Sub [SoftmaxOut](bf16)
\ /
Mul [Fscore](f32/bf16)
\ /
Div|Mul [QueryTrans](bf16)
___________________/ \ /
\ [KeyTrans](bf16) MatMul
\ / |
MatMul [output](bf16)
|
[output](bf16)
MatMul MatMul [Select](bf16)
| \ /
[output](bf16) (optional)Multiply [Dropout](bf16)
\ /
Multiply [SoftmaxOut](bf16)
/ \ /
/ Multiply
| |
| ReduceSum
\ /
Sub [SoftmaxOut](bf16)
\ /
Multiply [Fscore](bf16)
\ /
Div|Mul [QueryTrans](bf16)
___________________/ \ /
\ [KeyTrans](bf16) MatMul
\ / |
MatMul [output](bf16)
|
[output](bf16)
*/
COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(
compiler, bf16_mha_backward_pattern)
Expand All @@ -883,11 +924,8 @@ COMPILER_BACKEND_REGISTER_TRANSFORMATION_PASS(
{in_edge(0, in_transpose, 0)}, "bmm_v_grad_data");
bmm_v_grad_data->append_decision_function(
check_input_dtype<impl::data_type::bf16>);
auto dropout_grad = pgraph->append_op(
impl::op_kind::Multiply,
{in_edge(0, bmm_v_grad_data, 0)}, "dropout_grad");
dropout_grad->append_decision_function(
check_input_dtype<impl::data_type::bf16>);
auto dropout_grad = create_alternative_mul_subgraph(
pgraph, bmm_v_grad_data);
auto softmax_mul = pgraph->append_op(
impl::op_kind::Multiply,
{in_edge(0, dropout_grad, 0)}, "softmax_mul");
Expand Down
18 changes: 18 additions & 0 deletions tests/cpp/unit/backend/graph_compiler/test_compile_execute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,24 @@ TEST(GCGraphTest, BF16MHATrainingGraphCompileExecution) {
compile_execution_pipeline(agraph, 2);
}

TEST(GCGraphTest, FP32MHATrainingGraphCompileExecution2) {
REQUIRE_AVX512();
impl::graph_t agraph;
compiler_utils::add_MHA_training_subgraph(&agraph, false, true);
agraph.build_graph();

compile_execution_pipeline(agraph, 2);
}

TEST(GCGraphTest, BF16MHATrainingGraphCompileExecution2) {
REQUIRE_BF16_AMXBF16();
impl::graph_t agraph;
compiler_utils::add_MHA_training_subgraph(&agraph, true, true);
agraph.build_graph();

compile_execution_pipeline(agraph, 2);
}

TEST(GCGraphTest, FP32IdenticalBottleneckCompileExecution) {
REQUIRE_AVX512();
REQUIRE_SINGLE_THREAD();
Expand Down
51 changes: 51 additions & 0 deletions tests/cpp/unit/backend/graph_compiler/test_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,32 @@ TEST(GCPatternTests, FP32MHATrainingPattern) {
ASSERT_EQ(partitions[1]->get_outputs().size(), 3U);
}

TEST(GCPatternTests, FP32MHATrainingPattern2) {
REQUIRE_AVX512();
impl::graph_t agraph;
compiler_utils::add_MHA_training_subgraph(&agraph, false, true);
agraph.build_graph();

auto &compiler_backend_ptr
= compiler_impl::compiler_backend_t::get_singleton();

pass::pass_base_ptr fwd_apass
= get_pass(compiler_backend_ptr, "fp32_mha_forward_pattern");
pass::pass_base_ptr bwd_apass
= get_pass(compiler_backend_ptr, "fp32_mha_backward_pattern");
fwd_apass->run(agraph);
bwd_apass->run(agraph);

auto partitions = agraph.get_partitions();
ASSERT_EQ(partitions.size(), 2U);
ASSERT_EQ(partitions[0]->get_ops().size(), 9U);
ASSERT_EQ(partitions[0]->get_inputs().size(), 7U);
ASSERT_EQ(partitions[0]->get_outputs().size(), 3U);
ASSERT_EQ(partitions[1]->get_ops().size(), 12U);
ASSERT_EQ(partitions[1]->get_inputs().size(), 9U);
ASSERT_EQ(partitions[1]->get_outputs().size(), 3U);
}

TEST(GCPatternTests, BF16MHATrainingPattern) {
REQUIRE_BF16_AMXBF16();
impl::graph_t agraph;
Expand All @@ -730,6 +756,31 @@ TEST(GCPatternTests, BF16MHATrainingPattern) {
ASSERT_EQ(partitions[1]->get_outputs().size(), 3U);
}

TEST(GCPatternTests, BF16MHATrainingPattern2) {
REQUIRE_BF16_AMXBF16();
impl::graph_t agraph;
compiler_utils::add_MHA_training_subgraph(&agraph, true, true);
agraph.build_graph();

auto &compiler_backend_ptr
= compiler_impl::compiler_backend_t::get_singleton();
pass::pass_base_ptr fwd_apass
= get_pass(compiler_backend_ptr, "bf16_mha_forward_pattern");
pass::pass_base_ptr bwd_apass
= get_pass(compiler_backend_ptr, "bf16_mha_backward_pattern");
fwd_apass->run(agraph);
bwd_apass->run(agraph);

auto partitions = agraph.get_partitions();
ASSERT_EQ(partitions.size(), 2U);
ASSERT_EQ(partitions[0]->get_ops().size(), 9U);
ASSERT_EQ(partitions[0]->get_inputs().size(), 7U);
ASSERT_EQ(partitions[0]->get_outputs().size(), 3U);
ASSERT_EQ(partitions[1]->get_ops().size(), 12U);
ASSERT_EQ(partitions[1]->get_inputs().size(), 9U);
ASSERT_EQ(partitions[1]->get_outputs().size(), 3U);
}

TEST(GCPatternTests, FP32IdenticalBottleneckPattern1) {
REQUIRE_AVX512();
REQUIRE_SINGLE_THREAD();
Expand Down
Loading

0 comments on commit 147d9bc

Please sign in to comment.