Skip to content

Commit

Permalink
gtests: graph: unit: add gtests for verifying multi-consumers in pm
Browse files Browse the repository at this point in the history
  • Loading branch information
ElaineBao authored and TaoLv committed Dec 23, 2024
1 parent 3d64643 commit 2f2c380
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions tests/gtests/graph/unit/utils/test_pattern_matcher_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,78 @@ TEST(test_utils_pattern_matcher, ComplexRepetition) {
ASSERT_EQ(fusion_ops.size(), 3U);
}

TEST(test_utils_pattern_matcher, SharedInput) {
/* Pattern that captures shared input to two MatMuls
|
/ \
MatMul MatMul
\ /
Multiply
|
*/
auto graphp = std::make_shared<pb_graph_t>();
auto pmm1 = graphp->append_op(MatMul);
auto pmm2 = graphp->append_op(MatMul);
graphp->create_input_port(0, pmm1, 0);
graphp->create_input_port(0, pmm2, 0);
auto pmul = graphp->append_op(
Multiply, {in_edge(0, pmm1, 0), in_edge(1, pmm2, 0)});
UNUSED(pmul);

// test with a graph that has the shared input
graph_t agraph;
op_t matmul1 {0, MatMul, "matmul1"};
op_t matmul2 {1, MatMul, "matmul2"};
op_t multiply {2, Multiply, "multiply"};

std::vector<logical_tensor_t> lt_vec = create_logical_tensors(6);
matmul1.add_input(lt_vec[0]);
matmul1.add_input(lt_vec[1]);
matmul1.add_output(lt_vec[2]);
matmul2.add_input(lt_vec[0]);
matmul2.add_input(lt_vec[3]);
matmul2.add_output(lt_vec[4]);
multiply.add_input(lt_vec[2]);
multiply.add_input(lt_vec[4]);
multiply.add_output(lt_vec[5]);

ASSERT_EQ(agraph.add_op(&matmul1), status::success);
ASSERT_EQ(agraph.add_op(&matmul2), status::success);
ASSERT_EQ(agraph.add_op(&multiply), status::success);
agraph.finalize();

std::vector<op_t *> fusion_ops;
EXPECT_TRUE(match_pattern(agraph.get_ops()[0].get(), graphp, fusion_ops));
ASSERT_EQ(fusion_ops.size(), 3U);

// test with a graph that does not have the shared input
graph_t agraph2;
op_t matmul3 {0, MatMul, "matmul1"};
op_t matmul4 {1, MatMul, "matmul2"};
op_t multiply2 {2, Multiply, "multiply"};

std::vector<logical_tensor_t> lt_vec2 = create_logical_tensors(7);
matmul3.add_input(lt_vec2[0]);
matmul3.add_input(lt_vec2[1]);
matmul3.add_output(lt_vec2[2]);
matmul4.add_input(lt_vec2[3]);
matmul4.add_input(lt_vec2[4]);
matmul4.add_output(lt_vec2[5]);
multiply2.add_input(lt_vec2[2]);
multiply2.add_input(lt_vec2[5]);
multiply2.add_output(lt_vec2[6]);

ASSERT_EQ(agraph2.add_op(&matmul3), status::success);
ASSERT_EQ(agraph2.add_op(&matmul4), status::success);
ASSERT_EQ(agraph2.add_op(&multiply2), status::success);
agraph2.finalize();

std::vector<op_t *> fusion_ops2;
EXPECT_FALSE(
match_pattern(agraph2.get_ops()[0].get(), graphp, fusion_ops2));
ASSERT_EQ(fusion_ops2.size(), 0U);
}

TEST(test_utils_pattern_matcher, ParallelMatmul) {
auto graphp = std::make_shared<pb_graph_t>();
// Pattern that captures shared input to three MatMuls
Expand Down

0 comments on commit 2f2c380

Please sign in to comment.