Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-555] Add subgraph storage type inference to CachedOp #11306

Merged
merged 9 commits into from
Jun 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 119 additions & 28 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "./cached_op.h"
#include "../executor/exec_pass.h"
#include "../profiler/profiler.h"
#include "../operator/operator_common.h"


namespace mxnet {
Expand Down Expand Up @@ -95,7 +96,6 @@ CachedOp::CachedOp(
using namespace imperative;
static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")};
static const auto _copy = Op::Get("_copy");

config_.Init(flags);

if (config_.static_shape) {
Expand Down Expand Up @@ -204,26 +204,17 @@ CachedOp::CachedOp(
size_t num_forward_outputs = num_outputs();
for (uint32_t i = 0; i < ograd_entries_.size(); ++i) {
if (!idx.exist(ograd_entries_[i].node.get())) continue;
auto eid = idx.entry_id(ograd_entries_[i]);
if (ref_count[eid] > 0) {
bwd_ograd_dep_.push_back(i);
}
bwd_ograd_dep_.push_back(i);
}
save_inputs_.resize(num_forward_inputs, false);
for (uint32_t i = 0; i < num_forward_inputs; ++i) {
auto eid = idx.entry_id(idx.input_nodes()[i], 0);
if (ref_count[eid] > 0) {
save_inputs_[i] = true;
bwd_in_dep_.push_back(i);
}
save_inputs_[i] = true;
bwd_in_dep_.push_back(i);
}
save_outputs_.resize(idx.outputs().size(), false);
for (uint32_t i = 0; i < num_forward_outputs; ++i) {
auto eid = idx.entry_id(idx.outputs()[i]);
if (ref_count[eid] > 0) {
save_outputs_[i] = true;
bwd_out_dep_.push_back(i);
}
save_outputs_[i] = true;
bwd_out_dep_.push_back(i);
}
}
}
Expand All @@ -233,7 +224,7 @@ CachedOp::~CachedOp() {

std::vector<nnvm::NodeEntry> CachedOp::Gradient(
const nnvm::NodePtr& node,
const std::vector<nnvm::NodeEntry>& ograds) {
const std::vector<nnvm::NodeEntry>& ograds) const {
using namespace nnvm;
static const auto _backward_CachedOp = Op::Get("_backward_CachedOp");
static const auto _NoGrad = Op::Get("_NoGradient");
Expand Down Expand Up @@ -328,6 +319,27 @@ bool CachedOp::SetForwardGraph(
return false;
}

// Utility function to set backward input eids
void SetBackwardInputEid(const std::vector<uint32_t>& bwd_in_dep,
const std::vector<uint32_t>& bwd_out_dep,
const std::vector<uint32_t>& bwd_ograd_dep,
const std::vector<nnvm::NodeEntry>& ograd_entries,
const nnvm::IndexedGraph& idx,
std::vector<uint32_t> *bwd_input_eid) {
for (const auto& i : bwd_ograd_dep) {
auto eid = idx.entry_id(ograd_entries[i]);
bwd_input_eid->push_back(eid);
}
for (const auto& i : bwd_in_dep) {
auto eid = idx.entry_id(idx.input_nodes()[i], 0);
bwd_input_eid->push_back(eid);
}
for (const auto& i : bwd_out_dep) {
auto eid = idx.entry_id(idx.outputs()[i]);
bwd_input_eid->push_back(eid);
}
}

bool CachedOp::SetBackwardGraph(
GraphInfo* info,
const std::vector<OpReqType>& reqs,
Expand Down Expand Up @@ -356,18 +368,8 @@ bool CachedOp::SetBackwardGraph(

if (info->bwd_input_eid.size() != inputs.size()) {
info->bwd_input_eid.clear();
for (const auto& i : bwd_ograd_dep_) {
auto eid = idx.entry_id(ograd_entries_[i]);
info->bwd_input_eid.push_back(eid);
}
for (const auto& i : bwd_in_dep_) {
auto eid = idx.entry_id(idx.input_nodes()[i], 0);
info->bwd_input_eid.push_back(eid);
}
for (const auto& i : bwd_out_dep_) {
auto eid = idx.entry_id(idx.outputs()[i]);
info->bwd_input_eid.push_back(eid);
}
SetBackwardInputEid(bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_,
ograd_entries_, idx, &info->bwd_input_eid);
CHECK_EQ(inputs.size(), info->bwd_input_eid.size());
}

Expand Down Expand Up @@ -1019,6 +1021,79 @@ void CachedOp::Backward(
Engine::Get()->set_bulk_size(prev_bulk_size);
}

bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
using namespace imperative;
nnvm::Graph g(fwd_graph_);
const auto& idx = g.indexed_graph();
const auto &outputs = idx.outputs();

// Prepare stypes and contexts based on inputs
StorageTypeVector storage_type_inputs;
storage_type_inputs.reserve(in_attrs->size());
for (size_t i = 0; i < in_attrs->size(); ++i) {
storage_type_inputs.emplace_back(in_attrs->at(i));
}
exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);

// Forward graph storage type inference
CheckAndInferStorageType(&g, std::move(dev_masks), std::move(storage_type_inputs), true);
// Retrieve result and set outputs
const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type");
for (size_t i = 0; i < out_attrs->size(); i++) {
const auto eid = idx.entry_id(outputs[i]);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]);
}
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
return true;
}

bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
using namespace imperative;
nnvm::Graph g(full_graph_);
const auto& idx = g.indexed_graph();
const auto &outputs = idx.outputs();
const size_t num_forward_outputs = fwd_graph_.outputs.size();
CHECK_EQ(outputs.size(), num_forward_outputs + out_attrs->size());

// Construct bwd_input_eid
std::vector<uint32_t> bwd_input_eid;
SetBackwardInputEid(bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_,
ograd_entries_, idx, &bwd_input_eid);
CHECK_EQ(in_attrs->size(), bwd_input_eid.size());

// Prepare stypes and contexts based on inputs
StorageTypeVector stypes(idx.num_node_entries(), -1);
for (size_t i = 0; i < in_attrs->size(); ++i) {
stypes[bwd_input_eid[i]] = in_attrs->at(i);
}
// Some out_attr is known ahead of time (e.g. the grad stype is given by users).
// Prepare these to before invoking infer storage on the subgraph
for (size_t i = 0; i < out_attrs->size(); i++) {
const auto eid = idx.entry_id(outputs[i + num_forward_outputs]);
stypes[eid] = out_attrs->at(i);
}
exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);

// Full graph storage type inference
CheckAndInferStorageType(&g, std::move(dev_masks), std::move(stypes), false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use exec::InferShape(std::move(g))? Is it guaranteed that the inference works in one invocation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i haven't seen cases where storage type inference requires multiple runs since it doesn't use outputs to infer inputs. did you?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't know. but it's easy to make it work for multiple runs.

// Retrieve result and set outputs
const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type");
for (size_t i = 0; i < out_attrs->size(); i++) {
const auto eid = idx.entry_id(outputs[i + num_forward_outputs]);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]);
}
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
return true;
}


