Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-486] Create CPP test for concat MKLDNN operator #11371

Merged
merged 91 commits into from
Jun 29, 2018
Merged
Show file tree
Hide file tree
Changes from 87 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
2da1979
create func to generate new outputs for concat output
azai91 Jun 19, 2018
99d697d
provide type for vectors
azai91 Jun 19, 2018
e310303
fix concat dict
azai91 Jun 20, 2018
4f0873a
change value to string
azai91 Jun 20, 2018
ba1f4e0
use different datatype
azai91 Jun 20, 2018
29e2e88
create GetExpandedMemPD helper
azai91 Jun 20, 2018
a50030f
update input num
azai91 Jun 20, 2018
246949d
fix dim range
azai91 Jun 20, 2018
ce96e07
use only new shape
azai91 Jun 20, 2018
e772ea0
add comments
azai91 Jun 20, 2018
b8e5dbe
replace all with new_shape
azai91 Jun 20, 2018
0d0002b
consolidate testop
azai91 Jun 20, 2018
ac18966
remove init interface
azai91 Jun 20, 2018
716d8b1
use GetTestOutputArraysConcat for concat op
azai91 Jun 20, 2018
eae53fb
out arrays in correct scope
azai91 Jun 20, 2018
179b9b9
add VerifyConcatResult
azai91 Jun 20, 2018
e29b43b
noop for kWriteInPlace for concat
azai91 Jun 21, 2018
39f1090
refactor GetTestOutputArrays and GetTestOutputArraysConcat into one m…
azai91 Jun 21, 2018
9938492
create temp ndarrays in same scope as assert
azai91 Jun 21, 2018
7248b59
add message for GetTestOutputArraysConcat
azai91 Jun 21, 2018
77e41db
filter when dim too large
azai91 Jun 21, 2018
a48b2a5
fix print message
azai91 Jun 21, 2018
210d528
reshape concat output so it can be read
azai91 Jun 21, 2018
e41ccfe
check concat on diff dim
azai91 Jun 21, 2018
3ce7f3f
add VerifyConcatBackwardsResult bp
azai91 Jun 21, 2018
1573f24
reshape if view and mkldnn
azai91 Jun 21, 2018
41230d3
finish VerifyConcatBackwardsResult
azai91 Jun 21, 2018
fc7a812
reverse input output for concat backwards
azai91 Jun 21, 2018
4822952
do not rand output for concat backwards
azai91 Jun 21, 2018
5d86f89
make mulitple copies of inputs for ops that need mult unique inputs
azai91 Jun 22, 2018
ed91fc0
swap input/output msg
azai91 Jun 22, 2018
c9accdf
create test inputs can create expanded inputs
azai91 Jun 22, 2018
e3ac132
add verify msg to test
azai91 Jun 22, 2018
b9c55a9
fix slice of input
azai91 Jun 22, 2018
1eaa229
remove unused test
azai91 Jun 22, 2018
976393e
missing assignment
azai91 Jun 22, 2018
f46640d
fix slice amount for diff dim concat
azai91 Jun 22, 2018
adbdfd9
shrink outputs for concat backwards
azai91 Jun 22, 2018
44925a8
revert switching input/output for concat/backwards
azai91 Jun 22, 2018
4d96d88
create multiple copies of output
azai91 Jun 22, 2018
c806266
reorder concat input grad
azai91 Jun 22, 2018
bfe588e
increase num of input for concat backwards
azai91 Jun 22, 2018
2f463d5
concat dst is smaller array
azai91 Jun 22, 2018
881f3b2
use output vs input mem to determine shape and as tmp storage
azai91 Jun 22, 2018
b9b82d2
do not support mkldnn concat if mkl layout diff from nd layout
azai91 Jun 22, 2018
1547f53
reorder if view /mkldnn
azai91 Jun 22, 2018
3d9d812
exclude views from concat
azai91 Jun 22, 2018
d29ca85
remove unused header
azai91 Jun 22, 2018
1b73c9a
remove check for view in mkldnn_concat
azai91 Jun 22, 2018
3da955f
remove unused heaeder
azai91 Jun 22, 2018
f26a9ba
skip test
azai91 Jun 22, 2018
063f820
rename target_shape to shape
azai91 Jun 22, 2018
3ed397c
remove rand var and default outputs to rand
azai91 Jun 22, 2018
51adef6
rename target_pd to pd
azai91 Jun 22, 2018
641cddd
fix lint issues
azai91 Jun 22, 2018
eda5267
add space to error msg
azai91 Jun 22, 2018
48436fe
do not use mkldnn for forward concat if layout mismatch
azai91 Jun 23, 2018
6b6b8c1
create temp shape var
azai91 Jun 23, 2018
de3fe6c
do not check if view in concat
azai91 Jun 23, 2018
3b5968e
convert dim to unsigned int
azai91 Jun 23, 2018
331a4cc
fix lint
azai91 Jun 24, 2018
0125d63
check view first
azai91 Jun 25, 2018
3ac93bd
check type before creating mem
azai91 Jun 25, 2018
29d9490
check all inputs for concat mkldnn
azai91 Jun 27, 2018
02ab157
remove getshapestring
azai91 Jun 27, 2018
42ff5be
add comments for verify concat helpres
azai91 Jun 27, 2018
d2e9392
revert adding USE_MKLDNN flag
azai91 Jun 27, 2018
366bca2
use reference for arrays in concat mkldnn check
azai91 Jun 27, 2018
49b0c90
fix indent
azai91 Jun 27, 2018
9b0cad3
set default num_inputs to 1
azai91 Jun 27, 2018
7b00246
revert change to test_ctc_loss_train
azai91 Jun 27, 2018
70d8e96
add error message to check
azai91 Jun 27, 2018
362910d
use reference of arr in loops
azai91 Jun 27, 2018
b5220c8
merge from master
azai91 Jun 27, 2018
bb4e0ee
remove extra space
azai91 Jun 27, 2018
a01657a
use default num_inputs
azai91 Jun 27, 2018
708c206
use reference for all loops
azai91 Jun 27, 2018
58cf2a0
fix lint
azai91 Jun 27, 2018
b53aa28
Merge branch 'master' into test/concat
azai91 Jun 27, 2018
1596717
use separate concat test
azai91 Jun 27, 2018
4c3089f
remove reference from pd
azai91 Jun 27, 2018
dd3e594
do not use reference for shape
azai91 Jun 27, 2018
332f6e7
change conditional in gettestinputarray
azai91 Jun 27, 2018
362e6d8
remove reference
azai91 Jun 27, 2018
6c24960
fix lint
azai91 Jun 27, 2018
d8b7490
Merge branch 'master' into test/concat
azai91 Jun 27, 2018
11e0860
Merge branch 'master' into test/concat
azai91 Jun 27, 2018
035c3a4
increase num_inputs to 3
azai91 Jun 28, 2018
ce98f8d
remove extra out_arr var
azai91 Jun 28, 2018
050632e
retrigger
azai91 Jun 28, 2018
9dbb0da
increase num_inputs
azai91 Jun 28, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions src/operator/nn/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,19 @@ inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs,
return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, wanted_mode);
}

