Skip to content

Commit

Permalink
backend: dnnl: use post binary for float psrc and int dst
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyu-intel authored and TaoLv committed Jan 10, 2023
1 parent 8dbca04 commit 26a9a5b
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 85 deletions.
58 changes: 37 additions & 21 deletions src/backend/dnnl/fusion_info.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022 Intel Corporation
* Copyright 2022-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -103,34 +103,50 @@ dnnl::primitive_attr make_dnnl_primitive_attr(
fused_op->get_attr<int64_t>(op_attr::alg_kind));
dnnl_pops.append_eltwise(scale, alg, alpha, beta);
} else if (fused_op_kind == op_kind::dnnl_binary) {
const auto alg = static_cast<dnnl::algorithm>(
fused_op->get_attr<int64_t>(op_attr::alg_kind));
const auto &extra_inputs = pop->get_unfused_input_indices();
if (pop->is_post_sum()) {
// post-sum
float scale = pop->get_scale();
int32_t zp = pop->get_zp();
float scale = pop->get_scale();
int32_t zp = pop->get_zp();
const auto psrc = op->get_input_value(extra_inputs[0])
->get_logical_tensor();
const auto dst = op->get_output_value(0)->get_logical_tensor();
// check if can use post-sum, otherwise use binary post ops
// algorithm should be binary_add
bool is_post_sum = alg == dnnl::algorithm::binary_add;
// base_op should not be eltwise or pool
is_post_sum = is_post_sum
&& !impl::utils::one_of(op->get_kind(),
op_kind::dnnl_eltwise, op_kind::dnnl_pool);
// only support one post-sum
is_post_sum = is_post_sum
&& !(op->has_attr(op_attr::with_sum)
&& op->get_attr<bool>(op_attr::with_sum));
// post src and dst should have the same shape
is_post_sum = is_post_sum
&& impl::logical_tensor_wrapper_t(dst).vdims()
== impl::logical_tensor_wrapper_t(psrc).vdims();
// dst should have equal or larger memory size than post src
is_post_sum = is_post_sum
&& (psrc.data_type == dst.data_type
|| impl::utils::one_of(psrc.data_type,
impl::data_type::u8, impl::data_type::s8));
if (is_post_sum) {
pop->set_post_sum();
op->set_attr<bool>(op_attr::with_sum, true);
dnnl::memory::data_type sum_dt = dnnl::memory::data_type::undef;
if (op->get_kind() == op_kind::dnnl_convolution) {
const auto psrc_dt = op->get_input_value(extra_inputs[0])
->get_logical_tensor()
.data_type;
const auto dst_dt = op->get_output_value(0)
->get_logical_tensor()
.data_type;
if (psrc_dt == impl::data_type::s8
&& dst_dt == impl::data_type::u8) {
sum_dt = dnnl::memory::data_type::s8;
}
if (psrc.data_type == impl::data_type::s8
&& dst.data_type == impl::data_type::u8) {
sum_dt = dnnl::memory::data_type::s8;
}
dnnl_pops.append_sum(scale, zp, sum_dt);
} else {
// post-binary
assertm(extra_inputs.size() == 1,
"post-binary only has 1 extra input");
size_t src1_idx = extra_inputs[0];
auto input = op->get_input_value(src1_idx);
auto md = make_dnnl_memory_desc(input->get_logical_tensor());
const auto alg = static_cast<dnnl::algorithm>(
fused_op->get_attr<int64_t>(op_attr::alg_kind));
assertm(scale == 1.f && zp == 0,
"post-binary doesn't support input scale and zp");
auto md = make_dnnl_memory_desc(psrc);
dnnl_pops.append_binary(alg, md);
}
} else if (fused_op_kind == op_kind::dnnl_convolution) {
Expand Down
25 changes: 10 additions & 15 deletions src/backend/dnnl/fusion_info.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022 Intel Corporation
* Copyright 2022-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -60,9 +60,8 @@ class fusion_info_t {
: op_(op)
, scale_(scale)
, zp_(zp)
, unfused_input_indices_(extra_input_indices)
, is_post_sum_(true) {};
// for post-binary and post-conv
, unfused_input_indices_(extra_input_indices) {};
// for post-conv
meta_op_t(const op_ptr &op,
const std::vector<size_t> &extra_input_indices)
: op_(op), unfused_input_indices_(extra_input_indices) {};
Expand All @@ -83,6 +82,8 @@ class fusion_info_t {
return op_->get_kind() == op_kind::dnnl_binary && !is_post_sum_;
}

void set_post_sum() { is_post_sum_ = true; }

private:
std::shared_ptr<impl::op_t> op_;
// used to represent post-eltwise and post-sum's scale
Expand Down Expand Up @@ -139,13 +140,6 @@ class fusion_info_t {
post_ops_.emplace_back(std::make_shared<meta_op_t>(op, scale));
}

void append_post_sum(const op_ptr &op,
const std::vector<size_t> &extra_input_indices, float scale = 1.0f,
int32_t zp = 0) {
post_ops_.emplace_back(std::make_shared<meta_op_t>(
op, extra_input_indices, scale, zp));
}

// the extra input means the unfused input that has been added to the fused
// op, like the following case, we fuse a binary mul into the conv, the src1
// of mul op is unfused, and it becomes the 3rd input of conv. So the extra
Expand All @@ -157,10 +151,11 @@ class fusion_info_t {
// \ / |
// mul
// |
void append_post_binary(
const op_ptr &op, const std::vector<size_t> &extra_input_indices) {
post_ops_.emplace_back(
std::make_shared<meta_op_t>(op, extra_input_indices));
void append_post_binary(const op_ptr &op,
const std::vector<size_t> &extra_input_indices, float scale = 1.0f,
int32_t zp = 0) {
post_ops_.emplace_back(std::make_shared<meta_op_t>(
op, extra_input_indices, scale, zp));
}

// the meaning of extra input is same as that in append_post_binary function
Expand Down
53 changes: 4 additions & 49 deletions src/backend/dnnl/passes/transform.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2022 Intel Corporation
* Copyright 2021-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1244,57 +1244,12 @@ status_t fuse_post_ops(std::shared_ptr<subgraph_t> &sg) {
rewriter.fuse_op_to_successor(
mul_scale_op.shared_from_this());

fusion_info.append_post_sum(post_op->shared_from_this(),
fusion_info.append_post_binary(post_op->shared_from_this(),
std::vector<size_t> {base_op->num_inputs()},
scales[0], zp);
assertm(!base_op->has_attr(op_attr::with_sum)
|| !base_op->get_attr<bool>(
op_attr::with_sum),
"not support multiple post sum ops "
"currently.");
base_op->set_attr<bool>(op_attr::with_sum, true);
} else {
// - the add operation may need broadcast
auto fused_in = post_op->get_input_value(
fuse_op_predecessor_offset);
auto other_in = post_op->get_input_value(
1 - fuse_op_predecessor_offset);
auto dst = post_op->get_output_value(0);

if (ltw(fused_in->get_logical_tensor()).vdims()
== ltw(other_in->get_logical_tensor()).vdims()) {
if (base_op->get_kind() == op_kind::dnnl_eltwise
|| base_op->get_kind() == op_kind::dnnl_pool) {
fusion_info.append_post_binary(
post_op->shared_from_this(),
std::vector<size_t> {
base_op->num_inputs()});
} else {
// use sum post-ops for no-broadcast add
// map non-first post-sum to post-binary_add
if (base_op->has_attr(op_attr::with_sum)
&& base_op->get_attr<bool>(
op_attr::with_sum)) {
fusion_info.append_post_binary(
post_op->shared_from_this(),
std::vector<size_t> {
base_op->num_inputs()});
} else {
fusion_info.append_post_sum(
post_op->shared_from_this(),
std::vector<size_t> {
base_op->num_inputs()},
1.0f, 0);
base_op->set_attr<bool>(
op_attr::with_sum, true);
}
}
} else {
// use binary post-ops for broadcast add
fusion_info.append_post_binary(
post_op->shared_from_this(),
std::vector<size_t> {base_op->num_inputs()});
}
fusion_info.append_post_binary(post_op->shared_from_this(),
std::vector<size_t> {base_op->num_inputs()});
}
} else if (post_op->get_kind() == op_kind::dnnl_binary
&& static_cast<dnnl::algorithm>(
Expand Down

0 comments on commit 26a9a5b

Please sign in to comment.