diff --git a/src/graph/backend/dnnl/patterns/layernorm_fusion.cpp b/src/graph/backend/dnnl/patterns/layernorm_fusion.cpp index 598cc950d8c..e1ddf9db749 100644 --- a/src/graph/backend/dnnl/patterns/layernorm_fusion.cpp +++ b/src/graph/backend/dnnl/patterns/layernorm_fusion.cpp @@ -54,6 +54,9 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, layernorm_post_ops_fusion_cpu) 1>); layernorm_base->append_decision_function( check_begin_norm_axis_attr); + // primitive only support 2-5D data tensor for layernorm + layernorm_base->append_decision_function( + check_input_ndim_from_offset<0, 2, 5>); // optional typecast auto tc_graph = std::make_shared(); diff --git a/src/graph/backend/dnnl/patterns/single_op_pattern.cpp b/src/graph/backend/dnnl/patterns/single_op_pattern.cpp index ff47f7c7bba..4f9a48b1b80 100644 --- a/src/graph/backend/dnnl/patterns/single_op_pattern.cpp +++ b/src/graph/backend/dnnl/patterns/single_op_pattern.cpp @@ -194,6 +194,9 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, ln_pass) 1>); p_layernorm->append_decision_function( check_begin_norm_axis_attr); + // primitive only support 2-5D data tensor for layernorm + p_layernorm->append_decision_function( + check_input_ndim_from_offset<0, 2, 5>); }) .set_attr("FCreateKernel", []() -> kernel_ptr { return std::make_shared(); @@ -212,6 +215,9 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, ln_bw_pass) 2>); p_layernorm_bwd->append_decision_function( check_begin_norm_axis_attr); + // primitive only support 2-5D data tensor for layernorm + p_layernorm_bwd->append_decision_function( + check_input_ndim_from_offset<0, 2, 5>); }) .set_attr("FCreateKernel", []() -> kernel_ptr { return std::make_shared(); diff --git a/src/graph/backend/dnnl/patterns/utils.hpp b/src/graph/backend/dnnl/patterns/utils.hpp index 2c04c401fca..424c069e2b2 100644 --- a/src/graph/backend/dnnl/patterns/utils.hpp +++ b/src/graph/backend/dnnl/patterns/utils.hpp @@ -122,6 +122,19 @@ inline bool check_begin_norm_axis_attr(const op_t *op) { return true; } +// min <= input[offset]->ndims() <= max +template +inline bool check_input_ndim_from_offset(const op_t *op) { + if (OFFSET >= op->num_inputs()) return false; + const logical_tensor_t &src_lt + = op->get_input_value(OFFSET)->get_logical_tensor(); + const auto src_lt_wrapper = logical_tensor_wrapper_t(src_lt); + const auto ndims = src_lt_wrapper.ndims(); + + if (ndims == DNNL_GRAPH_UNKNOWN_NDIMS) return true; + return ndims >= MIN && ndims <= MAX; +} + inline const std::vector &get_unary_ops() { const static std::vector unary = { graph::op_kind::Abs, diff --git a/tests/gtests/graph/unit/backend/dnnl/test_layer_norm.cpp b/tests/gtests/graph/unit/backend/dnnl/test_layer_norm.cpp index 89e422c5550..05ed1d6113e 100644 --- a/tests/gtests/graph/unit/backend/dnnl/test_layer_norm.cpp +++ b/tests/gtests/graph/unit/backend/dnnl/test_layer_norm.cpp @@ -24,6 +24,54 @@ namespace graph = dnnl::impl::graph; namespace utils = dnnl::graph::tests::unit::utils; +// primitive only support 2-5D data tensor for layernorm, +TEST(test_layer_norm_execute, LayernormNDimCheck) { + graph::engine_t *engine = get_engine(); + + std::vector> src_shapes { + {}, {2}, {2, 3, 4, 5, 6, 7}, {2, 3}}; + std::vector expected_partition_num {0, 0, 0, 1, 1}; + + graph::logical_tensor_t scale_lt + = utils::logical_tensor_init(1, graph::data_type::f32); + graph::logical_tensor_t shift_lt + = utils::logical_tensor_init(2, graph::data_type::f32); + graph::logical_tensor_t dst_lt + = utils::logical_tensor_init(3, graph::data_type::f32); + graph::logical_tensor_t mean_lt + = utils::logical_tensor_init(4, graph::data_type::f32); + graph::logical_tensor_t variance_lt + = utils::logical_tensor_init(5, graph::data_type::f32); + + // the last ndim is DNNL_GRAPH_UNKNOWN_NDIMS + for (size_t i = 0; i < src_shapes.size() + 1; i++) { + const auto &src_shape = src_shapes[i]; + graph::logical_tensor_t src_lt; + if (i == src_shapes.size()) + // ndim is DNNL_GRAPH_UNKNOWN_NDIMS + src_lt = utils::logical_tensor_init(0, graph::data_type::f32); + else + src_lt = utils::logical_tensor_init( + 0, src_shape, graph::data_type::f32); + graph::graph_t g(engine->kind()); + + graph::op_t layernorm_op(graph::op_kind::LayerNorm); + layernorm_op.add_input(src_lt); + layernorm_op.add_input(scale_lt); + layernorm_op.add_input(shift_lt); + layernorm_op.add_output(dst_lt); + layernorm_op.add_output(mean_lt); + layernorm_op.add_output(variance_lt); + + ASSERT_EQ(g.add_op(&layernorm_op), graph::status::success); + g.finalize(); + + graph::pass::pass_base_ptr apass = get_pass("ln_pass"); + apass->run(g); + ASSERT_EQ(g.get_num_partitions(), expected_partition_num[i]); + } +} + TEST(test_layer_norm_execute, LayernormTraining) { graph::engine_t *eng = get_engine();