Skip to content

Commit

Permalink
graph: backend: remove fold_sum_scales pass
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyu-intel authored and TaoLv committed Feb 6, 2023
1 parent b8d21a5 commit 699ba75
Showing 1 changed file with 0 additions and 82 deletions.
82 changes: 0 additions & 82 deletions src/graph/backend/dnnl/passes/transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,88 +461,6 @@ status_t fold_mul_scales(std::shared_ptr<subgraph_t> &sg) {
return status::success;
}

status_t fold_sum_scales(std::shared_ptr<subgraph_t> &sg) {
std::set<op_t *> visited;
subgraph_rewriter_t rewriter(sg);

for (auto &cur_op : sg->get_ops()) {
if (!(cur_op->get_kind() == op_kind::dnnl_binary
&& static_cast<dnnl::algorithm>(
cur_op->get_attr<int64_t>(op_attr::alg_kind))
== dnnl::algorithm::binary_add)
|| visited.count(cur_op.get()))
continue;

visited.insert(cur_op.get());
size_t mul_scale_op_offset = 2;
auto lhs_val = cur_op->get_input_value(0);
auto rhs_val = cur_op->get_input_value(1);

if (!lhs_val->has_producer() || !rhs_val->has_producer()) { continue; }
const auto &l_op = lhs_val->get_producer();
const auto &r_op = rhs_val->get_producer();

auto consumers = cur_op->get_output_values()[0]->get_consumers();
if (consumers.empty()
|| consumers[0].get_op().get_kind()
!= op_kind::dnnl_mul_scales) {
continue;
}

if (l_op.get_kind() != op_kind::dnnl_mul_scales
|| r_op.get_kind() != op_kind::dnnl_mul_scales) {
continue;
}
if (l_op.num_inputs() > 0 && l_op.get_input_value(0)->has_producer()
&& l_op.get_input_value(0)->get_producer().get_kind()
== op_kind::dnnl_reorder) {
mul_scale_op_offset = 1;
} else if (r_op.num_inputs() > 0
&& r_op.get_input_value(0)->has_producer()
&& r_op.get_input_value(0)->get_producer().get_kind()
== op_kind::dnnl_reorder) {
mul_scale_op_offset = 0;
}

if (mul_scale_op_offset != 2
&& ltw(lhs_val->get_logical_tensor()).vdims()
== ltw(rhs_val->get_logical_tensor()).vdims()) {
auto in_val = cur_op->get_input_value(mul_scale_op_offset);
auto &mul_scale_op = in_val->get_producer();
auto scales = mul_scale_op.get_attr<std::vector<float>>(
op_attr::scales);
assert(scales.size() == 1); // per tensor

auto tmp = mul_scale_op.get_input_value(0);
auto &add_zps_op = tmp->get_producer();
auto zps = add_zps_op.get_attr<std::vector<int64_t>>(op_attr::zps);
assert(scales.size() == zps.size());

auto out_val = cur_op->get_output_values()[0];
auto consumers = out_val->get_consumers();
auto &next_op = consumers[0].get_op();
// set sum post-ops' second input scale
float tmp_scale
= next_op.get_attr<std::vector<float>>(op_attr::scales)[0];
scales[0] *= tmp_scale;
mul_scale_op.set_attr<std::vector<float>>(op_attr::scales, scales);

// update the output scales
auto other_val = cur_op->get_input_value(1 - mul_scale_op_offset);
auto &oscales_op = other_val->get_producer();
auto oscales
= oscales_op.get_attr<std::vector<float>>(op_attr::scales);
for (auto &v : oscales)
v *= tmp_scale;
oscales_op.set_attr<std::vector<float>>(op_attr::scales, oscales);
rewriter.fuse_op_to_predecessor(next_op.shared_from_this());
}
}

rewriter.run();
return status::success;
}

// FIXME(xx) This pass works correctly only when all inputs/outputs scales/zps
// are same, since we are simply ignoring the scales and zps. We can improve
// this pass to support different per-tensor scale since oneDNN concat primitive
Expand Down

0 comments on commit 699ba75

Please sign in to comment.