Skip to content

Commit

Permalink
Fix axis Bug in MKLDNN Softmax (apache#11335)
Browse files Browse the repository at this point in the history
* add softmax imporvement

* reuse CheckAxis code

* update comment

* add tests with negative axis
  • Loading branch information
xinyu-intel authored and szha committed Jun 20, 2018
1 parent 383e33f commit da58c44
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
5 changes: 4 additions & 1 deletion src/operator/nn/mkldnn/mkldnn_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "../softmax-inl.h"
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"
#include "../../tensor/broadcast_reduce_op.h"

#if MXNET_USE_MKLDNN == 1
namespace mxnet {
Expand All @@ -38,11 +39,13 @@ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
auto input_mem = in_data.GetMKLDNNData();
mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
mkldnn::memory::desc data_md = data_mpd.desc();
int axis = CheckAxis(param.axis, in_data.shape().ndim());

auto cpu_engine = data_mpd.get_engine();
auto prop = ctx.is_train
? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;
mkldnn::softmax_forward::desc desc = mkldnn::softmax_forward::desc(prop,
data_md, param.axis);
data_md, axis);
mkldnn::softmax_forward::primitive_desc pdesc(desc, cpu_engine);

auto output_memory = out_data.GetMKLDNNData();
Expand Down
4 changes: 1 addition & 3 deletions src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,8 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
// It seems MKLDNN softmax doesn't support training.
// and it only supports non-negative axis.
if (SupportMKLDNN(inputs[0]) && !ctx.is_train && param.axis >= 0) {
if (SupportMKLDNN(inputs[0]) && !ctx.is_train) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNSoftmaxForward(attrs, ctx, inputs[0], req[0], outputs[0]);
auto fn = SoftmaxCompute<cpu, mxnet_op::softmax_fwd>;
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4098,7 +4098,7 @@ def test_new_softmax():
for ndim in range(1, 5):
for _ in range(5):
shape = np.random.randint(1, 5, size=ndim)
axis = np.random.randint(0, ndim)
axis = np.random.randint(-ndim, ndim)
data = np.random.uniform(-2, 2, size=shape)
sym = mx.sym.softmax(axis=axis)
check_symbolic_forward(sym, [data], [np_softmax(data, axis=axis)])
Expand Down

0 comments on commit da58c44

Please sign in to comment.