diff --git a/tests/gtests/graph/unit/utils/test_pattern_matcher_cpu.cpp b/tests/gtests/graph/unit/utils/test_pattern_matcher_cpu.cpp index 0028f789c49..ffafc22a390 100644 --- a/tests/gtests/graph/unit/utils/test_pattern_matcher_cpu.cpp +++ b/tests/gtests/graph/unit/utils/test_pattern_matcher_cpu.cpp @@ -1266,6 +1266,78 @@ TEST(test_utils_pattern_matcher, ComplexRepetition) { ASSERT_EQ(fusion_ops.size(), 3U); } +TEST(test_utils_pattern_matcher, SharedInput) { + /* Pattern that captures shared input to two MatMuls + | + / \ + MatMul MatMul + \ / + Multiply + | + */ + auto graphp = std::make_shared(); + auto pmm1 = graphp->append_op(MatMul); + auto pmm2 = graphp->append_op(MatMul); + graphp->create_input_port(0, pmm1, 0); + graphp->create_input_port(0, pmm2, 0); + auto pmul = graphp->append_op( + Multiply, {in_edge(0, pmm1, 0), in_edge(1, pmm2, 0)}); + UNUSED(pmul); + + // test with a graph that has the shared input + graph_t agraph; + op_t matmul1 {0, MatMul, "matmul1"}; + op_t matmul2 {1, MatMul, "matmul2"}; + op_t multiply {2, Multiply, "multiply"}; + + std::vector lt_vec = create_logical_tensors(6); + matmul1.add_input(lt_vec[0]); + matmul1.add_input(lt_vec[1]); + matmul1.add_output(lt_vec[2]); + matmul2.add_input(lt_vec[0]); + matmul2.add_input(lt_vec[3]); + matmul2.add_output(lt_vec[4]); + multiply.add_input(lt_vec[2]); + multiply.add_input(lt_vec[4]); + multiply.add_output(lt_vec[5]); + + ASSERT_EQ(agraph.add_op(&matmul1), status::success); + ASSERT_EQ(agraph.add_op(&matmul2), status::success); + ASSERT_EQ(agraph.add_op(&multiply), status::success); + agraph.finalize(); + + std::vector fusion_ops; + EXPECT_TRUE(match_pattern(agraph.get_ops()[0].get(), graphp, fusion_ops)); + ASSERT_EQ(fusion_ops.size(), 3U); + + // test with a graph that does not have the shared input + graph_t agraph2; + op_t matmul3 {0, MatMul, "matmul1"}; + op_t matmul4 {1, MatMul, "matmul2"}; + op_t multiply2 {2, Multiply, "multiply"}; + + std::vector lt_vec2 = create_logical_tensors(7); + matmul3.add_input(lt_vec2[0]); + matmul3.add_input(lt_vec2[1]); + matmul3.add_output(lt_vec2[2]); + matmul4.add_input(lt_vec2[3]); + matmul4.add_input(lt_vec2[4]); + matmul4.add_output(lt_vec2[5]); + multiply2.add_input(lt_vec2[2]); + multiply2.add_input(lt_vec2[5]); + multiply2.add_output(lt_vec2[6]); + + ASSERT_EQ(agraph2.add_op(&matmul3), status::success); + ASSERT_EQ(agraph2.add_op(&matmul4), status::success); + ASSERT_EQ(agraph2.add_op(&multiply2), status::success); + agraph2.finalize(); + + std::vector fusion_ops2; + EXPECT_FALSE( + match_pattern(agraph2.get_ops()[0].get(), graphp, fusion_ops2)); + ASSERT_EQ(fusion_ops2.size(), 0U); +} + TEST(test_utils_pattern_matcher, ParallelMatmul) { auto graphp = std::make_shared(); // Pattern that captures shared input to three MatMuls