Skip to content

Commit

Permalink
graph: utils: pm: verify multi-consumers number
Browse files Browse the repository at this point in the history
  • Loading branch information
ElaineBao authored and TaoLv committed Dec 23, 2024
1 parent 1b30f9f commit 3d64643
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 12 deletions.
56 changes: 48 additions & 8 deletions src/graph/utils/pm/nested_matcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ bool match_pattern(op_t *first_op, const std::shared_ptr<pb_graph_t> &pattern,
DEBUG(DEBUGINFO_PM, "matching failed \n");
return false;
}
if (!verify_global_in_map(&init_ctx)) { return false; }

fusion_ops = reorder_matched_list(matched_op_map);
DEBUG(DEBUGINFO_PM, "finish matching, matched ops: ");
Expand All @@ -675,6 +676,28 @@ bool match_pattern(op_t *first_op, const std::shared_ptr<pb_graph_t> &pattern,
return true;
}

bool verify_global_in_map(match_context_t *ctx) {
pb_graph_t *graph = ctx->get_graph();
if (!graph) return true;

auto inner_cons = graph->get_inner_consumers();
if (inner_cons.empty()) return true;

for (size_t i = 0; i < inner_cons.size(); ++i) {
if (inner_cons[i].second.size() != ctx->in_port_map.count(i)) {
DEBUG(DEBUGINFO_PM,
"expected graph input %zu consumers size: %zu, actual "
"consumers size: %zu",
i, inner_cons[i].second.size(), ctx->in_port_map.count(i));
VPATTERN_MATCHER(
"matching failed: number of inputs check failed,%s:%i \n",
__FILE__, __LINE__);
return false;
}
}
return true;
}

inline std::vector<op_t *> reorder_matched_list(
const std::unordered_map<op_t *, pb_op_t *> &matched_op_map) {
// split ops and pb_op_ts
Expand Down Expand Up @@ -797,8 +820,18 @@ void fill_local_in_map(match_context_t *local_ctx, pb_node_t *cur_node,
for (size_t j = 0; j < inner_cons[i].second.size(); ++j) {
size_t iport = inner_cons[i].first;
const std::shared_ptr<consumer_t> &con = inner_cons[i].second[j];
if (con->first == cur_node)
local_ctx->in_port_map[iport] = {cur_op, cur_op_port};
if (con->first == cur_node) {
// check if the input port has been filled, if filled, the existing
// input should be the same as the current op
auto it = local_ctx->in_port_map.find(iport);
if (it != local_ctx->in_port_map.end()) {
if (it->second.first->get_input_value(it->second.second)
!= cur_op->get_input_value(cur_op_port)) {
return;
}
}
local_ctx->in_port_map.insert({iport, {cur_op, cur_op_port}});
}
}
}
}
Expand Down Expand Up @@ -931,8 +964,9 @@ bool match_alternation(const binding_t &bind_arg, match_context_t *ctx,
} else {
// alternation is restricted to have only 1 in port
if (local_ctx.in_port_map.size() != 1) return false;
op_t *current_op = local_ctx.in_port_map[0].first;
size_t current_port = local_ctx.in_port_map[0].second;
op_t *current_op = local_ctx.in_port_map.find(0)->second.first;
size_t current_port
= local_ctx.in_port_map.find(0)->second.second;
binding_t current_bind(BIND_OUT, current_op, current_port,
bind_arg.bind_node, bind_arg.bind_port);
return match_node_inputs(current_bind, ctx, matched_op_map);
Expand Down Expand Up @@ -1037,7 +1071,8 @@ bool repetition_matcher_t::prepare_next_matching_round(
// for next round's match
iport_t iport = pmap_.second;
// start op for last round's match
op_t *start_op = local_cached_ctx.in_port_map.at(iport).first;
op_t *start_op
= local_cached_ctx.in_port_map.find(iport)->second.first;
pb_op_t *start_pb_op = updated_op_map_[start_op];
op_t *next_op = nullptr;
size_t next_op_iport = 0;
Expand Down Expand Up @@ -1070,7 +1105,8 @@ bool repetition_matcher_t::prepare_next_matching_round(
} else { // backward matching
single_iter_bind_.bind_kind = BIND_OUT;
iport_t iport = pmap_.second;
op_t *current_op = local_cached_ctx.in_port_map.at(iport).first;
op_t *current_op
= local_cached_ctx.in_port_map.find(iport)->second.first;
if (iport >= current_op->num_inputs()) return true;
auto in_value = current_op->get_input_value(iport);
single_iter_bind_.bind_op = &(in_value->get_producer());
Expand Down Expand Up @@ -1143,9 +1179,11 @@ bool repetition_matcher_t::match_next_op(const binding_t &bind_arg) {
assertm(bind_arg.bind_node->get_inputs().size() <= 1,
"repetition is restricted to have only 1 input");
if (bind_arg.bind_node->get_inputs().size() == 1) {
op_t *current_op = rep_global_ctx_.in_port_map[pmap_.second].first;
op_t *current_op = rep_global_ctx_.in_port_map.find(pmap_.second)
->second.first;
size_t current_port
= rep_global_ctx_.in_port_map[pmap_.second].second;
= rep_global_ctx_.in_port_map.find(pmap_.second)
->second.second;
binding_t current_bind(BIND_OUT, current_op, current_port,
rep_node_, bind_arg.bind_port);
if (!match_node_inputs(current_bind, parent_ctx_, updated_op_map_))
Expand Down Expand Up @@ -1195,6 +1233,8 @@ size_t repetition_matcher_t::match_repetition_blocks() {
size_t num_rep = 0;
while (true) {
match_context_t local_cached_ctx {rep_global_ctx_};
local_cached_ctx.in_port_map.clear();
local_cached_ctx.out_port_map.clear();
std::unordered_map<op_t *, pb_op_t *> local_op_map = updated_op_map_;
if (forward_match_) {
// Same reason as adding condition of
Expand Down
15 changes: 11 additions & 4 deletions src/graph/utils/pm/nested_matcher.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2023 Intel Corporation
* Copyright 2021-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -92,7 +92,12 @@ class binding_t {
size_t hint_op_port = 0;
};

using graph_port_map = std::unordered_map<size_t, std::pair<op_t *, size_t>>;
// one input port can have multiple consumers
using graph_in_port_map
= std::unordered_multimap<size_t, std::pair<op_t *, size_t>>;
// one output port corresponds to one producer
using graph_out_port_map
= std::unordered_map<size_t, std::pair<op_t *, size_t>>;

//
// match context tracks a pattern graph match progress
Expand All @@ -111,8 +116,8 @@ class match_context_t {
match_context_t *get_parent_context() { return parent_ctx; };
pb_graph_t *get_graph() { return graph_; };

graph_port_map in_port_map;
graph_port_map out_port_map;
graph_in_port_map in_port_map;
graph_out_port_map out_port_map;

protected:
match_context_t *parent_ctx;
Expand Down Expand Up @@ -208,6 +213,8 @@ bool match_pattern(op_t *first_op, const std::shared_ptr<pb_graph_t> &pattern,
inline std::vector<op_t *> reorder_matched_list(
const std::unordered_map<op_t *, pb_op_t *> &matched_op_map);

// verify if all the in_ports of the graph have been filled
bool verify_global_in_map(match_context_t *ctx);
//
// fill the current match_context's in/out port map
// to pattern match_context. Useful for nested patterns
Expand Down

0 comments on commit 3d64643

Please sign in to comment.