Skip to content

Commit

Permalink
Fix MKLDNN sigmoid/softrelu issue (apache#10336)
Browse files Browse the repository at this point in the history
* Fix MKLDNN sigmoid/softrelu issue

* Enable Sigmoid and SoftRelu for MKLDNN

* Add activation kData for backward calculation for MKLDNN

* Add tanh support for MKLDNN activation

* Adjust rtol to pass tanh tests for MKLDNN
  • Loading branch information
jinhuang415 authored and lanking520 committed Apr 2, 2018
1 parent 499a186 commit 1dcdc44
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/operator/nn/activation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
#if MXNET_USE_CUDNN == 1
#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
CHECK_EQ(inputs.size(), 3U);
#else
CHECK_EQ(inputs.size(), 2U);
Expand Down
12 changes: 4 additions & 8 deletions src/operator/nn/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct ActivationGrad {
const std::vector<nnvm::NodeEntry>& ograds) const {
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
heads.emplace_back(nnvm::NodeEntry{n, activation::kOut, 0});
#if MXNET_USE_CUDNN == 1
#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
heads.push_back(n->inputs[activation::kData]);
#endif
return MakeGradNode(op_name, n, heads, n->attrs.dict);
Expand Down Expand Up @@ -74,15 +74,11 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
#if MXNET_USE_CUDNN == 1
CHECK_EQ(inputs.size(), 3U);
#else
CHECK_EQ(inputs.size(), 2U);
#endif
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
if (SupportMKLDNN(inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNActivationBackward(attrs, ctx, inputs[0], inputs[1], req[0],
MKLDNNActivationBackward(attrs, ctx, inputs[0], inputs[2], req[0],
outputs[0]);
MKLDNN_OPCHECK_RUN(ActivationGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
Expand Down Expand Up @@ -116,13 +112,13 @@ inline static bool BackwardActStorageType(const nnvm::NodeAttrs& attrs,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
#if MXNET_USE_CUDNN == 1
#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
CHECK_EQ(in_attrs->size(), 3U);
#else
CHECK_EQ(in_attrs->size(), 2U);
#endif
CHECK_EQ(out_attrs->size(), 1U);
#if MXNET_USE_CUDNN == 1
#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
bool ret = ElemwiseStorageType<3, 1, false, false, false>(attrs, dev_mask,
dispatch_mode,
in_attrs, out_attrs);
Expand Down
9 changes: 3 additions & 6 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,10 @@ namespace mxnet {
namespace op {

bool SupportMKLDNNAct(const ActivationParam& param) {
// We only enable ReLU for now. It seems other activations have some precision
// problems.
return param.act_type == activation::kReLU;
#if 0
return param.act_type == activation::kReLU
|| param.act_type == activation::kSigmoid
|| param.act_type == activation::kSoftReLU;
#endif
|| param.act_type == activation::kSoftReLU
|| param.act_type == activation::kTanh;
}

static inline mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) {
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,8 +717,8 @@ def test_lambda():

input_data = mx.nd.random.uniform(shape=(2, 3, 5, 7))
out1, out2, out3 = net1(input_data), net2(input_data), net3(input_data)
assert_almost_equal(out1.asnumpy(), out2.asnumpy())
assert_almost_equal(out1.asnumpy(), out3.asnumpy())
assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-3)
assert_almost_equal(out1.asnumpy(), out3.asnumpy(), rtol=1e-3)


@with_seed()
Expand Down

0 comments on commit 1dcdc44

Please sign in to comment.