From 9890739b38c06eed25118f22b737db3cf7d429a2 Mon Sep 17 00:00:00 2001 From: Zhang-Shu Date: Tue, 15 May 2018 03:29:52 +0800 Subject: [PATCH] [WIP][MXNET-107] Fused LSTM implementation for CPU (#10104) * register RNN fused-API with nnvm, finish single-layer && undirection LSTM forward function * fix coding style and lint complains * add single-layer && undirectional LSTM backward function * make interface universal for other RNN mode * share intermediate result between forward and backward in a trick way * add comments for important parameters * modify testcase * Fix coding style and error message * fix openmp collapse error * fix const * remove rnn.cu and skip related testcases temporarily for building on GPU * support multi-layer and bidirectional for lstm inference * remove some testcaseS in test_gluon_rnn.py to build on GPU * remove testcase between fp32 and fp64 temporarily * retrigger ci * fix some logs * use a better way to share memory * fix cudnn registration * fix invariant calculations and enable some gpu testcases * add thread local cache for cudnn rnn op * add thread local cache for rnn op * fix bugs * remove some testcases to check segmentfault * remove cudnn registeration to check segmentfault * support multi-layer for LSTM Training * modify lstm testcase * add bidirectional support for lstm * fix gluon and coding style * fix bugs * remove nnvm registration * enable gpu testcases * add detailed descriptions * add dropout check * fix workspace size * dropout is not supported, add unit test for it * fix review comments --- python/mxnet/gluon/rnn/rnn_layer.py | 4 +- src/operator/cudnn_rnn-inl.h | 3 +- src/operator/rnn-inl.h | 624 ++++++++++++++++--------- src/operator/rnn.cc | 48 +- src/operator/rnn_impl.h | 457 ++++++++++++++++++ tests/python/gpu/test_operator_gpu.py | 17 - tests/python/unittest/test_operator.py | 83 ++++ 7 files changed, 991 insertions(+), 245 deletions(-) create mode 100644 src/operator/rnn_impl.h diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 59dd74754ed2..34ad05d5cc90 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -23,7 +23,6 @@ from __future__ import print_function __all__ = ['RNN', 'LSTM', 'GRU'] -from ...autograd import is_training from ... import ndarray from .. import Block from . import rnn_cell @@ -186,8 +185,7 @@ def forward(self, inputs, states=None): for i in range(self._dir): self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2]) self.i2h_weight[i]._finish_deferred_init() - if inputs.context.device_type == 'gpu' or \ - (not is_training() and self._mode == 'lstm'): + if inputs.context.device_type == 'gpu' or self._mode == 'lstm': out = self._forward_kernel(inputs, states) else: out = self._forward(inputs, states) diff --git a/src/operator/cudnn_rnn-inl.h b/src/operator/cudnn_rnn-inl.h index 1a54b73660c7..033d30e40dc8 100644 --- a/src/operator/cudnn_rnn-inl.h +++ b/src/operator/cudnn_rnn-inl.h @@ -38,7 +38,7 @@ namespace mxnet { namespace op { #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 template -class CuDNNRNNOp : public Operator { +class CuDNNRNNOp : public Operator{ public: explicit CuDNNRNNOp(RNNParam param) { this->param_ = param; @@ -101,6 +101,7 @@ class CuDNNRNNOp : public Operator { CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_)); Storage::Get()->Free(dropout_states_); Storage::Get()->Free(reserve_space_); + init_cudnn_ = false; } } diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 13c077dd9e35..eded6aeed8a9 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -21,7 +21,7 @@ * Copyright (c) 2015 by Contributors * \file rnn-inl.h * \brief - * \author Sebastian Bodenstein + * \author Sebastian Bodenstein, Shu Zhang */ #ifndef MXNET_OPERATOR_RNN_INL_H_ #define MXNET_OPERATOR_RNN_INL_H_ @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -37,8 +38,7 @@ #include "./math.h" #include "./math_functions-inl.h" #include "./operator_common.h" -#include "./mshadow_op.h" -#include "./linalg.h" +#include "./rnn_impl.h" namespace mxnet { namespace op { @@ -50,18 +50,37 @@ namespace rnn_enum { enum RNNOpResource {kTempSpace}; } -// A utility function to calculate input size -inline int rnn_single_param_size(int inputSize, - int hiddenSize, - int mode) { - int size = hiddenSize * (hiddenSize + inputSize + 2); - // Different RNN's have different num weights +inline int GetRnnParamSize(int num_layer, + int input_size, + int state_size, + int direction, + int mode) { + int size = state_size * direction; switch (mode) { case rnn_enum::kRnnRelu: - size *= 1; + case rnn_enum::kRnnTanh: break; + case rnn_enum::kLstm: + size *= 4; + break; + case rnn_enum::kGru: + size *= 3; + break; + } + int size1 = (input_size + state_size + 2) * size; // first layer size + int size2 = (state_size * direction + state_size + 2) * size; // other layers size + int param_size = size1 + (num_layer - 1) * size2; + return param_size; +} + +inline int GetRnnBiasSize(int num_layer, + int state_size, + int direction, + int mode) { + int size = 2 * state_size * direction * num_layer; + switch (mode) { + case rnn_enum::kRnnRelu: case rnn_enum::kRnnTanh: - size *= 1; break; case rnn_enum::kLstm: size *= 4; @@ -73,19 +92,48 @@ inline int rnn_single_param_size(int inputSize, return size; } -inline int rnn_param_size(int layerNum, - int inputSize, - int hiddenSize, - bool bidirectional, - int mode) { - // get size of first layer - int size = rnn_single_param_size(inputSize, hiddenSize, mode); - // get size of remaining layers - if (bidirectional) { - size += (layerNum - 1) * rnn_single_param_size(2 * hiddenSize, hiddenSize, mode); - size *= 2; - } else { - size += (layerNum - 1) * rnn_single_param_size(hiddenSize, hiddenSize, mode); +inline size_t GetRNNWorkspaceSize(int seq_length, + int batch_size, + int hidden_size, + int direction, + int mode) { + size_t size = 0; + switch (mode) { + case rnn_enum::kRnnRelu: + case rnn_enum::kRnnTanh: + case rnn_enum::kGru: + LOG(FATAL) << "Only LSTM is supported at the moment"; + break; + case rnn_enum::kLstm: + size = (seq_length + 1) * batch_size * hidden_size * 4 + batch_size * hidden_size * 2 + + seq_length * batch_size * hidden_size * direction; + break; + default: + LOG(FATAL) << "unknown RNN mode " << mode; + break; + } + return size; +} + +inline size_t GetRNNReserveSpaceSize(int num_layer, + int direction, + int seq_length, + int batch_size, + int hidden_size, + int mode) { + size_t size = 0; + switch (mode) { + case rnn_enum::kRnnRelu: + case rnn_enum::kRnnTanh: + case rnn_enum::kGru: + LOG(FATAL) << "Only LSTM is supported at the moment"; + break; + case rnn_enum::kLstm: + size = num_layer * direction * seq_length * batch_size * hidden_size * 6; + break; + default: + LOG(FATAL) << "unknown RNN mode " << mode; + break; } return size; } @@ -125,51 +173,153 @@ struct RNNParam : public dmlc::Parameter { } }; -template -class RNNOp : public Operator { - public: - explicit RNNOp(RNNParam p) { - } - virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - // TODO(sbodenstein): add MShadow implementation +/** + * @params: ws: Temp workspace for gemm's output storage. + * rs: Reserve space of forward intermediate data used for training. + * num_layers: The number of recurrent layers. + * direction: direction is 2 if use bidirectional recurrent layers, else is 1; + * seq_length: The number of iterations to unroll over. + * batch_size: size of batch. + * input_size: The number of expected input features. + * state_size: The number of hidden state features. + * x_ptr: Pointer of tensor x containing the features of the input sequence. + * x's shape is [seq_length, batch_size, input_size] + * hx_ptr: Pointer of tensor hx containing the initial hidden state. + * hx's shape is [num_layers, batch_size, state_size] + * cx_ptr: Only used in lstm mode. pointer of tensor cx containing the initial cell state. + * cx's shape is [num_layers, batch_size, state_size] + * w_ptr: Pointer of tensor w containing weights. + * b_ptr: Pointer of tensor w containing bias. + * y_ptr: Pointer of tensor y containing the features of the output features from the + * last layers of the RNN. y's shape is [seq_length, batch_size, state_size] + * hy_ptr: Pointer of tensor hy containing the hidden state for t=seq_length. + * hy's shape is [num_layers, batch_size, state_size] + * cy_ptr: Only used in lstm mode. pointer of tensor cy containing the cell state + * for t=seq_length. cy' shape is [num_layers, batch_size, state_size] + * mode: Specifies the type of RNN to compute. + */ +template +void RNNForwardTraining(DType* ws, + DType* rs, + bool state_outputs, + const int num_layers, + const int direction, + const int seq_length, + const int batch_size, + const int input_size, + const int state_size, + DType* x_ptr, + DType* hx_ptr, + DType* cx_ptr, + DType* w_ptr, + DType* b_ptr, + DType* y_ptr, + DType* hy_ptr, + DType* cy_ptr, + int mode) { + switch (mode) { + case rnn_enum::kRnnTanh: + case rnn_enum::kRnnRelu: + case rnn_enum::kGru: + LOG(FATAL) << "Only LSTM is supported at the moment"; + break; + case rnn_enum::kLstm: + LstmForwardTraining(ws, rs, state_outputs, num_layers, direction, seq_length, + batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr, + w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr); + break; + default: + LOG(FATAL) << "unknown RNN mode " << mode; + break; } +} - virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - // TODO(sbodenstein): add MShadow implementation +template +void RNNForwardInference(DType* ws, + bool state_outputs, + const int num_layers, + const int direction, + const int seq_length, + const int batch_size, + const int input_size, + const int state_size, + DType* x_ptr, + DType* hx_ptr, + DType* cx_ptr, + DType* w_ptr, + DType* b_ptr, + DType* y_ptr, + DType* hy_ptr, + DType* cy_ptr, + int mode) { + switch (mode) { + case rnn_enum::kRnnRelu: + case rnn_enum::kRnnTanh: + case rnn_enum::kGru: + LOG(FATAL) << "Only LSTM is supported at the moment"; + break; + case rnn_enum::kLstm: + LstmForwardInference(ws, state_outputs, num_layers, direction, seq_length, + batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr, + w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr); + break; + default: + LOG(FATAL) << "unknown RNN mode" << mode; + break; } +} - private: - RNNParam param_; -}; // class RNNOp +template +void RNNBackward(DType* ws, + DType* rs, + const int num_layers, + const int direction, + const int seq_length, + const int batch_size, + const int input_size, + const int state_size, + DType* x_ptr, + DType* hx_ptr, + DType* cx_ptr, + DType* w_ptr, + DType* y_ptr, + DType* dy_ptr, + DType* dhy_ptr, + DType* dcy_ptr, + DType* dx_ptr, + DType* dhx_ptr, + DType* dcx_ptr, + DType* dw_ptr, + DType* db_ptr, + int mode) { + switch (mode) { + case rnn_enum::kRnnRelu: + case rnn_enum::kRnnTanh: + case rnn_enum::kGru: + break; + case rnn_enum::kLstm: + LstmBackward(ws, rs, num_layers, direction, seq_length, batch_size, + input_size, state_size, x_ptr, hx_ptr, cx_ptr, w_ptr, y_ptr, + dy_ptr, dhy_ptr, dcy_ptr, dx_ptr, dhx_ptr, dcx_ptr, dw_ptr, db_ptr); + break; + default: + LOG(FATAL) << "unknown RNN mode" << mode; + break; + } +} template -class RNNOp : public Operator { +class RNNOp : public Operator{ public: - explicit RNNOp(RNNParam param) { - this->param_ = param; - // RNN Mode - param_.lstm_q_ = false; - switch (param_.mode) { - case rnn_enum::kLstm: - param_.lstm_q_ = true; - break; - default: - LOG(FATAL) << "only LSTM is implmented on CPU"; + explicit RNNOp(RNNParam p) + :param_(p), init_space_(false), reserve_space_size_(0) + {} + + ~RNNOp() { + if (init_space_) { + Storage::Get()->Free(reserve_space_); + init_space_ = false; } } @@ -178,189 +328,221 @@ class RNNOp : public Operator { const std::vector &req, const std::vector &out_data, const std::vector &aux_args) { - // Layout TNC - CHECK(!ctx.is_train) << "only inference mode is available" - "for cpu at the moment."; - size_t in_expected = param_.lstm_q_ ? 4 : 3; - size_t out_expected = param_.lstm_q_ ? 3 : 2; - - if (!param_.state_outputs) - LOG(FATAL) << "no state outputs is currently not supported for cpu."; + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(param_.mode, rnn_enum::kLstm) << "Only lstm mode is supported at the moment."; + CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment."; - CHECK_EQ(req[rnn_enum::kOut], kWriteTo); + size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3; + size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2; + if (!param_.state_outputs) { + out_expected = 1; + } CHECK_EQ(in_data.size(), in_expected); CHECK_EQ(out_data.size(), out_expected); - - mshadow::Stream *s = ctx.get_stream(); - // get input + output tensors - // w layout i2h_w, h2h_w, i2h_b, h2h_b - Tensor x = - in_data[rnn_enum::kData].get(s); // TNC + Stream *s = ctx.get_stream(); + // get input + output tensor + Tensor x = in_data[rnn_enum::kData].get(s); Tensor w = in_data[rnn_enum::kParams].get(s); - Tensor hx = - in_data[rnn_enum::kState].get(s); // LNC - Tensor y = - out_data[rnn_enum::kOut].get(s); // TNC - int64_t seq_len = x.shape_[0]; - int64_t num_layers = hx.shape_[0]; - int64_t batch_size = x.shape_[1]; - int64_t h_channel = hx.shape_[2]; - int64_t in_channel = x.shape_[2]; - Tensor x_flatten = in_data[rnn_enum::kData] - .get_with_shape( - mshadow::Shape2(seq_len * batch_size, in_channel), s); // (T*N)C - Tensor y_flatten = out_data[rnn_enum::kOut] - .get_with_shape( - mshadow::Shape2( - y.shape_[0] * y.shape_[1], y.shape_[2]), s); // (T*N)C - + Tensor hx = in_data[rnn_enum::kState].get(s); + Tensor y = out_data[rnn_enum::kOut].get(s); CHECK(x.CheckContiguous()); CHECK(w.CheckContiguous()); CHECK(hx.CheckContiguous()); CHECK(y.CheckContiguous()); + param_.seq_length_ = x.shape_[0]; + param_.batch_size_ = x.shape_[1]; + param_.input_size_ = x.shape_[2]; + + const int direction = param_.bidirectional ? 2 : 1; + const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, direction, param_.mode); + DType* b_ptr = w.dptr_ + w.shape_[0] - bsize; + + DType* hy_ptr = NULL; + if (param_.state_outputs) { + hy_ptr = out_data[rnn_enum::kStateOut].dptr(); + } + DType* cx_ptr = NULL; + DType* cy_ptr = NULL; - if (param_.lstm_q_) { - const size_t kNumMat = 4; - int64_t fused_h_ch = kNumMat * h_channel; - int64_t h_size = batch_size * fused_h_ch; - int64_t num_dir = 1 + param_.bidirectional; - int64_t h2h_w_size = h_channel * fused_h_ch; - - Tensor cx = - in_data[rnn_enum::kStateCell].get(s); - CHECK(cx.CheckContiguous()); - - Tensor cy = - out_data[rnn_enum::kStateCellOut].get(s); - Tensor hy = - out_data[rnn_enum::kStateOut].get(s); - CHECK(cy.CheckContiguous()); - CHECK(hy.CheckContiguous()); - - DType* workspace_addr = - static_cast(ctx.requested[rnn_enum::kTempSpace] - .get_host_space_internal(sizeof(DType) * - (seq_len * h_size + h_size - + y.shape_[0] * y.shape_[1] * y.shape_[2]))); - Tensor i2h_y( - workspace_addr, mshadow::Shape3(seq_len, batch_size, fused_h_ch)); - Tensor i2h_y_flatten( - workspace_addr, mshadow::Shape2(seq_len * batch_size, fused_h_ch)); - Tensor h2h_y(workspace_addr - + seq_len * h_size, mshadow::Shape2(batch_size, fused_h_ch)); - Tensor y_tmp(workspace_addr - + (seq_len + 1) * h_size, y.shape_); - Tensor y_flatten_tmp(workspace_addr - + (seq_len + 1) * h_size, y_flatten.shape_); - CHECK(i2h_y.CheckContiguous()); - CHECK(h2h_y.CheckContiguous()); - CHECK(y_tmp.CheckContiguous()); - - for (int64_t layer = 0; layer < num_layers; layer++) { - int reverse_dir = 0; - int out_tmp = 0; - if (param_.bidirectional && layer % 2) - reverse_dir = 1; - if (layer / num_dir % 2 == 0) - out_tmp = 1; - mshadow::Shape<2> i2h_w_shape = mshadow::Shape2(fused_h_ch, - (layer < num_dir) ? in_channel : num_dir * h_channel); - mshadow::Shape<2> h2h_w_shape = mshadow::Shape2(fused_h_ch, h_channel); - int64_t start = layer < num_dir ? - (layer * (in_channel * fused_h_ch + h2h_w_size)) : // input layer - (num_dir * (in_channel * fused_h_ch + h2h_w_size) - + (layer - num_dir) * (h2h_w_size * num_dir + h2h_w_size)); - Tensor i2h_w(w.dptr_ + start, i2h_w_shape); - start += layer < num_dir ? - in_channel * fused_h_ch : h2h_w_size * num_dir; - Tensor h2h_w(w.dptr_ + start, h2h_w_shape); - start = num_dir * (in_channel * fused_h_ch + h2h_w_size) - + (num_layers - num_dir) * (h2h_w_size * (num_dir + 1)) - + layer * fused_h_ch * 2; - Tensor i2h_b = w.Slice(start, start + fused_h_ch); - start += fused_h_ch; - Tensor h2h_b = w.Slice(start, start + fused_h_ch); - if (out_tmp) { - linalg_gemm(layer < num_dir ? x_flatten:y_flatten, i2h_w, - i2h_y_flatten, false, true, s); - } else { - linalg_gemm(layer < num_dir ? x_flatten:y_flatten_tmp, i2h_w, - i2h_y_flatten, false, true, s); - } - i2h_y_flatten += repmat(i2h_b, seq_len * batch_size); - for (int64_t t = 0; t < seq_len; t++) { - int64_t timestep = t; - if (reverse_dir) - timestep = seq_len - 1 - t; - linalg_gemm(t == 0 ? hx[layer]:hy[layer], h2h_w, h2h_y, - false, true, s); - h2h_y += repmat(h2h_b, batch_size); - // fused element-wise ops - LSTMFusedElementWiseCPUOps(i2h_y[timestep], cx[layer], h2h_y, - y[timestep], out_tmp ? y_tmp[timestep]: y[timestep], - hy[layer], cy[layer], batch_size, h_channel, t, - reverse_dir, out_tmp && (layer == num_layers - 1)); - } + if (param_.mode == rnn_enum::kLstm) { + cx_ptr = in_data[rnn_enum::kStateCell].dptr(); + if (param_.state_outputs) { + cy_ptr = out_data[rnn_enum::kStateCellOut].dptr(); } + } + + // allocate temp space + const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, + param_.state_size, direction, param_.mode); + Tensor workspace = ctx.requested[rnn_enum::kTempSpace] + .get_space_typed(Shape1(workspace_size), s); + + if (ctx.is_train) { + const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, + param_.seq_length_, param_.batch_size_, + param_.state_size, param_.mode); + if (init_space_ && reserve_space_size_ < r_size) { + Storage::Get()->Free(reserve_space_); + init_space_ = false; + } + + if (!init_space_) { + reserve_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), Context::CPU()); + reserve_space_size_ = r_size; + init_space_ = true; + } + + DType* reserve_space_ptr = static_cast(reserve_space_.dptr); + RNNForwardTraining(workspace.dptr_, + reserve_space_ptr, + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.mode); } else { - LOG(FATAL) << "only LSTM is available for cpu at the moment."; + RNNForwardInference(workspace.dptr_, + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.mode); } } virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, - const std::vector &out_data, + const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_args) { - LOG(FATAL) << "LSTM backward is not available for cpu at the moment."; - } + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(param_.mode, rnn_enum::kLstm) << "Only lstm mode is supported at the moment."; + CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment."; + size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3; + size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2; + if (!param_.state_outputs) { + out_expected = 1; + } + CHECK_EQ(in_data.size(), in_expected); + CHECK_EQ(out_data.size(), out_expected); + CHECK_EQ(in_grad.size(), in_expected); + CHECK_EQ(out_grad.size(), out_expected); + CHECK_EQ(req.size(), in_expected); + CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for data"; + CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for state"; + mshadow::Stream *s = ctx.get_stream(); + // get input + output tensors + Tensor x = in_data[rnn_enum::kData].get(s); + Tensor w = in_data[rnn_enum::kParams].get(s); + Tensor hx = in_data[rnn_enum::kState].get(s); + Tensor y = out_data[rnn_enum::kOut].get(s); + Tensor dx = in_grad[rnn_enum::kData].get(s); + Tensor dw = in_grad[rnn_enum::kParams].get(s); + Tensor dhx = in_grad[rnn_enum::kState].get(s); + Tensor dy = out_grad[rnn_enum::kOut].get(s); + CHECK(x.CheckContiguous()); + CHECK(w.CheckContiguous()); + CHECK(hx.CheckContiguous()); + CHECK(y.CheckContiguous()); + CHECK(dx.CheckContiguous()); + CHECK(dw.CheckContiguous()); + CHECK(dhx.CheckContiguous()); + CHECK(dy.CheckContiguous()); + param_.seq_length_ = x.shape_[0]; + param_.batch_size_ = x.shape_[1]; + param_.input_size_ = x.shape_[2]; + + const int direction = param_.bidirectional ? 2 : 1; + const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, direction, param_.mode); + DType* db_ptr = dw.dptr_ + w.shape_[0] - bsize; + + DType * dhy_ptr = NULL; + if (param_.state_outputs) { + dhy_ptr = out_grad[rnn_enum::kStateOut].dptr(); + } - private: - RNNParam param_; + DType * cx_ptr = NULL; + DType * dcx_ptr = NULL; + DType * dcy_ptr = NULL; - void LSTMFusedElementWiseCPUOps(const Tensor &i2h_y, - const Tensor &cx, - const Tensor &h2h_y, - const Tensor &y, - // holding intermediate layer output - const Tensor &tmp, - const Tensor &hy, - const Tensor &cy, - const int64_t batch_size, - const int64_t h_channel, - const int64_t t, - const int reverse_dir, - const int copy_tmp2y) { - int64_t length = batch_size * h_channel; - #pragma omp parallel for - for (int64_t ji = 0; ji < length; ++ji) { - int64_t j = ji / h_channel; // batch dim - int64_t i = ji % h_channel; - int64_t f = i + h_channel; - int64_t c = i + h_channel * 2; - int64_t o = i + h_channel * 3; - int64_t j_pos = j * h_channel * 4; - h2h_y.dptr_[j_pos + i] += i2h_y.dptr_[j_pos + i]; - h2h_y.dptr_[j_pos + f] += i2h_y.dptr_[j_pos + f]; - h2h_y.dptr_[j_pos + o] += i2h_y.dptr_[j_pos + o]; - h2h_y.dptr_[j_pos + c] += i2h_y.dptr_[j_pos + c]; - h2h_y.dptr_[j_pos + i] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + i])); - h2h_y.dptr_[j_pos + f] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + f])); - h2h_y.dptr_[j_pos + o] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + o])); - h2h_y.dptr_[j_pos + c] = tanh(h2h_y.dptr_[j_pos + c]); - cy[j][i] = h2h_y.dptr_[j_pos + f] * (t == 0 ? cx[j][i]:cy[j][i]) - + h2h_y.dptr_[j_pos + i] * h2h_y.dptr_[j_pos + c]; - hy[j][i] = h2h_y.dptr_[j_pos + o] * tanh(cy[j][i]); - tmp[j][i + h_channel * reverse_dir] = hy[j][i]; - if (copy_tmp2y) { - y[j][i] = tmp[j][i]; - if (reverse_dir) - y[j][i + h_channel] = tmp[j][i + h_channel]; + if (param_.mode == rnn_enum::kLstm) { + CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported for state cell"; + cx_ptr = in_data[rnn_enum::kStateCell].dptr(); + dcx_ptr = in_grad[rnn_enum::kStateCell].dptr(); + if (param_.state_outputs) { + dcy_ptr = out_grad[rnn_enum::kStateCellOut].dptr(); } } + + // allocate temp space + const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, + param_.state_size, direction, param_.mode); + Tensor workspace = ctx.requested[rnn_enum::kTempSpace] + .get_space_typed(Shape1(workspace_size), s); + + size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, + param_.seq_length_, param_.batch_size_, + param_.state_size, param_.mode); + if (!init_space_ || reserve_space_size_ != r_size) { + LOG(FATAL) << "Check forward init error"; + } + + DType* reserve_space_ptr = static_cast(reserve_space_.dptr); + RNNBackward(workspace.dptr_, + reserve_space_ptr, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + y.dptr_, + dy.dptr_, + dhy_ptr, + dcy_ptr, + dx.dptr_, + dhx.dptr_, + dcx_ptr, + dw.dptr_, + db_ptr, + param_.mode); } + + private: + RNNParam param_; + bool init_space_; + size_t reserve_space_size_; + Storage::Handle reserve_space_; }; // class RNNOp template @@ -429,10 +611,10 @@ class RNNProp : public OperatorProperty { Shape3(total_layers, batch_size, param_.state_size)); // calculate parameter vector length - int param_size = rnn_param_size(param_.num_layers, + int param_size = GetRnnParamSize(param_.num_layers, input_size, param_.state_size, - param_.bidirectional, + numDirections, param_.mode); SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size)); diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index a60adbcd2fbc..6da367d3b80b 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -23,7 +23,6 @@ * \brief * \author Sebastian Bodenstein */ - #include "./rnn-inl.h" namespace mxnet { @@ -32,7 +31,7 @@ template<> Operator *CreateOp(RNNParam param, int dtype) { Operator *op = NULL; MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new RNNOp(param); + op = new RNNOp(param); }); return op; } @@ -46,7 +45,50 @@ Operator *RNNProp::CreateOperatorEx(Context ctx, DMLC_REGISTER_PARAMETER(RNNParam); MXNET_REGISTER_OP_PROPERTY(RNN, RNNProp) -.describe("Applies a recurrent layer to input.") +.describe(R"code(Applies recurrent layers to input. +Currently, vanilla RNN, LSTM and GRU are implemented, with + both multi-layer and bidirectional support. +**Vanilla RNN** +Applies a single-gate recurrent layer to input X. Two kinds of + activation function are supported: ReLU and tanh. + +ReLU activation function: + +.. math:: + $h_t = relu(w_{ih} * x_t + b_{ih} + w_{hh} * h_{(t-1)} + b_{hh})$ + +Tanh activtion function: + +.. math:: + $h_t = \tanh(w_{ih} * x_t + b_{ih} + w_{hh} * h_{(t-1)} + b_{hh})$ + +Reference paper: Finding structure in time - Elman, 1988. + https://crl.ucsd.edu/~elman/Papers/fsit.pdf + +**LSTM** +Long Short-Term Memory - Hochreiter, 1997. + +.. math:: + \begin{array}{ll} + i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\ + f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\ + g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\ + o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\ + c_t = f_t * c_{(t-1)} + i_t * g_t \\ + h_t = o_t * \tanh(c_t) + \end{array} + +**GRU** +Gated Recurrent Unit - Cho et al. 2014. +http://arxiv.org/abs/1406.1078 + +.. math:: +\begin{array}{ll} + r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ + z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ + n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ + h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\ + \end{array})code") .add_argument("data", "NDArray-or-Symbol", "Input data to RNN") .add_argument("parameters", "NDArray-or-Symbol", "Vector of all RNN trainable parameters concatenated") diff --git a/src/operator/rnn_impl.h b/src/operator/rnn_impl.h new file mode 100644 index 000000000000..2ee374bbf569 --- /dev/null +++ b/src/operator/rnn_impl.h @@ -0,0 +1,457 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file rnn_impl.h + * \brief + * \author Shu Zhang +*/ +#ifndef MXNET_OPERATOR_RNN_IMPL_H_ +#define MXNET_OPERATOR_RNN_IMPL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "./math.h" +#include "./math_functions-inl.h" +#include "./operator_common.h" +#include "./mshadow_op.h" +#include "./linalg.h" + +template +inline DType sigmoid(DType x) { + return 1.0f / (1.0f + exp(-x)); +} + +template +void LstmForwardTrainingSingleLayer(DType* ws, + DType* rs, + bool state_outputs, + bool bid, + const int T, + const int N, + const int I, + const int H, + const Tensor &x, + const Tensor &hx, + const Tensor &cx, + const Tensor &y, + DType* w_ptr, + DType* b_ptr, + DType* hy_ptr, + DType* cy_ptr) { + using namespace mshadow; + const Tensor wx(w_ptr, Shape2(H * 4, I)); + const Tensor wh(w_ptr + I * H * 4, Shape2(H * 4, H)); + const Tensor bx(b_ptr, Shape2(4, H)); + const Tensor bh(b_ptr + H * 4, Shape2(4, H)); + const Tensor yx_flat(ws, Shape2(T * N, 4 * H)); + const Tensor yh_flat(ws + T * N * H * 4, Shape2(N, 4 * H)); + const Tensor yx(yx_flat.dptr_, Shape4(T, N, 4, H)); + const Tensor yh(yh_flat.dptr_, Shape3(N, 4, H)); + Tensor h(yh_flat.dptr_ + N * H * 4, Shape2(N, H)); + DType *c_ptr = bid ? rs + T * N * H * 7 : rs; + Tensor c(c_ptr, Shape3(T, N, H)); + Tensor ifgo(c_ptr + T * N * H, Shape4(T, N, H, 4)); + + const int offset = bid ? H : 0; + const DType alpha = 1.0; + const DType beta = 0.0; + const int cell_size = N * H; + linalg_gemm(x, wx, yx_flat, alpha, beta, false, true); + + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + for (int i = 0; i < T; ++i) { + int t = bid ? T - 1 - i : i; + linalg_gemm(i ? h : hx, wh, yh_flat, alpha, beta, false, true); + #pragma omp parallel for num_threads(omp_threads) + for (int jk = 0; jk < cell_size; ++jk) { + int j = jk / H; + int k = jk % H; + DType it = sigmoid(yx[t][j][0][k] + yh[j][0][k] + bx[0][k] + bh[0][k]); + DType ft = sigmoid(yx[t][j][1][k] + yh[j][1][k] + bx[1][k] + bh[1][k]); + DType gt = tanh(yx[t][j][2][k] + yh[j][2][k] + bx[2][k] + bh[2][k]); + DType ot = sigmoid(yx[t][j][3][k] + yh[j][3][k] + bx[3][k] + bh[3][k]); + DType ct = (i ? c[i-1][j][k] : cx[j][k]) * ft + it * gt; + DType ht = ot * tanh(ct); + h[j][k] = ht; + // reserve + y[t][j][k + offset] = ht; + c[i][j][k] = ct; + ifgo[i][j][k][0] = it; + ifgo[i][j][k][1] = ft; + ifgo[i][j][k][2] = gt; + ifgo[i][j][k][3] = ot; + if (i == T - 1 && state_outputs) { + hy_ptr[jk] = ht; + cy_ptr[jk] = ct; + } + } + } +} + +template +void LstmForwardTraining(DType* ws, + DType* rs, + bool state_outputs, + const int L, + const int D, + const int T, + const int N, + const int I, + const int H, + DType* x_ptr, + DType* hx_ptr, + DType* cx_ptr, + DType* w_ptr, + DType* b_ptr, + DType* y_ptr, + DType* hy_ptr, + DType* cy_ptr) { + const int total_layers = D * L; + Tensor hx(hx_ptr, Shape3(total_layers, N, H)); + Tensor cx(cx_ptr, Shape3(total_layers, N, H)); + const int b_size = 2 * H * 4; + const int r_size = D * T * N * H * 6; + const int y_offset = T * N * H * 5; + const int cell_size = N * H; + int idx = 0; // state & cell state's idx; + for (int i = 0; i < L; ++i) { + const int input_size = i ? H * D : I; + const int w_size = (input_size + H) * H * 4; + Tensor x(x_ptr, Shape2(T * N, input_size)); + Tensor y(rs + y_offset, Shape3(T, N, H * D)); + LstmForwardTrainingSingleLayer(ws, rs, state_outputs, false, T, N, input_size, H, x, + hx[idx], cx[idx], y, w_ptr, b_ptr, hy_ptr, cy_ptr); + if (D == 2) { + w_ptr += w_size; + b_ptr += b_size; + ++idx; + if (state_outputs) { + hy_ptr += cell_size; + cy_ptr += cell_size; + } + LstmForwardTrainingSingleLayer(ws, rs, state_outputs, true, T, N, input_size, H, x, + hx[idx], cx[idx], y, w_ptr, b_ptr, hy_ptr, cy_ptr); + } + if (i != L - 1) { + w_ptr += w_size; + b_ptr += b_size; + x_ptr = y.dptr_; + rs += r_size; + ++idx; + if (state_outputs) { + hy_ptr += cell_size; + cy_ptr += cell_size; + } + } + } + memcpy(y_ptr, rs + y_offset, T * N * H * D * sizeof(DType)); +} + +template +void LstmForwardInferenceSingleLayer(DType* ws, + bool state_outputs, + bool bid, + const int T, + const int N, + const int I, + const int H, + const Tensor &x, + const Tensor &hx, + const Tensor &cx, + const Tensor &y, + DType* w_ptr, + DType* b_ptr, + DType* hy_ptr, + DType* cy_ptr) { + using namespace mshadow; + const Tensor wx(w_ptr, Shape2(H * 4, I)); + const Tensor wh(w_ptr + I * H * 4, Shape2(H * 4, H)); + const Tensor bx(b_ptr, Shape2(4, H)); + const Tensor bh(b_ptr + H * 4, Shape2(4, H)); + Tensor yx_flat(ws, Shape2(T * N, H * 4)); + Tensor yh_flat(ws + T * N * H * 4, Shape2(N, H * 4)); + const Tensor yx(yx_flat.dptr_, Shape4(T, N, 4, H)); + const Tensor yh(yh_flat.dptr_, Shape3(N, 4, H)); + Tensor h(yh_flat.dptr_ + N * H * 4, Shape2(N, H)); + Tensor c(h.dptr_ + N * H, Shape2(N, H)); + const int offset = bid ? H : 0; + const DType alpha = 1.0; + const DType beta = 0.0; + const int cell_size = N * H; + linalg_gemm(x, wx, yx_flat, alpha, beta, false, true); + + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + for (int i = 0; i < T; ++i) { + int t = bid ? T - 1 - i : i; + linalg_gemm(i ? h : hx, wh, yh_flat, alpha, beta, false, true); + #pragma omp parallel for num_threads(omp_threads) + for (int jk = 0; jk < cell_size; ++jk) { + int j = jk / H; + int k = jk % H; + DType it = sigmoid(yx[t][j][0][k] + yh[j][0][k] + bx[0][k] + bh[0][k]); + DType ft = sigmoid(yx[t][j][1][k] + yh[j][1][k] + bx[1][k] + bh[1][k]); + DType gt = tanh(yx[t][j][2][k] + yh[j][2][k] + bx[2][k] + bh[2][k]); + DType ot = sigmoid(yx[t][j][3][k] + yh[j][3][k] + bx[3][k] + bh[3][k]); + DType ct = (i ? c[j][k] : cx[j][k]) * ft + it * gt; + DType ht = ot * tanh(ct); + y[t][j][k + offset] = ht; + if (i == T - 1 && state_outputs) { + hy_ptr[jk] = ht; + cy_ptr[jk] = ct; + } else { + h[j][k] = ht; + c[j][k] = ct; + } + } + } +} + +template +void LstmForwardInference(DType* ws, + bool state_outputs, + const int L, + const int D, + const int T, + const int N, + const int I, + const int H, + DType* x_ptr, + DType* hx_ptr, + DType* cx_ptr, + DType* w_ptr, + DType* b_ptr, + DType* y_ptr, + DType* hy_ptr, + DType* cy_ptr) { + const int total_layers = D * L; + Tensor hx(hx_ptr, Shape3(total_layers, N, H)); + Tensor cx(cx_ptr, Shape3(total_layers, N, H)); + const int b_size = 2 * H * 4; + const int cell_size = N * H; + DType* y_tmp_ptr = ws + (T + 1) * cell_size * 4 + cell_size * 2; + DType* y_cur_ptr = y_ptr; + int idx = 0; // state & cell state's idx; + bool flag = L % 2 ? false : true; + for (int i = 0; i < L; ++i) { + const int input_size = i ? H * D : I; + const int w_size = (input_size + H) * H * 4; + // If bidirectional, need space to save current layer output y. + if (D == 2) { + y_cur_ptr = flag ? y_tmp_ptr : y_ptr; + flag = !flag; + } + Tensor x(x_ptr, Shape2(T * N, input_size)); + Tensor y(y_cur_ptr, Shape3(T, N, H * D)); + LstmForwardInferenceSingleLayer(ws, state_outputs, false, T, N, input_size, H, + x, hx[idx], cx[idx], y, w_ptr, b_ptr, hy_ptr, cy_ptr); + // If bidirectional, then calculate the reverse direction's forward result. + if (D == 2) { + w_ptr += w_size; + b_ptr += b_size; + ++idx; + if (state_outputs) { + hy_ptr += cell_size; + cy_ptr += cell_size; + } + LstmForwardInferenceSingleLayer(ws, state_outputs, true, T, N, input_size, H, + x, hx[idx], cx[idx], y, w_ptr, b_ptr, hy_ptr, cy_ptr); + } + // Don't need to move pointer in the last layer. + if (i != L - 1) { + w_ptr += w_size; + b_ptr += b_size; + x_ptr = y_cur_ptr; + ++idx; + if (state_outputs) { + hy_ptr += cell_size; + cy_ptr += cell_size; + } + } + } +} + +template +void LstmBackwardSingleLayer(DType* ws, + DType* rs, + bool bid, + const int T, + const int N, + const int I, + const int H, + const Tensor &x, + const Tensor &hx, + const Tensor &cx, + const Tensor &y, + const Tensor &dy, + const Tensor &dx, + const Tensor &dhx, + const Tensor &dcx, + DType* dhy_ptr, + DType* dcy_ptr, + DType* w_ptr, + DType* dw_ptr, + DType* db_ptr) { + using namespace mshadow; + const Tensor wx(w_ptr, Shape2(H * 4, I)); + const Tensor wh(w_ptr + I * H * 4, Shape2(H * 4, H)); + Tensor dwx(dw_ptr, Shape2(H * 4, I)); + Tensor dwh(dw_ptr + I * H * 4, Shape2(H * 4, H)); + Tensor dbx(db_ptr, Shape1(H * 4)); + Tensor dbh(dbx.dptr_ + H * 4, Shape1(H * 4)); + DType *c_ptr = bid ? rs + T * N * H * 7 : rs; + const Tensor c(c_ptr, Shape3(T, N, H)); + const Tensor ifgo(c_ptr + T * N * H, Shape4(T, N, H, 4)); + memset(dwh.dptr_, 0, H * H * 4 * sizeof(DType)); + memset(dbx.dptr_, 0, H * 4 * sizeof(DType)); + memset(dbh.dptr_, 0, H * 4 * sizeof(DType)); + Tensor difgo(ws, Shape4(T, N, 4, H)); + Tensor dh(ws + T * N * H * 4, Shape2(N, H)); + Tensor dc(dh.dptr_ + N * H, Shape2(N, H)); + Tensor htmp(dc.dptr_ + N * H, Shape2(N, H)); + const int offset = bid ? H : 0; + const DType alpha = 1.0; + const DType beta0 = 0.0; + const DType beta1 = 1.0; + const int cell_size = N * H; + if (dhy_ptr != NULL) { + memcpy(dh.dptr_, dhy_ptr, cell_size * sizeof(DType)); + } + if (dcy_ptr != NULL) { + memcpy(dc.dptr_, dcy_ptr, cell_size * sizeof(DType)); + } + + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + for (int i = T - 1; i >= 0; --i) { + int t = bid ? T - 1 - i : i; + int tnext = bid ? t + 1 : t - 1; + const Tensor& dhnext = i ? dh : dhx; + const Tensor& dcnext = i ? dc : dcx; + const Tensor& hnext = i ? htmp : hx; + const Tensor& cnext = i ? c[i - 1] : cx; + #pragma omp parallel for num_threads(omp_threads) + for (int jk = 0; jk < cell_size; ++jk) { + int j = jk / H; + int k = jk % H; + DType tc = tanh(c[i][j][k]); + DType it = ifgo[i][j][k][0]; + DType ft = ifgo[i][j][k][1]; + DType gt = ifgo[i][j][k][2]; + DType ot = ifgo[i][j][k][3]; + dh[j][k] += dy[t][j][k + offset]; + dc[j][k] += dh[j][k] * ot * (1 - tc * tc); + difgo[t][j][0][k] = dc[j][k] * gt * it * (1 - it); + difgo[t][j][1][k] = dc[j][k] * cnext[j][k] * ft * (1 - ft); + difgo[t][j][2][k] = dc[j][k] * it * (1 - gt * gt); + difgo[t][j][3][k] = dh[j][k] * tc * ot * (1 - ot); + dcnext[j][k] = dc[j][k] * ft; + if (i) { + htmp[j][k] = y[tnext][j][k + offset]; + } + } + Tensor dyh(difgo[t].dptr_, Shape2(N, H * 4)); + linalg_gemm(dyh, wh, dhnext, alpha, beta0, false, false); + linalg_gemm(dyh, hnext, dwh, alpha, beta1, true, false); + } + Tensor dyx(difgo.dptr_, Shape2(T * N, H * 4)); + linalg_gemm(dyx, wx, dx, alpha, bid ? beta1 : beta0, false, false); + linalg_gemm(dyx, x, dwx, alpha, beta0, true, false); + const int row = T * N; + const int col = H * 4; + for (int i = 0; i < row; ++i) { + for (int j = 0; j < col; ++j) { + dbx[j] += dyx[i][j]; + dbh[j] = dbx[j]; + } + } +} + +template +void LstmBackward(DType* ws, + DType* rs, + const int L, + const int D, + const int T, + const int N, + const int I, + const int H, + DType* x_ptr, + DType* hx_ptr, + DType* cx_ptr, + DType* w_ptr, + DType* y_ptr, + DType* dy_ptr, + DType* dhy_ptr, + DType* dcy_ptr, + DType* dx_ptr, + DType* dhx_ptr, + DType* dcx_ptr, + DType* dw_ptr, + DType* db_ptr) { + const int total_layers = D * L; + Tensor hx(hx_ptr, Shape3(total_layers, N, H)); + Tensor cx(cx_ptr, Shape3(total_layers, N, H)); + Tensor dhx(dhx_ptr, Shape3(total_layers, N, H)); + Tensor dcx(dcx_ptr, Shape3(total_layers, N, H)); + const int b_size = 2 * H * 4; + const int r_size = D * T * N * H * 6; + const int y_offset = T * N * H * 5; + const int w_size1 = (I + H) * H * 4; // first layer + const int w_size2 = (D * H + H) * H * 4; // other layers + const int cell_size = N * H; + DType* dy_tmp_ptr = ws + T * cell_size * 4 + cell_size * 3; + for (int i = L - 1; i >= 0; --i) { + const int input_size = i ? H * D : I; + const int w_size = i ? w_size2 : w_size1; + int idx = i * D; + DType* w_cur_ptr = i ? w_ptr + (w_size1 + (i - 1) * w_size2) * D : w_ptr; + DType* dw_cur_ptr = i ? dw_ptr + (w_size1 + (i - 1) * w_size2) * D : dw_ptr; + DType* db_cur_ptr = db_ptr + i * b_size * D; + DType* rs_cur_ptr = rs + i * r_size; + DType* dhy_cur_ptr = dhy_ptr ? dhy_ptr + i * cell_size * D : NULL; + DType* dcy_cur_ptr = dcy_ptr ? dcy_ptr + i * cell_size * D : NULL; + Tensor y(rs_cur_ptr + y_offset, Shape3(T, N, H * D)); + Tensor dy(dy_ptr, Shape3(T, N, H * D)); + Tensor x(i ? y.dptr_ - r_size : x_ptr, Shape2(T * N, input_size)); + Tensor dx(i ? dy_tmp_ptr : dx_ptr, Shape2(T * N, input_size)); + LstmBackwardSingleLayer(ws, rs_cur_ptr, false, T, N, input_size, H, + x, hx[idx], cx[idx], y, dy, dx, dhx[idx], dcx[idx], + dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr); + if (D == 2) { + w_cur_ptr += w_size; + dw_cur_ptr += w_size; + db_cur_ptr += b_size; + ++idx; + dhy_cur_ptr = dhy_ptr ? dhy_cur_ptr + cell_size : NULL; + dcy_cur_ptr = dcy_ptr ? dcy_cur_ptr + cell_size : NULL; + LstmBackwardSingleLayer(ws, rs_cur_ptr, true, T, N, input_size, H, + x, hx[idx], cx[idx], y, dy, dx, dhx[idx], dcx[idx], + dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr); + } + dy_ptr = dx.dptr_; + } +} +#endif // MXNET_OPERATOR_RNN_IMPL_H_ diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 43c62e18845b..d356e7892898 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1272,19 +1272,6 @@ def test_rnn(): check_rnn_consistency(fused, stack) check_rnn_consistency(stack, fused) - -@with_seed() -def test_lstm(): - fused = mx.rnn.FusedRNNCell(100, num_layers=2, mode='lstm', prefix='') - - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.LSTMCell(100, prefix='l0_')) - stack.add(mx.rnn.LSTMCell(100, prefix='l1_')) - - check_rnn_consistency(fused, stack) - check_rnn_consistency(stack, fused) - - @with_seed() def test_lstm_forget_bias(): forget_bias = 2.0 @@ -1306,7 +1293,6 @@ def test_lstm_forget_bias(): expected_bias = forget_bias * np.ones(10, ) assert_allclose(args[bias_name].asnumpy(), expected_bias) - @with_seed() def test_gru(): fused = mx.rnn.FusedRNNCell(100, num_layers=2, mode='gru', prefix='') @@ -1318,7 +1304,6 @@ def test_gru(): check_rnn_consistency(fused, stack) check_rnn_consistency(stack, fused) - @with_seed() def test_bidirectional(): fused = mx.rnn.FusedRNNCell(100, num_layers=2, mode='gru', prefix='', @@ -1337,7 +1322,6 @@ def test_bidirectional(): check_rnn_consistency(fused, stack) check_rnn_consistency(stack, fused) - @with_seed() def test_unfuse(): for mode in ['rnn_tanh', 'rnn_relu', 'lstm', 'gru']: @@ -1519,7 +1503,6 @@ def test_deformable_convolution_options(): sym = mx.sym.contrib.DeformableConvolution(num_filter=4, kernel=(3,3), num_deformable_group=2, name='deformable_conv') - @with_seed() def test_residual_fused(): cell = mx.rnn.ResidualCell( diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 06e0675c0b82..0a6de8e7a1b8 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -28,6 +28,89 @@ from common import setup_module, with_seed import unittest +def check_rnn_consistency(cell1, cell2, T, N, I, H): + dshape = (N, T, I) + data = mx.sym.Variable('data') + + Y1, _ = cell1.unroll(T, data, layout='NTC', merge_outputs=True) + mod1 = mx.mod.Module(Y1, label_names=None, context=default_context()) + mod1.bind(data_shapes=[('data', dshape)], label_shapes=None, inputs_need_grad=True) + + Y2, _ = cell2.unroll(T, data, layout='NTC', merge_outputs=True) + mod2 = mx.mod.Module(Y2, label_names=None, context=default_context()) + mod2.bind(data_shapes=[('data', dshape)], label_shapes=None, inputs_need_grad=True) + + mod1.init_params() + args, auxs = mod1.get_params() + args = cell1.unpack_weights(args) + args = cell2.pack_weights(args) + mod2.set_params(args, auxs) + + x = mx.random.uniform(shape=dshape) + batch=mx.io.DataBatch(data=[x]) + # check inference + mod1.forward(batch, is_train=False) + mod2.forward(batch, is_train=False) + assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=1e-2, atol=1e-4) + + # check training + mod1.forward(batch, is_train=True) + mod2.forward(batch, is_train=True) + assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=1e-2, atol=1e-4) + + dy = mx.random.uniform(shape=mod1.get_outputs()[0].shape) + mod1.backward(out_grads=[dy]) + mod2.backward(out_grads=[dy]) + assert_allclose(mod1.get_input_grads()[0].asnumpy(), mod2.get_input_grads()[0].asnumpy(), rtol=1e-2, atol=1e-4) + +@with_seed() +def test_lstm_sym(): + T, N, I, H = 5, 32, 800, 800 + fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='lstm', get_next_state=True, prefix='') + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.LSTMCell(H, prefix='l0_')) + stack.add(mx.rnn.LSTMCell(H, prefix='l1_')) + stack.add(mx.rnn.LSTMCell(H, prefix='l2_')) + check_rnn_consistency(fused, stack, T, N, I, H) + check_rnn_consistency(stack, fused, T, N, I, H) + +@with_seed() +def test_lstm_bidirectional(): + T, N, I, H = 5, 20, 800, 800 + fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='lstm', + bidirectional=True, get_next_state=True, prefix='') + + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.LSTMCell(H, prefix='l0_'), + mx.rnn.LSTMCell(H, prefix='r0_'), + output_prefix='bi_lstm_0_')) + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.LSTMCell(H, prefix='l1_'), + mx.rnn.LSTMCell(H, prefix='r1_'), + output_prefix='bi_lstm_1_')) + + check_rnn_consistency(stack, fused, T, N, I, H) + check_rnn_consistency(fused, stack, T, N, I, H) + +# Currently, fused LSTM operator doesn't support dropout. +# Will change this test after dropout is supported +@with_seed() +def test_lstm_dropout(): + X = mx.sym.Variable('x') + Params = mx.sym.Variable('params') + HX = mx.sym.Variable('state') + CX = mx.sym.Variable('state_cell') + T, N, I, H = 300, 20, 800, 800 + rnn = mx.sym.RNN(data=X, parameters=Params, state=HX, state_cell=CX, + state_size=H, num_layers=5, mode='lstm', p=0.5, state_outputs=True, name='LSTM') + exe = rnn.simple_bind(ctx=mx.cpu(), x=(T, N, I)) + try: + out = exe.forward(is_train=False) + out[0].wait_to_read() + assert False # should not reach here + except mx.base.MXNetError as err: + assert str(err).find('Dropout is not supported at the moment') != -1 def np_softmax(x, axis=-1): # fix for old numpy on Travis not supporting keepdims