Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-486] Create CPP test for concat MKLDNN operator (#11371)
Browse files Browse the repository at this point in the history
* create func to generate new outputs for concat output

* provide type for vectors

* fix concat dict

* change value to string

* use different datatype

* create GetExpandedMemPD helper

* update input num

* fix dim range

* use only new shape

* add comments

* replace all with new_shape

* consolidate testop

* remove init interface

* use GetTestOutputArraysConcat for concat op

* out arrays in correct scope

* add VerifyConcatResult

* noop for kWriteInPlace for concat

* refactor GetTestOutputArrays and GetTestOutputArraysConcat into one method

* create temp ndarrays in same scope as assert

* add message for GetTestOutputArraysConcat

* filter when dim too large

* fix print message

* reshape concat output so it can be read

* check concat on diff dim

* add VerifyConcatBackwardsResult bp

* reshape if view and mkldnn

* finish VerifyConcatBackwardsResult

* reverse input output for concat backwards

* do not rand output for concat backwards

* make mulitple copies of inputs for ops that need mult unique inputs

* swap input/output msg

* create test inputs can create expanded inputs

* add verify msg to test

* fix slice of input

* remove unused test

* missing assignment

* fix slice amount for diff dim concat

* shrink outputs for concat backwards

* revert switching input/output for concat/backwards

* create multiple copies of output

* reorder concat input grad

* increase num of input for concat backwards

* concat dst is smaller array

* use output vs input mem to determine shape and as tmp storage

* do not support mkldnn concat if mkl layout diff from nd layout

* reorder if view /mkldnn

* exclude views from concat

* remove unused header

* remove check for view in mkldnn_concat

* remove unused heaeder

* skip test

* rename target_shape to shape

* remove rand var and default outputs to rand

* rename target_pd to pd

* fix lint issues

* add space to error msg

* do not use mkldnn for forward concat if layout mismatch

* create temp shape var

* do not check if view in concat

* convert dim to unsigned int

* fix lint

* check view first

* check type before creating mem

* check all inputs for concat mkldnn

* remove getshapestring

* add comments for verify concat helpres

* revert adding USE_MKLDNN flag

* use reference for arrays in concat mkldnn check

* fix indent

* set default num_inputs to 1

* revert change to test_ctc_loss_train

* add error message to check

* use reference of arr in loops

* remove extra space

* use default num_inputs

* use reference for all loops

* fix lint

* use separate concat test

* remove reference from pd

* do not use reference for shape

* change conditional in gettestinputarray

* remove reference

* fix lint

* increase num_inputs to 3

* remove extra out_arr var

* retrigger

* increase num_inputs
  • Loading branch information
azai91 authored and szha committed Jun 29, 2018
1 parent ca60b94 commit af4a600
Show file tree
Hide file tree
Showing 3 changed files with 305 additions and 69 deletions.
20 changes: 15 additions & 5 deletions src/operator/nn/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,19 @@ inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs,
return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, wanted_mode);
}

#if MXNET_USE_MKLDNN == 1
bool SupportMKLDNNConcat(const std::vector<NDArray> &arrs) {
for (auto &arr : arrs) {
if (arr.IsView()) return false;
if (arr.dtype() != mshadow::kFloat32) return false;
unsigned ndim = arr.shape().ndim();
unsigned mkldnn_ndims =
static_cast<unsigned>(arr.GetMKLDNNData()->get_primitive_desc().desc().data.ndims);
if (!(ndim == 2 || ndim == 4) || ndim != mkldnn_ndims) return false;
}
return true;
}
#endif
static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& op_ctx,
const std::vector<NDArray>& inputs,
Expand All @@ -171,8 +183,7 @@ static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
outputs[0].storage_type() == kCSRStorage) {
ConcatCSRImpl<cpu>(attrs, op_ctx, inputs, req, outputs);
#if MXNET_USE_MKLDNN == 1
} else if ((inputs[0].shape().ndim() == 2 || inputs[0].shape().ndim() == 4)
&& inputs[0].dtype() == mshadow::kFloat32) {
} else if (SupportMKLDNNConcat(inputs)) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNConcatForward(attrs, op_ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(ConcatCompute<cpu>, attrs, op_ctx, inputs, req, outputs);
Expand All @@ -190,8 +201,7 @@ static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if ((inputs[0].shape().ndim() == 2 || inputs[0].shape().ndim() == 4)
&& inputs[0].dtype() == mshadow::kFloat32) {
if (SupportMKLDNNConcat(inputs)) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNConcatBackward(attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(ConcatGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
Expand Down
12 changes: 6 additions & 6 deletions src/operator/nn/mkldnn/mkldnn_concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
std::vector<const mkldnn::memory *> data_mem;
data_md.reserve(num_in_data);
data_mem.reserve(num_in_data);
for (int i =0; i < num_in_data; i++) {
for (int i = 0; i < num_in_data; i++) {
const mkldnn::memory *tmp_mem = in_data[i].GetMKLDNNData();
mkldnn::memory::primitive_desc tmp_pd = tmp_mem->get_primitive_desc();
data_md.push_back(tmp_pd);
Expand Down Expand Up @@ -138,11 +138,11 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
mkldnn::memory::dims offsets = {0, 0, 0, 0};
for (int i = 0; i < num_in_data; i++) {
mkldnn::memory::dims diff_src_tz
= {static_cast<int>(inputs[i+1].shape()[0]),
static_cast<int>(inputs[i+1].shape()[1]),
static_cast<int>(inputs[i+1].shape()[2]),
static_cast<int>(inputs[i+1].shape()[3])};
auto diff_src_mpd = inputs[i+1].GetMKLDNNData()->get_primitive_desc();
= {static_cast<int>(outputs[i].shape()[0]),
static_cast<int>(outputs[i].shape()[1]),
static_cast<int>(outputs[i].shape()[2]),
static_cast<int>(outputs[i].shape()[3])};
auto diff_src_mpd = outputs[i].GetMKLDNNData()->get_primitive_desc();
auto gradi_mem_ = CreateMKLDNNMem(outputs[i], diff_src_mpd, req[i]);
// create view from gy to gxs[i]
std::shared_ptr<mkldnn::view::primitive_desc> view_pd;
Expand Down
Loading

0 comments on commit af4a600

Please sign in to comment.