From b8d21a58d8247097ed26816b730e3cd4c19f61c2 Mon Sep 17 00:00:00 2001 From: "Chen, Xinyu1" Date: Thu, 5 Jan 2023 11:53:38 +0800 Subject: [PATCH] graph: backend: seperate checking post-sum from fuser --- src/graph/backend/dnnl/fusion_info.cpp | 74 +++++++++++---------- src/graph/backend/dnnl/fusion_info.hpp | 29 +++----- src/graph/backend/dnnl/passes/transform.cpp | 53 ++------------- 3 files changed, 51 insertions(+), 105 deletions(-) diff --git a/src/graph/backend/dnnl/fusion_info.cpp b/src/graph/backend/dnnl/fusion_info.cpp index f13da5e5833..2778e9ee27f 100644 --- a/src/graph/backend/dnnl/fusion_info.cpp +++ b/src/graph/backend/dnnl/fusion_info.cpp @@ -141,48 +141,50 @@ dnnl::primitive_attr make_dnnl_primitive_attr( fused_op->get_attr(op_attr::alg_kind)); dnnl_pops.append_eltwise(alg, alpha, beta); } else if (fused_op_kind == op_kind::dnnl_binary) { + const auto alg = static_cast( + fused_op->get_attr(op_attr::alg_kind)); const auto &extra_inputs = pop->get_unfused_input_indices(); - if (pop->is_post_sum()) { - // post-sum - float scale = pop->get_scale(); - int32_t zp = pop->get_zp(); - const auto psrc_dt = op->get_input_value(extra_inputs[0]) - ->get_logical_tensor() - .data_type; - const auto dst_dt = op->get_output_value(0) - ->get_logical_tensor() - .data_type; - // note that onednn doesn't support float post-sum with u8/s8 - // dst. use post-binary for such case instead. - if (impl::utils::one_of( - dst_dt, impl::data_type::u8, impl::data_type::s8) - && impl::utils::one_of(psrc_dt, impl::data_type::f32, - impl::data_type::bf16) - && scale == 1.f && zp == 0) { - auto input = op->get_input_value(extra_inputs[0]); - auto md = make_dnnl_memory_desc( - input->get_logical_tensor()); - dnnl_pops.append_binary(dnnl::algorithm::binary_add, md); - op->remove_attr(op_attr::with_sum); - pop->to_post_binary(); - } else { - dnnl::memory::data_type sum_dt - = dnnl::memory::data_type::undef; - if (psrc_dt == impl::data_type::s8 - && dst_dt == impl::data_type::u8) { - sum_dt = dnnl::memory::data_type::s8; - } - dnnl_pops.append_sum(scale, zp, sum_dt); + float scale = pop->get_scale(); + int32_t zp = pop->get_zp(); + const auto psrc = op->get_input_value(extra_inputs[0]) + ->get_logical_tensor(); + const auto dst = op->get_output_value(0)->get_logical_tensor(); + // check if can use post-sum, otherwise use binary post ops + // algorithm should be binary_add + bool is_post_sum = alg == dnnl::algorithm::binary_add; + // base_op should not be eltwise or pool + is_post_sum = is_post_sum + && !impl::utils::one_of(op->get_kind(), + op_kind::dnnl_eltwise, op_kind::dnnl_pool); + // only support one post-sum + is_post_sum = is_post_sum + && !(op->has_attr(op_attr::with_sum) + && op->get_attr(op_attr::with_sum)); + // post src and dst should have the same shape + is_post_sum = is_post_sum + && logical_tensor_wrapper_t(dst).vdims() + == logical_tensor_wrapper_t(psrc).vdims(); + // dst should have equal or larger memory size than post src + is_post_sum = is_post_sum + && (psrc.data_type == dst.data_type + || impl::utils::one_of(psrc.data_type, + impl::data_type::u8, impl::data_type::s8)); + if (is_post_sum) { + pop->set_post_sum(); + op->set_attr(op_attr::with_sum, true); + dnnl::memory::data_type sum_dt = dnnl::memory::data_type::undef; + if (psrc.data_type == impl::data_type::s8 + && dst.data_type == impl::data_type::u8) { + sum_dt = dnnl::memory::data_type::s8; } + dnnl_pops.append_sum(scale, zp, sum_dt); } else { // post-binary assertm(extra_inputs.size() == 1, "post-binary only has 1 extra input"); - size_t src1_idx = extra_inputs[0]; - auto input = op->get_input_value(src1_idx); - auto md = make_dnnl_memory_desc(input->get_logical_tensor()); - const auto alg = static_cast( - fused_op->get_attr(op_attr::alg_kind)); + assertm(scale == 1.f && zp == 0, + "post-binary doesn't support input scale and zp"); + auto md = make_dnnl_memory_desc(psrc); dnnl_pops.append_binary(alg, md); } } else if (fused_op_kind == op_kind::dnnl_convolution) { diff --git a/src/graph/backend/dnnl/fusion_info.hpp b/src/graph/backend/dnnl/fusion_info.hpp index 46688a0d252..8785f0532ea 100644 --- a/src/graph/backend/dnnl/fusion_info.hpp +++ b/src/graph/backend/dnnl/fusion_info.hpp @@ -54,16 +54,15 @@ class fusion_info_t { meta_op_t(const op_ptr &op) : op_(op) {}; // for post-eltwise meta_op_t(const op_ptr &op, float scale) : op_(op), scale_(scale) {}; - // for post-sum + // for post-sum and post_binary meta_op_t(const op_ptr &op, const std::vector &extra_input_indices, float scale, int32_t zp) : op_(op) , scale_(scale) , zp_(zp) - , unfused_input_indices_(extra_input_indices) - , is_post_sum_(true) {}; - // for post-binary and post-conv + , unfused_input_indices_(extra_input_indices) {}; + // for post-conv meta_op_t(const op_ptr &op, const std::vector &extra_input_indices) : op_(op), unfused_input_indices_(extra_input_indices) {}; @@ -84,11 +83,7 @@ class fusion_info_t { return op_->get_kind() == op_kind::dnnl_binary && !is_post_sum_; } - void to_post_binary() { - assertm(scale_ == 1.0f && zp_ == 0, - "post bianry cannot support scale and zp!"); - is_post_sum_ = false; - } + void set_post_sum() { is_post_sum_ = true; } private: std::shared_ptr op_; @@ -158,13 +153,6 @@ class fusion_info_t { post_ops_.emplace_back(std::make_shared(op, scale)); } - void append_post_sum(const op_ptr &op, - const std::vector &extra_input_indices, float scale = 1.0f, - int32_t zp = 0) { - post_ops_.emplace_back(std::make_shared( - op, extra_input_indices, scale, zp)); - } - // the extra input means the unfused input that has been added to the fused // op, like the following case, we fuse a binary mul into the conv, the src1 // of mul op is unfused, and it becomes the 3rd input of conv. So the extra @@ -176,10 +164,11 @@ class fusion_info_t { // \ / | // mul // | - void append_post_binary( - const op_ptr &op, const std::vector &extra_input_indices) { - post_ops_.emplace_back( - std::make_shared(op, extra_input_indices)); + void append_post_binary(const op_ptr &op, + const std::vector &extra_input_indices, float scale = 1.0f, + int32_t zp = 0) { + post_ops_.emplace_back(std::make_shared( + op, extra_input_indices, scale, zp)); } // the meaning of extra input is same as that in append_post_binary function diff --git a/src/graph/backend/dnnl/passes/transform.cpp b/src/graph/backend/dnnl/passes/transform.cpp index c8d6108b4a1..9009447c37a 100644 --- a/src/graph/backend/dnnl/passes/transform.cpp +++ b/src/graph/backend/dnnl/passes/transform.cpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright 2021-2022 Intel Corporation + * Copyright 2021-2023 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -937,57 +937,12 @@ status_t fuse_post_ops(std::shared_ptr &sg) { rewriter.fuse_op_to_successor( mul_scale_op.shared_from_this()); - fusion_info.append_post_sum(post_op->shared_from_this(), + fusion_info.append_post_binary(post_op->shared_from_this(), std::vector {base_op->num_inputs()}, scales[0], zp); - assertm(!base_op->has_attr(op_attr::with_sum) - || !base_op->get_attr( - op_attr::with_sum), - "not support multiple post sum ops " - "currently."); - base_op->set_attr(op_attr::with_sum, true); } else { - // - the add operation may need broadcast - auto fused_in = post_op->get_input_value( - fuse_op_predecessor_offset); - auto other_in = post_op->get_input_value( - 1 - fuse_op_predecessor_offset); - auto dst = post_op->get_output_value(0); - - if (ltw(fused_in->get_logical_tensor()).vdims() - == ltw(other_in->get_logical_tensor()).vdims()) { - if (base_op->get_kind() == op_kind::dnnl_eltwise - || base_op->get_kind() == op_kind::dnnl_pool) { - fusion_info.append_post_binary( - post_op->shared_from_this(), - std::vector { - base_op->num_inputs()}); - } else { - // use sum post-ops for no-broadcast add - // map non-first post-sum to post-binary_add - if (base_op->has_attr(op_attr::with_sum) - && base_op->get_attr( - op_attr::with_sum)) { - fusion_info.append_post_binary( - post_op->shared_from_this(), - std::vector { - base_op->num_inputs()}); - } else { - fusion_info.append_post_sum( - post_op->shared_from_this(), - std::vector { - base_op->num_inputs()}, - 1.0f, 0); - base_op->set_attr( - op_attr::with_sum, true); - } - } - } else { - // use binary post-ops for broadcast add - fusion_info.append_post_binary( - post_op->shared_from_this(), - std::vector {base_op->num_inputs()}); - } + fusion_info.append_post_binary(post_op->shared_from_this(), + std::vector {base_op->num_inputs()}); } } else if (post_op->get_kind() == op_kind::dnnl_binary && static_cast(