diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index c0e5e83eb905..5a3d44c04cea 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -22,6 +22,7 @@ #include "./cached_op.h" #include "../executor/exec_pass.h" #include "../profiler/profiler.h" +#include "../operator/operator_common.h" namespace mxnet { @@ -95,7 +96,6 @@ CachedOp::CachedOp( using namespace imperative; static const std::vector zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")}; static const auto _copy = Op::Get("_copy"); - config_.Init(flags); if (config_.static_shape) { @@ -204,26 +204,17 @@ CachedOp::CachedOp( size_t num_forward_outputs = num_outputs(); for (uint32_t i = 0; i < ograd_entries_.size(); ++i) { if (!idx.exist(ograd_entries_[i].node.get())) continue; - auto eid = idx.entry_id(ograd_entries_[i]); - if (ref_count[eid] > 0) { - bwd_ograd_dep_.push_back(i); - } + bwd_ograd_dep_.push_back(i); } save_inputs_.resize(num_forward_inputs, false); for (uint32_t i = 0; i < num_forward_inputs; ++i) { - auto eid = idx.entry_id(idx.input_nodes()[i], 0); - if (ref_count[eid] > 0) { - save_inputs_[i] = true; - bwd_in_dep_.push_back(i); - } + save_inputs_[i] = true; + bwd_in_dep_.push_back(i); } save_outputs_.resize(idx.outputs().size(), false); for (uint32_t i = 0; i < num_forward_outputs; ++i) { - auto eid = idx.entry_id(idx.outputs()[i]); - if (ref_count[eid] > 0) { - save_outputs_[i] = true; - bwd_out_dep_.push_back(i); - } + save_outputs_[i] = true; + bwd_out_dep_.push_back(i); } } } @@ -233,7 +224,7 @@ CachedOp::~CachedOp() { std::vector CachedOp::Gradient( const nnvm::NodePtr& node, - const std::vector& ograds) { + const std::vector& ograds) const { using namespace nnvm; static const auto _backward_CachedOp = Op::Get("_backward_CachedOp"); static const auto _NoGrad = Op::Get("_NoGradient"); @@ -328,6 +319,27 @@ bool CachedOp::SetForwardGraph( return false; } +// Utility function to set backward input eids +void SetBackwardInputEid(const std::vector& bwd_in_dep, + const std::vector& bwd_out_dep, + const std::vector& bwd_ograd_dep, + const std::vector& ograd_entries, + const nnvm::IndexedGraph& idx, + std::vector *bwd_input_eid) { + for (const auto& i : bwd_ograd_dep) { + auto eid = idx.entry_id(ograd_entries[i]); + bwd_input_eid->push_back(eid); + } + for (const auto& i : bwd_in_dep) { + auto eid = idx.entry_id(idx.input_nodes()[i], 0); + bwd_input_eid->push_back(eid); + } + for (const auto& i : bwd_out_dep) { + auto eid = idx.entry_id(idx.outputs()[i]); + bwd_input_eid->push_back(eid); + } +} + bool CachedOp::SetBackwardGraph( GraphInfo* info, const std::vector& reqs, @@ -356,18 +368,8 @@ bool CachedOp::SetBackwardGraph( if (info->bwd_input_eid.size() != inputs.size()) { info->bwd_input_eid.clear(); - for (const auto& i : bwd_ograd_dep_) { - auto eid = idx.entry_id(ograd_entries_[i]); - info->bwd_input_eid.push_back(eid); - } - for (const auto& i : bwd_in_dep_) { - auto eid = idx.entry_id(idx.input_nodes()[i], 0); - info->bwd_input_eid.push_back(eid); - } - for (const auto& i : bwd_out_dep_) { - auto eid = idx.entry_id(idx.outputs()[i]); - info->bwd_input_eid.push_back(eid); - } + SetBackwardInputEid(bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_, + ograd_entries_, idx, &info->bwd_input_eid); CHECK_EQ(inputs.size(), info->bwd_input_eid.size()); } @@ -1019,6 +1021,79 @@ void CachedOp::Backward( Engine::Get()->set_bulk_size(prev_bulk_size); } +bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + using namespace imperative; + nnvm::Graph g(fwd_graph_); + const auto& idx = g.indexed_graph(); + const auto &outputs = idx.outputs(); + + // Prepare stypes and contexts based on inputs + StorageTypeVector storage_type_inputs; + storage_type_inputs.reserve(in_attrs->size()); + for (size_t i = 0; i < in_attrs->size(); ++i) { + storage_type_inputs.emplace_back(in_attrs->at(i)); + } + exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask); + + // Forward graph storage type inference + CheckAndInferStorageType(&g, std::move(dev_masks), std::move(storage_type_inputs), true); + // Retrieve result and set outputs + const auto& inferred_stypes = g.GetAttr("storage_type"); + for (size_t i = 0; i < out_attrs->size(); i++) { + const auto eid = idx.entry_id(outputs[i]); + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]); + } + DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); + return true; +} + +bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + using namespace imperative; + nnvm::Graph g(full_graph_); + const auto& idx = g.indexed_graph(); + const auto &outputs = idx.outputs(); + const size_t num_forward_outputs = fwd_graph_.outputs.size(); + CHECK_EQ(outputs.size(), num_forward_outputs + out_attrs->size()); + + // Construct bwd_input_eid + std::vector bwd_input_eid; + SetBackwardInputEid(bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_, + ograd_entries_, idx, &bwd_input_eid); + CHECK_EQ(in_attrs->size(), bwd_input_eid.size()); + + // Prepare stypes and contexts based on inputs + StorageTypeVector stypes(idx.num_node_entries(), -1); + for (size_t i = 0; i < in_attrs->size(); ++i) { + stypes[bwd_input_eid[i]] = in_attrs->at(i); + } + // Some out_attr is known ahead of time (e.g. the grad stype is given by users). + // Prepare these to before invoking infer storage on the subgraph + for (size_t i = 0; i < out_attrs->size(); i++) { + const auto eid = idx.entry_id(outputs[i + num_forward_outputs]); + stypes[eid] = out_attrs->at(i); + } + exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask); + + // Full graph storage type inference + CheckAndInferStorageType(&g, std::move(dev_masks), std::move(stypes), false); + // Retrieve result and set outputs + const auto& inferred_stypes = g.GetAttr("storage_type"); + for (size_t i = 0; i < out_attrs->size(); i++) { + const auto eid = idx.entry_id(outputs[i + num_forward_outputs]); + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]); + } + DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); + return true; +} + NNVM_REGISTER_OP(_CachedOp) .set_num_inputs([](const NodeAttrs& attrs) { @@ -1029,6 +1104,14 @@ NNVM_REGISTER_OP(_CachedOp) const CachedOpPtr& op = nnvm::get(attrs.parsed); return op->num_outputs(); }) +.set_attr("FInferStorageType", [](const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); + }) .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { const CachedOpPtr& op = nnvm::get(n->attrs.parsed); @@ -1044,6 +1127,14 @@ NNVM_REGISTER_OP(_backward_CachedOp) const CachedOpPtr& op = nnvm::get(attrs.parsed); return op->num_inputs() - op->mutable_input_nodes().size(); }) +.set_attr("FInferStorageType", [](const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); + }) .set_attr("TIsLayerOpBackward", true) .set_attr("TIsBackward", true); diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 60a40c5e4a52..6b94c67a94e2 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -71,13 +71,13 @@ class CachedOp { const nnvm::Symbol& sym, const std::vector >& flags); ~CachedOp(); - uint32_t num_inputs() { + uint32_t num_inputs() const { return fwd_graph_.indexed_graph().input_nodes().size(); } - uint32_t num_outputs() { + uint32_t num_outputs() const { return fwd_graph_.outputs.size(); } - uint32_t num_backward_inputs() { + uint32_t num_backward_inputs() const { return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size(); } std::vector& save_inputs() { @@ -86,12 +86,12 @@ class CachedOp { std::vector& save_outputs() { return save_outputs_; } - const std::unordered_set& mutable_input_nodes() { + const std::unordered_set& mutable_input_nodes() const { return fwd_graph_.indexed_graph().mutable_input_nodes(); } std::vector Gradient( const nnvm::NodePtr& node, - const std::vector& ograds); + const std::vector& ograds) const; void Forward( const std::shared_ptr& op_ptr, const std::vector& inputs, @@ -102,6 +102,20 @@ class CachedOp { const std::vector& inputs, const std::vector& reqs, const std::vector& outputs); + // forward storage type inference + bool ForwardStorageType( + const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs); + // backward storage type inference + bool BackwardStorageType( + const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs); private: struct GraphInfo; diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 726531d02994..faff5f173fe1 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -668,7 +668,6 @@ inline bool CheckAndInferStorageType(nnvm::Graph* p_g, exec::DevMaskVector&& dev g.attrs["storage_type"] = std::make_shared(std::move(storage_types)); g = exec::InferStorageType(std::move(g)); } - CHECK_EQ(g.GetAttr("storage_type_num_unknown_nodes"), 0U); return false; } diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 0a9cd08db81b..02130eb32e51 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -256,7 +256,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) { */ #define STORAGE_TYPE_ASSIGN_CHECK(type_array, index, type) \ { \ - if (!type_assign(&(type_array)[index], type)) { \ + if (!::mxnet::op::type_assign(&(type_array)[index], type)) { \ std::ostringstream os; \ os << "Storage type inconsistent, Provided = " \ << common::stype_string((type_array)[index]) << ',' \ @@ -274,7 +274,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) { */ #define DISPATCH_MODE_ASSIGN_CHECK(type_array, index, type) \ { \ - if (!dispatch_mode_assign(&(type_array)[index], type)) { \ + if (!::mxnet::op::dispatch_mode_assign(&(type_array)[index], type)) { \ std::ostringstream os; \ os << "Dispatch mode inconsistent, Provided = " \ << common::dispatch_mode_string((type_array)[index]) << ',' \ diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 6fafb3671ffe..cd3cc685bdd6 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -1293,6 +1293,63 @@ def test_legacy_save_params(): model.load_params('test.params', ctx=mx.cpu()) +@with_seed() +def test_sparse_hybrid_block_grad(): + class Embedding(mx.gluon.HybridBlock): + def __init__(self, num_tokens, embedding_size): + super(Embedding, self).__init__() + self.num_tokens = num_tokens + + with self.name_scope(): + self.embedding = mx.gluon.nn.Embedding( + num_tokens, embedding_size, sparse_grad=True) + + def hybrid_forward(self, F, words): + emb = self.embedding(words) + return emb + F.ones_like(emb) + + embedding = Embedding(20, 3) + embedding.initialize() + embedding.hybridize() + + with mx.autograd.record(): + emb0 = embedding(mx.nd.arange(10)).sum() + emb1 = embedding(mx.nd.arange(10)).sum() + loss = emb0 + emb1 + loss.backward() + grad = embedding.embedding.weight.grad().asnumpy() + assert (grad[:10] == 2).all() + assert (grad[10:] == 0).all() + +@with_seed() +def test_sparse_hybrid_block(): + class Linear(mx.gluon.HybridBlock): + def __init__(self, units): + super(Linear, self).__init__() + with self.name_scope(): + self.w = self.params.get('w', shape=(units, units)) + + def hybrid_forward(self, F, x, w): + return F.dot(x, w) + + class SparseBlock(mx.gluon.HybridBlock): + def __init__(self, units): + super(SparseBlock, self).__init__() + with self.name_scope(): + self.net = Linear(units) + + def hybrid_forward(self, F, x): + return self.net(x) * x + + block = SparseBlock(2) + block.initialize() + block.hybridize() + x = mx.nd.ones((2,2)).tostype('csr') + with mx.autograd.record(): + z = block(x) + block(x) + z.backward() + assert (block.net.w.grad().asnumpy() == 4).all() + def test_hybrid_static_memory_recording(): net = gluon.model_zoo.vision.get_resnet( 1, 18, pretrained=True, ctx=mx.context.current_context())