#if MXNET_USE_MKLDNN == 1
bool SupportMKLDNNConcat(const std::vector<NDArray> &arrs) {
for (auto &arr : arrs) {
if (arr.IsView()) return false;
if (arr.dtype() != mshadow::kFloat32) return false;
unsigned ndim = arr.shape().ndim();
unsigned mkldnn_ndims =
static_cast<unsigned>(arr.GetMKLDNNData()->get_primitive_desc().desc().data.ndims);
if (!(ndim == 2 || ndim == 4) || ndim != mkldnn_ndims) return false;
}
return true;
}
#endif
static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& op_ctx,
const std::vector<NDArray>& inputs,
Expand All @@ -171,8 +183,7 @@ static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
outputs[0].storage_type() == kCSRStorage) {
ConcatCSRImpl<cpu>(attrs, op_ctx, inputs, req, outputs);
#if MXNET_USE_MKLDNN == 1
} else if ((inputs[0].shape().ndim() == 2 || inputs[0].shape().ndim() == 4)
&& inputs[0].dtype() == mshadow::kFloat32) {
} else if (SupportMKLDNNConcat(inputs)) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNConcatForward(attrs, op_ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(ConcatCompute<cpu>, attrs, op_ctx, inputs, req, outputs);
Expand All @@ -190,8 +201,7 @@ static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if ((inputs[0].shape().ndim() == 2 || inputs[0].shape().ndim() == 4)
&& inputs[0].dtype() == mshadow::kFloat32) {
if (SupportMKLDNNConcat(inputs)) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNConcatBackward(attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(ConcatGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
Expand Down
12 changes: 6 additions & 6 deletions src/operator/nn/mkldnn/mkldnn_concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
std::vector<const mkldnn::memory *> data_mem;
data_md.reserve(num_in_data);
data_mem.reserve(num_in_data);
for (int i =0; i < num_in_data; i++) {
for (int i = 0; i < num_in_data; i++) {
const mkldnn::memory *tmp_mem = in_data[i].GetMKLDNNData();
mkldnn::memory::primitive_desc tmp_pd = tmp_mem->get_primitive_desc();
data_md.push_back(tmp_pd);
Expand Down Expand Up @@ -138,11 +138,11 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
mkldnn::memory::dims offsets = {0, 0, 0, 0};
for (int i = 0; i < num_in_data; i++) {
mkldnn::memory::dims diff_src_tz
= {static_cast<int>(inputs[i+1].shape()[0]),
static_cast<int>(inputs[i+1].shape()[1]),
static_cast<int>(inputs[i+1].shape()[2]),
static_cast<int>(inputs[i+1].shape()[3])};
auto diff_src_mpd = inputs[i+1].GetMKLDNNData()->get_primitive_desc();
= {static_cast<int>(outputs[i].shape()[0]),
static_cast<int>(outputs[i].shape()[1]),
static_cast<int>(outputs[i].shape()[2]),
static_cast<int>(outputs[i].shape()[3])};
auto diff_src_mpd = outputs[i].GetMKLDNNData()->get_primitive_desc();
auto gradi_mem_ = CreateMKLDNNMem(outputs[i], diff_src_mpd, req[i]);
// create view from gy to gxs[i]
std::shared_ptr<mkldnn::view::primitive_desc> view_pd;
Expand Down
Loading