Skip to content

Commit

Permalink
change the interface (apache#13)
Browse files Browse the repository at this point in the history
* remove SubgraphOperator.

* change the interface of SubgraphSelector.
  • Loading branch information
zheng-da authored and reminisce committed Jun 20, 2018
1 parent c585398 commit 382c792
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 110 deletions.
32 changes: 14 additions & 18 deletions src/operator/subgraph/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,27 +120,23 @@ void LabelSubgraph(const Graph&g,
cur_node->label = label;
subgraph_nodes->push_back(cur_node);
// get qualified adjacent input nodes
if (select_func->UseIncomingEdges()) {
for (auto& e : cur_node->node->inputs) {
if (select_func->Select(*e.node)) {
const auto nid = indexed_graph.node_id(e.node.get());
CHECK_LT(nid, simple_nodes.size());
// this node has not been visited yet
if (simple_nodes[nid]->label == -1)
node_queue.push(simple_nodes[nid].get());
}
for (auto& e : cur_node->node->inputs) {
if (select_func->SelectInput(*cur_node->node, *e.node)) {
const auto nid = indexed_graph.node_id(e.node.get());
CHECK_LT(nid, simple_nodes.size());
// this node has not been visited yet
if (simple_nodes[nid]->label == -1)
node_queue.push(simple_nodes[nid].get());
}
}
// get qualified output nodes
if (select_func->UseOutgoingEdges()) {
for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end(); ++it) {
if (select_func->Select(*it->first)) {
const auto nid = indexed_graph.node_id(it->first);
CHECK_LT(nid, simple_nodes.size());
// this node has not been visited yet
if (simple_nodes[nid]->label == -1)
node_queue.push(simple_nodes[nid].get());
}
for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end(); ++it) {
if (select_func->SelectOutput(*cur_node->node, *it->first)) {
const auto nid = indexed_graph.node_id(it->first);
CHECK_LT(nid, simple_nodes.size());
// this node has not been visited yet
if (simple_nodes[nid]->label == -1)
node_queue.push(simple_nodes[nid].get());
}
}
}
Expand Down
47 changes: 8 additions & 39 deletions src/operator/subgraph/subgraph_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,15 @@
namespace mxnet {
namespace op {

static std::unordered_map<std::string, SubgraphPropertyPtr> subg_props;

void RegisterSubgraphProperty(SubgraphPropertyPtr property) {
auto ret = subg_props.insert(std::pair<std::string, SubgraphPropertyPtr>(
property->GetType(), property));
CHECK(!ret.second) << "The subgraph property for " << property->GetType()
<< " has been registered";
}

class DefaultSubgraphOperator: public SubgraphOperator {
class DefaultSubgraphOperator {
public:
// TODO: initialize uuid
DefaultSubgraphOperator(const Symbol& sym) : SubgraphOperator(sym),
subgraph_uuid_("dfasdfadsmxdfw324"),
DefaultSubgraphOperator(const Symbol& sym) : subgraph_uuid_("dfasdfadsmxdfw324"),
immutable_data_names_(sym.ListInputNames(Symbol::kReadOnlyArgs)),
mutable_data_names_(sym.ListInputNames(Symbol::kAuxiliaryStates)),
//input_data_names_(sym.ListInputNames(Symbol::kAll)),
output_data_names_(sym.ListOutputNames()) {
this->subg_sym = sym;
const std::vector<std::string> input_data_names = sym.ListInputNames(Symbol::kAll);
//const std::vector<std::string> immutable_data_names = sym.ListInputNames(Symbol::kReadOnlyArgs);
//const std::vector<std::string> mutable_data_names = sym.ListInputNames(Symbol::kAuxiliaryStates);
Expand Down Expand Up @@ -82,6 +73,7 @@ class DefaultSubgraphOperator: public SubgraphOperator {
}

private:
nnvm::Symbol subg_sym;
std::string subgraph_uuid_;
// this variable records the NDArrays' var versions of the last run.
std::vector<int64_t> ndarray_var_versions_;
Expand All @@ -94,10 +86,6 @@ class DefaultSubgraphOperator: public SubgraphOperator {
std::shared_ptr<Executor> subgraph_executor_;
};

SubgraphOperatorPtr SimpleSubgraphProperty::CreateSubgraphOperator(const nnvm::Symbol &sym) const {
return std::make_shared<DefaultSubgraphOperator>(sym);
}

void DefaultSubgraphOperator::Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
Expand All @@ -114,7 +102,7 @@ void DefaultSubgraphOperator::Forward(const OpContext& ctx,
}
std::vector<NDArray> grad_store(arg_arrays.size());
std::vector<OpReqType> grad_req(arg_arrays.size(), kNullOp);
this->subgraph_executor_.reset(Executor::Bind(this->GetSubgraph(),
this->subgraph_executor_.reset(Executor::Bind(subg_sym,
ctx.run_ctx.ctx, std::map<std::string, Context>(), arg_arrays, grad_store,
grad_req, aux_arrays));
}
Expand Down Expand Up @@ -148,31 +136,12 @@ void DefaultSubgraphOperator::Forward(const OpContext& ctx,
}
}

struct SubgraphOpState {
SubgraphOperatorPtr op;

SubgraphOpState(SubgraphOperatorPtr op) {
this->op = op;
}
};

OpStatePtr CreateSubgraphOpState(const NodeAttrs& attrs,
Context ctx,
const std::vector<TShape>& in_shapes,
const std::vector<int>& in_types) {
const Symbol& subgraph_sym = nnvm::get<Symbol>(attrs.parsed);
auto it = attrs.dict.find("exec_type");
if (it == attrs.dict.end()) {
auto op = std::make_shared<DefaultSubgraphOperator>(subgraph_sym);
return OpStatePtr::Create<SubgraphOpState>(op);
}

std::string exec_name = it->second;
auto prop_iter = subg_props.find(exec_name);
CHECK(prop_iter != subg_props.end()) << "We don't support the execution type: "
<< exec_name;
auto op = prop_iter->second->CreateSubgraphOperator(subgraph_sym);
return OpStatePtr::Create<SubgraphOpState>(op);
return OpStatePtr::Create<DefaultSubgraphOperator>(subgraph_sym);
}

bool SubgraphOpShape(const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -267,8 +236,8 @@ void SubgraphOpForward(const OpStatePtr& state_ptr,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
SubgraphOpState& state = state_ptr.get_state<SubgraphOpState>();
state.op->Forward(ctx, inputs, req, outputs);
DefaultSubgraphOperator& op = state_ptr.get_state<DefaultSubgraphOperator>();
op.Forward(ctx, inputs, req, outputs);
}

NNVM_REGISTER_OP(_subgraph_op)
Expand Down
64 changes: 11 additions & 53 deletions src/operator/subgraph/subgraph_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,48 +56,16 @@ class SubgraphSelector {
public:
virtual ~SubgraphSelector() {
}
/*
* Given a set of nodes that have been selected so far for a subgraph, determine
* if the input node should be selected for a subgraph.
*/
// Determine if the node should be selected for a subgraph.
virtual bool Select(const nnvm::Node &n) = 0;
virtual bool UseIncomingEdges() const = 0;
virtual bool UseOutgoingEdges() const = 0;
// Determine if the input node should be selected for a subgraph.
virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) = 0;
// Determine if the output node should be selected for a subgraph.
virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) = 0;
};

using SubgraphSelectorPtr = std::shared_ptr<SubgraphSelector>;

/*
* This is the interface of the subgraph operator that executes the computation
* in the subgraph.
*/
class SubgraphOperator {
public:
SubgraphOperator(const nnvm::Symbol &sym) {
this->subgraph_sym_ = sym;
}

virtual ~SubgraphOperator() {
}

const nnvm::Symbol &GetSubgraph() const {
return subgraph_sym_;
}

virtual void Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) = 0;
virtual void Backward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) = 0;
private:
nnvm::Symbol subgraph_sym_;
};

using SubgraphOperatorPtr = std::shared_ptr<SubgraphOperator>;

/*
* This provides a set of properties for partitioning a graph into subgraphs,
* reconstructing a new graph from the subgraphs and creating a subgraph
Expand All @@ -110,10 +78,6 @@ class SubgraphProperty {
// create an nnvm node for a given subgraph. Here users can customize how to
// execute the operators in the subgraph.
virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &s) const = 0;
// Create a subgraph operator for execution.
virtual SubgraphOperatorPtr CreateSubgraphOperator(const nnvm::Symbol &sym) const = 0;
// The type of the subgraph.
virtual std::string GetType() const = 0;
};

using SubgraphPropertyPtr = std::shared_ptr<SubgraphProperty>;
Expand All @@ -132,16 +96,16 @@ class ContainOpSelector: public SubgraphSelector {
this->op_names = op_names;
}

virtual bool UseIncomingEdges() const {
return true;
virtual bool Select(const nnvm::Node &n) {
return !n.is_variable() && op_names->count(n.op()->name);
}

virtual bool UseOutgoingEdges() const {
return true;
virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) {
return !new_node.is_variable() && op_names->count(new_node.op()->name);
}

virtual bool Select(const nnvm::Node &n) {
return !n.is_variable() && op_names->count(n.op()->name);
virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) {
return !new_node.is_variable() && op_names->count(new_node.op()->name);
}
};

Expand All @@ -158,19 +122,13 @@ class SimpleSubgraphProperty: public SubgraphProperty {
nnvm::NodePtr n = nnvm::Node::Create();
n->attrs.op = Op::Get("_subgraph_op");
n->attrs.name = "_subgraph_op";
n->attrs.dict.insert(std::pair<std::string, std::string>("exec_type", GetType()));
n->attrs.parsed = sym;
return n;
}
virtual SubgraphSelectorPtr CreateSubgraphSelector() const {
return std::make_shared<ContainOpSelector>(op_names);
}

virtual SubgraphOperatorPtr CreateSubgraphOperator(const nnvm::Symbol &sym) const;
virtual std::string GetType() const {
return "default";
}

private:
std::shared_ptr<const std::unordered_set<std::string>> op_names;
};
Expand Down

0 comments on commit 382c792

Please sign in to comment.