Skip to content

Commit

Permalink
graph: backend: dnnl: patterns: check input ndims for LayerNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
rongzha1 authored and TaoLv committed Jul 24, 2024
1 parent 730b976 commit f704f09
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/graph/backend/dnnl/patterns/layernorm_fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, layernorm_post_ops_fusion_cpu)
1>);
layernorm_base->append_decision_function(
check_begin_norm_axis_attr);
// primitive only support 2-5D data tensor for layernorm
layernorm_base->append_decision_function(
check_input_ndim_from_offset<0, 2, 5>);

// optional typecast
auto tc_graph = std::make_shared<pb_graph_t>();
Expand Down
6 changes: 6 additions & 0 deletions src/graph/backend/dnnl/patterns/single_op_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, ln_pass)
1>);
p_layernorm->append_decision_function(
check_begin_norm_axis_attr);
// primitive only support 2-5D data tensor for layernorm
p_layernorm->append_decision_function(
check_input_ndim_from_offset<0, 2, 5>);
})
.set_attr<FCreateKernel>("FCreateKernel", []() -> kernel_ptr {
return std::make_shared<layernorm_fwd_t>();
Expand All @@ -212,6 +215,9 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, ln_bw_pass)
2>);
p_layernorm_bwd->append_decision_function(
check_begin_norm_axis_attr);
// primitive only support 2-5D data tensor for layernorm
p_layernorm_bwd->append_decision_function(
check_input_ndim_from_offset<0, 2, 5>);
})
.set_attr<FCreateKernel>("FCreateKernel", []() -> kernel_ptr {
return std::make_shared<layernorm_bwd_t>();
Expand Down
13 changes: 13 additions & 0 deletions src/graph/backend/dnnl/patterns/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,19 @@ inline bool check_begin_norm_axis_attr(const op_t *op) {
return true;
}

// min <= input[offset]->ndims() <= max
template <size_t OFFSET, int32_t MIN, int32_t MAX>
inline bool check_input_ndim_from_offset(const op_t *op) {
if (OFFSET >= op->num_inputs()) return false;
const logical_tensor_t &src_lt
= op->get_input_value(OFFSET)->get_logical_tensor();
const auto src_lt_wrapper = logical_tensor_wrapper_t(src_lt);
const auto ndims = src_lt_wrapper.ndims();

if (ndims == DNNL_GRAPH_UNKNOWN_NDIMS) return true;
return ndims >= MIN && ndims <= MAX;
}

inline const std::vector<op_kind_t> &get_unary_ops() {
const static std::vector<op_kind_t> unary = {
graph::op_kind::Abs,
Expand Down
48 changes: 48 additions & 0 deletions tests/gtests/graph/unit/backend/dnnl/test_layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,54 @@
namespace graph = dnnl::impl::graph;
namespace utils = dnnl::graph::tests::unit::utils;

// primitive only support 2-5D data tensor for layernorm,
TEST(test_layer_norm_execute, LayernormNDimCheck) {
graph::engine_t *engine = get_engine();

std::vector<std::vector<int64_t>> src_shapes {
{}, {2}, {2, 3, 4, 5, 6, 7}, {2, 3}};
std::vector<size_t> expected_partition_num {0, 0, 0, 1, 1};

graph::logical_tensor_t scale_lt
= utils::logical_tensor_init(1, graph::data_type::f32);
graph::logical_tensor_t shift_lt
= utils::logical_tensor_init(2, graph::data_type::f32);
graph::logical_tensor_t dst_lt
= utils::logical_tensor_init(3, graph::data_type::f32);
graph::logical_tensor_t mean_lt
= utils::logical_tensor_init(4, graph::data_type::f32);
graph::logical_tensor_t variance_lt
= utils::logical_tensor_init(5, graph::data_type::f32);

// the last ndim is DNNL_GRAPH_UNKNOWN_NDIMS
for (size_t i = 0; i < src_shapes.size() + 1; i++) {
const auto &src_shape = src_shapes[i];
graph::logical_tensor_t src_lt;
if (i == src_shapes.size())
// ndim is DNNL_GRAPH_UNKNOWN_NDIMS
src_lt = utils::logical_tensor_init(0, graph::data_type::f32);
else
src_lt = utils::logical_tensor_init(
0, src_shape, graph::data_type::f32);
graph::graph_t g(engine->kind());

graph::op_t layernorm_op(graph::op_kind::LayerNorm);
layernorm_op.add_input(src_lt);
layernorm_op.add_input(scale_lt);
layernorm_op.add_input(shift_lt);
layernorm_op.add_output(dst_lt);
layernorm_op.add_output(mean_lt);
layernorm_op.add_output(variance_lt);

ASSERT_EQ(g.add_op(&layernorm_op), graph::status::success);
g.finalize();

graph::pass::pass_base_ptr apass = get_pass("ln_pass");
apass->run(g);
ASSERT_EQ(g.get_num_partitions(), expected_partition_num[i]);
}
}

TEST(test_layer_norm_execute, LayernormTraining) {
graph::engine_t *eng = get_engine();

Expand Down

0 comments on commit f704f09

Please sign in to comment.