NNVM_REGISTER_OP(_CachedOp)
.set_num_inputs([](const NodeAttrs& attrs) {
Expand All @@ -1029,6 +1104,14 @@ NNVM_REGISTER_OP(_CachedOp)
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->num_outputs();
})
.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
})
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(n->attrs.parsed);
Expand All @@ -1044,6 +1127,14 @@ NNVM_REGISTER_OP(_backward_CachedOp)
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->num_inputs() - op->mutable_input_nodes().size();
})
.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
})
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<bool>("TIsBackward", true);

Expand Down
24 changes: 19 additions & 5 deletions src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ class CachedOp {
const nnvm::Symbol& sym,
const std::vector<std::pair<std::string, std::string> >& flags);
~CachedOp();
uint32_t num_inputs() {
uint32_t num_inputs() const {
return fwd_graph_.indexed_graph().input_nodes().size();
}
uint32_t num_outputs() {
uint32_t num_outputs() const {
return fwd_graph_.outputs.size();
}
uint32_t num_backward_inputs() {
uint32_t num_backward_inputs() const {
return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
}
std::vector<bool>& save_inputs() {
Expand All @@ -86,12 +86,12 @@ class CachedOp {
std::vector<bool>& save_outputs() {
return save_outputs_;
}
const std::unordered_set<uint32_t>& mutable_input_nodes() {
const std::unordered_set<uint32_t>& mutable_input_nodes() const {
return fwd_graph_.indexed_graph().mutable_input_nodes();
}
std::vector<nnvm::NodeEntry> Gradient(
const nnvm::NodePtr& node,
const std::vector<nnvm::NodeEntry>& ograds);
const std::vector<nnvm::NodeEntry>& ograds) const;
void Forward(
const std::shared_ptr<CachedOp>& op_ptr,
const std::vector<NDArray*>& inputs,
Expand All @@ -102,6 +102,20 @@ class CachedOp {
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);
// forward storage type inference
bool ForwardStorageType(
const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);
// backward storage type inference
bool BackwardStorageType(
const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);

private:
struct GraphInfo;
Expand Down
1 change: 0 additions & 1 deletion src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,6 @@ inline bool CheckAndInferStorageType(nnvm::Graph* p_g, exec::DevMaskVector&& dev
g.attrs["storage_type"] = std::make_shared<dmlc::any>(std::move(storage_types));
g = exec::InferStorageType(std::move(g));
}

CHECK_EQ(g.GetAttr<size_t>("storage_type_num_unknown_nodes"), 0U);
return false;
}
Expand Down
4 changes: 2 additions & 2 deletions src/operator/operator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
*/
#define STORAGE_TYPE_ASSIGN_CHECK(type_array, index, type) \
{ \
if (!type_assign(&(type_array)[index], type)) { \
if (!::mxnet::op::type_assign(&(type_array)[index], type)) { \
std::ostringstream os; \
os << "Storage type inconsistent, Provided = " \
<< common::stype_string((type_array)[index]) << ',' \
Expand All @@ -274,7 +274,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
*/
#define DISPATCH_MODE_ASSIGN_CHECK(type_array, index, type) \
{ \
if (!dispatch_mode_assign(&(type_array)[index], type)) { \
if (!::mxnet::op::dispatch_mode_assign(&(type_array)[index], type)) { \
std::ostringstream os; \
os << "Dispatch mode inconsistent, Provided = " \
<< common::dispatch_mode_string((type_array)[index]) << ',' \
Expand Down
57 changes: 57 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,6 +1293,63 @@ def test_legacy_save_params():
model.load_params('test.params', ctx=mx.cpu())


@with_seed()
def test_sparse_hybrid_block_grad():
class Embedding(mx.gluon.HybridBlock):
def __init__(self, num_tokens, embedding_size):
super(Embedding, self).__init__()
self.num_tokens = num_tokens

with self.name_scope():
self.embedding = mx.gluon.nn.Embedding(
num_tokens, embedding_size, sparse_grad=True)

def hybrid_forward(self, F, words):
emb = self.embedding(words)
return emb + F.ones_like(emb)

embedding = Embedding(20, 3)
embedding.initialize()
embedding.hybridize()

with mx.autograd.record():
emb0 = embedding(mx.nd.arange(10)).sum()
emb1 = embedding(mx.nd.arange(10)).sum()
loss = emb0 + emb1
loss.backward()
grad = embedding.embedding.weight.grad().asnumpy()
assert (grad[:10] == 2).all()
assert (grad[10:] == 0).all()

@with_seed()
def test_sparse_hybrid_block():
class Linear(mx.gluon.HybridBlock):
def __init__(self, units):
super(Linear, self).__init__()
with self.name_scope():
self.w = self.params.get('w', shape=(units, units))

def hybrid_forward(self, F, x, w):
return F.dot(x, w)

class SparseBlock(mx.gluon.HybridBlock):
def __init__(self, units):
super(SparseBlock, self).__init__()
with self.name_scope():
self.net = Linear(units)

def hybrid_forward(self, F, x):
return self.net(x) * x

block = SparseBlock(2)
block.initialize()
block.hybridize()
x = mx.nd.ones((2,2)).tostype('csr')
with mx.autograd.record():
z = block(x) + block(x)
z.backward()
assert (block.net.w.grad().asnumpy() == 4).all()

def test_hybrid_static_memory_recording():
net = gluon.model_zoo.vision.get_resnet(
1, 18, pretrained=True, ctx=mx.context.current_context())
Expand Down