Skip to content

Commit

Permalink
graph: backend: seperate checking post-sum from fuser
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyu-intel authored and TaoLv committed Feb 6, 2023
1 parent 9421fb2 commit b8d21a5
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 105 deletions.
74 changes: 38 additions & 36 deletions src/graph/backend/dnnl/fusion_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,48 +141,50 @@ dnnl::primitive_attr make_dnnl_primitive_attr(
fused_op->get_attr<int64_t>(op_attr::alg_kind));
dnnl_pops.append_eltwise(alg, alpha, beta);
} else if (fused_op_kind == op_kind::dnnl_binary) {
const auto alg = static_cast<dnnl::algorithm>(
fused_op->get_attr<int64_t>(op_attr::alg_kind));
const auto &extra_inputs = pop->get_unfused_input_indices();
if (pop->is_post_sum()) {
// post-sum
float scale = pop->get_scale();
int32_t zp = pop->get_zp();
const auto psrc_dt = op->get_input_value(extra_inputs[0])
->get_logical_tensor()
.data_type;
const auto dst_dt = op->get_output_value(0)
->get_logical_tensor()
.data_type;
// note that onednn doesn't support float post-sum with u8/s8
// dst. use post-binary for such case instead.
if (impl::utils::one_of(
dst_dt, impl::data_type::u8, impl::data_type::s8)
&& impl::utils::one_of(psrc_dt, impl::data_type::f32,
impl::data_type::bf16)
&& scale == 1.f && zp == 0) {
auto input = op->get_input_value(extra_inputs[0]);
auto md = make_dnnl_memory_desc(
input->get_logical_tensor());
dnnl_pops.append_binary(dnnl::algorithm::binary_add, md);
op->remove_attr(op_attr::with_sum);
pop->to_post_binary();
} else {
dnnl::memory::data_type sum_dt
= dnnl::memory::data_type::undef;
if (psrc_dt == impl::data_type::s8
&& dst_dt == impl::data_type::u8) {
sum_dt = dnnl::memory::data_type::s8;
}
dnnl_pops.append_sum(scale, zp, sum_dt);
float scale = pop->get_scale();
int32_t zp = pop->get_zp();
const auto psrc = op->get_input_value(extra_inputs[0])
->get_logical_tensor();
const auto dst = op->get_output_value(0)->get_logical_tensor();
// check if can use post-sum, otherwise use binary post ops
// algorithm should be binary_add
bool is_post_sum = alg == dnnl::algorithm::binary_add;
// base_op should not be eltwise or pool
is_post_sum = is_post_sum
&& !impl::utils::one_of(op->get_kind(),
op_kind::dnnl_eltwise, op_kind::dnnl_pool);
// only support one post-sum
is_post_sum = is_post_sum
&& !(op->has_attr(op_attr::with_sum)
&& op->get_attr<bool>(op_attr::with_sum));
// post src and dst should have the same shape
is_post_sum = is_post_sum
&& logical_tensor_wrapper_t(dst).vdims()
== logical_tensor_wrapper_t(psrc).vdims();
// dst should have equal or larger memory size than post src
is_post_sum = is_post_sum
&& (psrc.data_type == dst.data_type
|| impl::utils::one_of(psrc.data_type,
impl::data_type::u8, impl::data_type::s8));
if (is_post_sum) {
pop->set_post_sum();
op->set_attr<bool>(op_attr::with_sum, true);
dnnl::memory::data_type sum_dt = dnnl::memory::data_type::undef;
if (psrc.data_type == impl::data_type::s8
&& dst.data_type == impl::data_type::u8) {
sum_dt = dnnl::memory::data_type::s8;
}
dnnl_pops.append_sum(scale, zp, sum_dt);
} else {
// post-binary
assertm(extra_inputs.size() == 1,
"post-binary only has 1 extra input");
size_t src1_idx = extra_inputs[0];
auto input = op->get_input_value(src1_idx);
auto md = make_dnnl_memory_desc(input->get_logical_tensor());
const auto alg = static_cast<dnnl::algorithm>(
fused_op->get_attr<int64_t>(op_attr::alg_kind));
assertm(scale == 1.f && zp == 0,
"post-binary doesn't support input scale and zp");
auto md = make_dnnl_memory_desc(psrc);
dnnl_pops.append_binary(alg, md);
}
} else if (fused_op_kind == op_kind::dnnl_convolution) {
Expand Down
29 changes: 9 additions & 20 deletions src/graph/backend/dnnl/fusion_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,15 @@ class fusion_info_t {
meta_op_t(const op_ptr &op) : op_(op) {};
// for post-eltwise
meta_op_t(const op_ptr &op, float scale) : op_(op), scale_(scale) {};
// for post-sum
// for post-sum and post_binary
meta_op_t(const op_ptr &op,
const std::vector<size_t> &extra_input_indices, float scale,
int32_t zp)
: op_(op)
, scale_(scale)
, zp_(zp)
, unfused_input_indices_(extra_input_indices)
, is_post_sum_(true) {};
// for post-binary and post-conv
, unfused_input_indices_(extra_input_indices) {};
// for post-conv
meta_op_t(const op_ptr &op,
const std::vector<size_t> &extra_input_indices)
: op_(op), unfused_input_indices_(extra_input_indices) {};
Expand All @@ -84,11 +83,7 @@ class fusion_info_t {
return op_->get_kind() == op_kind::dnnl_binary && !is_post_sum_;
}

void to_post_binary() {
assertm(scale_ == 1.0f && zp_ == 0,
"post bianry cannot support scale and zp!");
is_post_sum_ = false;
}
void set_post_sum() { is_post_sum_ = true; }

private:
std::shared_ptr<op_t> op_;
Expand Down Expand Up @@ -158,13 +153,6 @@ class fusion_info_t {
post_ops_.emplace_back(std::make_shared<meta_op_t>(op, scale));
}

void append_post_sum(const op_ptr &op,
const std::vector<size_t> &extra_input_indices, float scale = 1.0f,
int32_t zp = 0) {
post_ops_.emplace_back(std::make_shared<meta_op_t>(
op, extra_input_indices, scale, zp));
}

// the extra input means the unfused input that has been added to the fused
// op, like the following case, we fuse a binary mul into the conv, the src1
// of mul op is unfused, and it becomes the 3rd input of conv. So the extra
Expand All @@ -176,10 +164,11 @@ class fusion_info_t {
// \ / |
// mul
// |
void append_post_binary(
const op_ptr &op, const std::vector<size_t> &extra_input_indices) {
post_ops_.emplace_back(
std::make_shared<meta_op_t>(op, extra_input_indices));
void append_post_binary(const op_ptr &op,
const std::vector<size_t> &extra_input_indices, float scale = 1.0f,
int32_t zp = 0) {
post_ops_.emplace_back(std::make_shared<meta_op_t>(
op, extra_input_indices, scale, zp));
}

// the meaning of extra input is same as that in append_post_binary function
Expand Down
53 changes: 4 additions & 49 deletions src/graph/backend/dnnl/passes/transform.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2022 Intel Corporation
* Copyright 2021-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -937,57 +937,12 @@ status_t fuse_post_ops(std::shared_ptr<subgraph_t> &sg) {
rewriter.fuse_op_to_successor(
mul_scale_op.shared_from_this());

fusion_info.append_post_sum(post_op->shared_from_this(),
fusion_info.append_post_binary(post_op->shared_from_this(),
std::vector<size_t> {base_op->num_inputs()},
scales[0], zp);
assertm(!base_op->has_attr(op_attr::with_sum)
|| !base_op->get_attr<bool>(
op_attr::with_sum),
"not support multiple post sum ops "
"currently.");
base_op->set_attr<bool>(op_attr::with_sum, true);
} else {
// - the add operation may need broadcast
auto fused_in = post_op->get_input_value(
fuse_op_predecessor_offset);
auto other_in = post_op->get_input_value(
1 - fuse_op_predecessor_offset);
auto dst = post_op->get_output_value(0);

if (ltw(fused_in->get_logical_tensor()).vdims()
== ltw(other_in->get_logical_tensor()).vdims()) {
if (base_op->get_kind() == op_kind::dnnl_eltwise
|| base_op->get_kind() == op_kind::dnnl_pool) {
fusion_info.append_post_binary(
post_op->shared_from_this(),
std::vector<size_t> {
base_op->num_inputs()});
} else {
// use sum post-ops for no-broadcast add
// map non-first post-sum to post-binary_add
if (base_op->has_attr(op_attr::with_sum)
&& base_op->get_attr<bool>(
op_attr::with_sum)) {
fusion_info.append_post_binary(
post_op->shared_from_this(),
std::vector<size_t> {
base_op->num_inputs()});
} else {
fusion_info.append_post_sum(
post_op->shared_from_this(),
std::vector<size_t> {
base_op->num_inputs()},
1.0f, 0);
base_op->set_attr<bool>(
op_attr::with_sum, true);
}
}
} else {
// use binary post-ops for broadcast add
fusion_info.append_post_binary(
post_op->shared_from_this(),
std::vector<size_t> {base_op->num_inputs()});
}
fusion_info.append_post_binary(post_op->shared_from_this(),
std::vector<size_t> {base_op->num_inputs()});
}
} else if (post_op->get_kind() == op_kind::dnnl_binary
&& static_cast<dnnl::algorithm>(
Expand Down

0 comments on commit b8d21a5

Please sign in to comment.