From af4a6005c67cc391567156686dd0ee9867ac7c0c Mon Sep 17 00:00:00 2001 From: Alexander Zai Date: Fri, 29 Jun 2018 08:54:50 -0700 Subject: [PATCH] [MXNET-486] Create CPP test for concat MKLDNN operator (#11371) * create func to generate new outputs for concat output * provide type for vectors * fix concat dict * change value to string * use different datatype * create GetExpandedMemPD helper * update input num * fix dim range * use only new shape * add comments * replace all with new_shape * consolidate testop * remove init interface * use GetTestOutputArraysConcat for concat op * out arrays in correct scope * add VerifyConcatResult * noop for kWriteInPlace for concat * refactor GetTestOutputArrays and GetTestOutputArraysConcat into one method * create temp ndarrays in same scope as assert * add message for GetTestOutputArraysConcat * filter when dim too large * fix print message * reshape concat output so it can be read * check concat on diff dim * add VerifyConcatBackwardsResult bp * reshape if view and mkldnn * finish VerifyConcatBackwardsResult * reverse input output for concat backwards * do not rand output for concat backwards * make mulitple copies of inputs for ops that need mult unique inputs * swap input/output msg * create test inputs can create expanded inputs * add verify msg to test * fix slice of input * remove unused test * missing assignment * fix slice amount for diff dim concat * shrink outputs for concat backwards * revert switching input/output for concat/backwards * create multiple copies of output * reorder concat input grad * increase num of input for concat backwards * concat dst is smaller array * use output vs input mem to determine shape and as tmp storage * do not support mkldnn concat if mkl layout diff from nd layout * reorder if view /mkldnn * exclude views from concat * remove unused header * remove check for view in mkldnn_concat * remove unused heaeder * skip test * rename target_shape to shape * remove rand var and default outputs to rand * rename target_pd to pd * fix lint issues * add space to error msg * do not use mkldnn for forward concat if layout mismatch * create temp shape var * do not check if view in concat * convert dim to unsigned int * fix lint * check view first * check type before creating mem * check all inputs for concat mkldnn * remove getshapestring * add comments for verify concat helpres * revert adding USE_MKLDNN flag * use reference for arrays in concat mkldnn check * fix indent * set default num_inputs to 1 * revert change to test_ctc_loss_train * add error message to check * use reference of arr in loops * remove extra space * use default num_inputs * use reference for all loops * fix lint * use separate concat test * remove reference from pd * do not use reference for shape * change conditional in gettestinputarray * remove reference * fix lint * increase num_inputs to 3 * remove extra out_arr var * retrigger * increase num_inputs --- src/operator/nn/concat.cc | 20 +- src/operator/nn/mkldnn/mkldnn_concat.cc | 12 +- tests/cpp/operator/mkldnn.cc | 342 ++++++++++++++++++++---- 3 files changed, 305 insertions(+), 69 deletions(-) diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc index 04332456cda3..266ccb1b1a14 100644 --- a/src/operator/nn/concat.cc +++ b/src/operator/nn/concat.cc @@ -157,7 +157,19 @@ inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs, return storage_type_assign(out_attrs, mxnet::kDefaultStorage, dispatch_mode, wanted_mode); } - +#if MXNET_USE_MKLDNN == 1 +bool SupportMKLDNNConcat(const std::vector &arrs) { + for (auto &arr : arrs) { + if (arr.IsView()) return false; + if (arr.dtype() != mshadow::kFloat32) return false; + unsigned ndim = arr.shape().ndim(); + unsigned mkldnn_ndims = + static_cast(arr.GetMKLDNNData()->get_primitive_desc().desc().data.ndims); + if (!(ndim == 2 || ndim == 4) || ndim != mkldnn_ndims) return false; + } + return true; +} +#endif static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs, const OpContext& op_ctx, const std::vector& inputs, @@ -171,8 +183,7 @@ static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs, outputs[0].storage_type() == kCSRStorage) { ConcatCSRImpl(attrs, op_ctx, inputs, req, outputs); #if MXNET_USE_MKLDNN == 1 - } else if ((inputs[0].shape().ndim() == 2 || inputs[0].shape().ndim() == 4) - && inputs[0].dtype() == mshadow::kFloat32) { + } else if (SupportMKLDNNConcat(inputs)) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); MKLDNNConcatForward(attrs, op_ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(ConcatCompute, attrs, op_ctx, inputs, req, outputs); @@ -190,8 +201,7 @@ static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - if ((inputs[0].shape().ndim() == 2 || inputs[0].shape().ndim() == 4) - && inputs[0].dtype() == mshadow::kFloat32) { + if (SupportMKLDNNConcat(inputs)) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); MKLDNNConcatBackward(attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(ConcatGradCompute, attrs, ctx, inputs, req, outputs); diff --git a/src/operator/nn/mkldnn/mkldnn_concat.cc b/src/operator/nn/mkldnn/mkldnn_concat.cc index dbc0e94c630f..af81e1fe3ee3 100644 --- a/src/operator/nn/mkldnn/mkldnn_concat.cc +++ b/src/operator/nn/mkldnn/mkldnn_concat.cc @@ -107,7 +107,7 @@ void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, std::vector data_mem; data_md.reserve(num_in_data); data_mem.reserve(num_in_data); - for (int i =0; i < num_in_data; i++) { + for (int i = 0; i < num_in_data; i++) { const mkldnn::memory *tmp_mem = in_data[i].GetMKLDNNData(); mkldnn::memory::primitive_desc tmp_pd = tmp_mem->get_primitive_desc(); data_md.push_back(tmp_pd); @@ -138,11 +138,11 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, mkldnn::memory::dims offsets = {0, 0, 0, 0}; for (int i = 0; i < num_in_data; i++) { mkldnn::memory::dims diff_src_tz - = {static_cast(inputs[i+1].shape()[0]), - static_cast(inputs[i+1].shape()[1]), - static_cast(inputs[i+1].shape()[2]), - static_cast(inputs[i+1].shape()[3])}; - auto diff_src_mpd = inputs[i+1].GetMKLDNNData()->get_primitive_desc(); + = {static_cast(outputs[i].shape()[0]), + static_cast(outputs[i].shape()[1]), + static_cast(outputs[i].shape()[2]), + static_cast(outputs[i].shape()[3])}; + auto diff_src_mpd = outputs[i].GetMKLDNNData()->get_primitive_desc(); auto gradi_mem_ = CreateMKLDNNMem(outputs[i], diff_src_mpd, req[i]); // create view from gy to gxs[i] std::shared_ptr view_pd; diff --git a/tests/cpp/operator/mkldnn.cc b/tests/cpp/operator/mkldnn.cc index e593d00a0de4..8e01216527c8 100644 --- a/tests/cpp/operator/mkldnn.cc +++ b/tests/cpp/operator/mkldnn.cc @@ -160,6 +160,17 @@ static mkldnn::memory::primitive_desc GetMemPD(const TShape s, int dtype, return mkldnn::memory::primitive_desc(desc, CpuEngine::Get()->get_engine()); } +static mkldnn::memory::primitive_desc GetExpandedMemPD( + mkldnn::memory::primitive_desc pd, float num_input, int dim = 0) { + CHECK(dim < pd.desc().data.ndims) << "dimension cannot be larger than total dimensions of input"; + nnvm::TShape s(pd.desc().data.ndims); + for (size_t i = 0; i < pd.desc().data.ndims; i++) + s[i] = pd.desc().data.dims[i]; + s[dim] = static_cast(s[dim] * num_input); + return GetMemPD(s, mshadow::DataType::kFlag, + static_cast(pd.desc().data.format)); +} + // This function gets special MKLDNN formats without knowing the specific // hardware configuration. Certainly, it potentially misses some format if // it's specific for certain array shapes. It covers at least one special format @@ -359,9 +370,9 @@ struct OpAttrs { OpAttrs GetCopyOp() { OpAttrs attrs; attrs.attrs.op = Op::Get("_copy"); - attrs.dispatches.resize(2); attrs.num_inputs = 1; attrs.num_outputs = 1; + attrs.dispatches.resize(2); attrs.dispatches[0] = DispatchMode::kFCompute; attrs.dispatches[1] = DispatchMode::kFComputeEx; return attrs; @@ -407,9 +418,9 @@ OpAttrs GetReluBackwardsOp() { OpAttrs GetSumOp() { OpAttrs attrs; attrs.attrs.op = Op::Get("elemwise_add"); - attrs.dispatches.resize(2); attrs.num_inputs = 2; attrs.num_outputs = 1; + attrs.dispatches.resize(2); attrs.dispatches[0] = DispatchMode::kFCompute; attrs.dispatches[1] = DispatchMode::kFComputeEx; return attrs; @@ -426,6 +437,42 @@ OpAttrs GetSumBackwardsOp() { return attrs; } +OpAttrs GetConcatOp(int num_args, int dim) { + OpAttrs attrs; + attrs.attrs.op = Op::Get("concat"); + attrs.num_inputs = num_args; + attrs.num_outputs = 1; + attrs.attrs.dict.insert({"num_args" , std::to_string(num_args)}); + attrs.attrs.dict.insert({"dim" , std::to_string(dim)}); + attrs.attrs.op->attr_parser(&attrs.attrs); + attrs.dispatches.resize(2); + attrs.dispatches[0] = DispatchMode::kFCompute; + attrs.dispatches[1] = DispatchMode::kFComputeEx; + return attrs; +} + +OpAttrs GetConcatBackwardsOp(int num_args, int dim) { + OpAttrs attrs; + attrs.attrs.op = Op::Get("_backward_Concat"); + attrs.num_inputs = 2; + attrs.num_outputs = num_args; + attrs.attrs.dict.insert({"num_args" , std::to_string(num_args)}); + attrs.attrs.dict.insert({"dim" , std::to_string(dim)}); + attrs.attrs.op->attr_parser(&attrs.attrs); + attrs.dispatches.resize(2); + attrs.dispatches[0] = DispatchMode::kFCompute; + attrs.dispatches[1] = DispatchMode::kFComputeEx; + return attrs; +} + +void PrintVerifyMsg(const NDArrayAttrs &arr1, const NDArrayAttrs &arr2) { + TShape t1 = arr1.arr.shape(); + TShape t2 = arr2.arr.shape(); + std::stringstream ss; + std::cout << "Verifying: " << arr1.desc.c_str() << " " << + t1 << " with " << arr2.desc.c_str() << " " << t2 << "\n"; +} + /* * We want to get a few types of NDArrays for testing: * 1. Normal NDArray @@ -446,20 +493,37 @@ OpAttrs GetSumBackwardsOp() { * In the inference mode, the MKLDNN memory in the weight array will be * reordered to 5 dimensions. * + * num_inputs / dim arguments used to scale shape (used for concat backwards to enlarge input shapes) */ -std::vector GetTestInputArrays(bool rand = false) { +std::vector GetTestInputArrays(bool rand = false, int num_inputs = 1, int dim = 0) { TestArrayShapes tas = GetTestArrayShapes(); std::vector shapes = tas.shapes; std::vector pds = tas.pds; std::vector in_arrs; std::string desc; + + int slice_amount = 1; + if (dim == 0) + slice_amount = num_inputs; for (auto shape : shapes) { + if (dim >= shape.ndim()) + continue; + shape[dim] = shape[dim] * num_inputs; + // Type 1. NDArray arr(shape, Context()); in_arrs.emplace_back(arr, "Normal NDArray"); InitDefaultArray(&in_arrs.back().arr, rand); for (auto pd : pds) { + if (num_inputs > 1) { + // preserve if matching layout else just expand on 0 dim + if (shape.ndim() == pd.desc().data.ndims) + pd = GetExpandedMemPD(pd, num_inputs, dim); + else + pd = GetExpandedMemPD(pd, num_inputs); + } + if (shape.Size() != pd.get_size() / sizeof(mshadow::default_real_t)) continue; @@ -472,8 +536,8 @@ std::vector GetTestInputArrays(bool rand = false) { shape.ndim() << "/" << pd.desc().data.ndims; desc = ss.str(); } + InitMKLDNNArray(&arr, pd); in_arrs.emplace_back(arr, desc); - InitMKLDNNArray(&in_arrs.back().arr, pd); // Type 4, 5, 6. arr = NDArray(shape, Context()); @@ -485,31 +549,12 @@ std::vector GetTestInputArrays(bool rand = false) { desc = ss.str(); } InitMKLDNNArray(&arr, pd); - in_arrs.emplace_back(arr.Slice(1, arr.shape()[0] - 1), desc); + in_arrs.emplace_back(arr.Slice(slice_amount, arr.shape()[0] - slice_amount), desc); } } return in_arrs; } -TEST(MKLDNN_NDArray, GetTestInputArrays) { - std::vector in_arrs = GetTestInputArrays(); - int mkldnn_count = 0, mkldnn_view_count = 0; - for (auto arr : in_arrs) { - if (arr.arr.IsView() && arr.arr.IsMKLDNNData()) { - mkldnn_view_count++; - continue; - } - - if (arr.arr.IsMKLDNNData()) { - mkldnn_count++; - continue; - } - } - - EXPECT_GT(mkldnn_view_count, 0); - EXPECT_GT(mkldnn_count, 0); -} - /* * We want to get a few types of NDArrays for testing: * 1. Normal NDArray @@ -527,9 +572,17 @@ TEST(MKLDNN_NDArray, GetTestInputArrays) { * 7. Reused reshaped/sliced NDArray. * 8. Reused NDArray with MKLDNN layout. * 9. Reused NDArray with MKLDNN layout of different dimensions. + * + * Optional num_inputs / dim args can be passed to modify input shape (used for Concat test) */ -std::vector GetTestOutputArrays(const TShape &shape, - const std::vector &pds) { +std::vector GetTestOutputArrays( + const TShape &shp, + const std::vector &pds, + float num_inputs = 0, int dim = 0) { + TShape shape = shp; + if (num_inputs != 0) + shape[dim] = static_cast(shape[dim] * num_inputs); + std::vector in_arrs; std::string desc; // Type 1. @@ -568,11 +621,14 @@ std::vector GetTestOutputArrays(const TShape &shape, InitDefaultArray(&arr3, true); in_arrs.emplace_back(arr3.Slice(1, shape[0] + 1), "Reused+Reshaped NDArray"); - for (auto pd : pds) { if (shape.Size() != pd.get_size() / sizeof(mshadow::default_real_t)) continue; + if (num_inputs != 0) + pd = GetExpandedMemPD(pd, num_inputs); + + // Type 2, 3. arr = NDArray(shape, Context()); @@ -605,6 +661,47 @@ std::vector GetTestOutputArrays(const TShape &shape, return in_arrs; } +TEST(MKLDNN_NDArray, GetTestInputArraysConcat) { + auto in_arrs = GetTestInputArrays(); + for (int dim = 0; dim < 5; dim++) { + for (int num_inputs = 2; num_inputs < 5; num_inputs++) { + std::vector expanded_arrs = GetTestInputArrays(false, num_inputs, dim); + int i = 0; + for (auto &arr : in_arrs) { + if (dim >= arr.arr.shape().ndim()) + continue; + auto ex_arr = expanded_arrs[i]; + PrintVerifyMsg(arr, ex_arr); + EXPECT_EQ(arr.arr.shape().Size() * num_inputs, ex_arr.arr.shape().Size()); + EXPECT_EQ(arr.arr.shape()[dim] * num_inputs, ex_arr.arr.shape()[dim]); + i++; + } + } + } +} + +TEST(MKLDNN_NDArray, GetTestOutputArraysConcat) { + auto shapes_pds = GetTestArrayShapes(); + std::vector shapes; shapes = shapes_pds.shapes; + std::vector pds = shapes_pds.pds; + for (auto &shape : shapes) { + for (int dim = 0; dim < 5; dim++) { + for (int num_inputs = 2; num_inputs < 5; num_inputs++) { + if (shape.ndim() <= dim) + continue; + std::cout << "Extending " << shape << " dim " << + dim << " and " << num_inputs << "num_inputs\n"; + auto output_arrs = GetTestOutputArrays(shape, pds, num_inputs, dim); + for (auto &out_arr : output_arrs) { + auto out_shape = out_arr.arr.shape(); + EXPECT_EQ(shape.Size() * num_inputs, out_arr.arr.shape().Size()); + EXPECT_EQ(shape[dim] * num_inputs, out_arr.arr.shape()[dim]); + } + } + } + } +} + void VerifyCopyResult(const std::vector &in_arrs, const std::vector &out_arrs) { NDArray tmp1 = in_arrs[0]->Reorder2Default(); @@ -676,17 +773,77 @@ void VerifySumBackwardsResult(const std::vector &in_arrs, } } -void PrintVerifyMsg(const NDArrayAttrs &arr1, const NDArrayAttrs &arr2) { - TShape t1 = arr1.arr.shape(); - TShape t2 = arr2.arr.shape(); +/* + * Determines axis ndarrays are concatenated by + * Used to verify concat/concat backwards operator + */ +int GetDim(TShape input_shape, TShape output_shape) { + CHECK(input_shape.Size() != output_shape.Size()); + for (size_t i = 0; i < input_shape.ndim(); i++) { + if (input_shape[i] != output_shape[i]) + return i; + } + return -1; +} - printf("Verifying: %s (", arr1.desc.c_str()); - for (size_t i = 0; i < t1.ndim(); i++) - printf("%ld, ", t1[i]); - printf(") with %s (", arr2.desc.c_str()); - for (size_t i = 0; i < t2.ndim(); i++) - printf("%ld, ", t2[i]); - printf(")\n"); +/* + * Calculates the size of continuous block of array inside arger concatenated array + * Used to verify concat/concat backwards operator + */ +int GetBlockSize(TShape shape, int dim) { + int block_size = 1; + for (int i = shape.ndim() - 1; i >= dim; i--) + block_size *= shape[i]; + return block_size; +} + +void VerifyConcatResult(const std::vector &in_arrs, + const std::vector &out_arrs) { + int num_inputs = in_arrs.size(); + int input_size = in_arrs[0]->shape().Size(); + TShape input_shape = in_arrs[0]->shape(); + NDArray output = out_arrs[0]->Reorder2Default(); + size_t total_size = output.shape().Size(); + EXPECT_EQ(input_size * num_inputs, total_size); + mshadow::default_real_t *out_data = output.data().dptr(); + + int dim = GetDim(input_shape, output.shape()); + int block_size = GetBlockSize(input_shape, dim); + int num_blocks = input_size / block_size; + for (size_t input_num = 0; input_num < num_inputs; input_num++) { + NDArray tmp = in_arrs[input_num]->Reorder2Default(); + mshadow::default_real_t* data = tmp.data().dptr(); + for (size_t block_num = 0; block_num < num_blocks; block_num++) { + for (size_t i = 0; i < block_size; i++) + ASSERT_EQ(data[block_num * block_size + i], + out_data[(block_num * num_inputs + input_num) * block_size + i]); + } + } +} + +void VerifyConcatBackwardsResult(const std::vector &in_arrs, + const std::vector &out_arrs) { + // in_arrs is larger array, out_arr is ammler + int num_inputs = out_arrs.size(); + int input_size = out_arrs[0]->shape().Size(); + TShape input_shape = out_arrs[0]->shape(); + NDArray output = in_arrs[0]->Reorder2Default(); + size_t total_size = output.shape().Size(); + EXPECT_EQ(input_size * num_inputs, total_size); + mshadow::default_real_t *out_data = output.data().dptr(); + + int dim = GetDim(input_shape, output.shape()); + int block_size = GetBlockSize(input_shape, dim); + int num_blocks = input_size / block_size; + for (size_t input_num = 0; input_num < num_inputs; input_num++) { + NDArray tmp = out_arrs[input_num]->Reorder2Default(); + mshadow::default_real_t* data = tmp.data().dptr(); + for (size_t block_num = 0; block_num < num_blocks; block_num++) { + for (size_t i = 0; i < block_size; i++) + ASSERT_EQ(data[block_num * block_size + i], + out_data[(block_num * num_inputs + input_num) * block_size + i]); + } + } } void VerifyAddRequest(const std::vector &in_arrs, @@ -703,11 +860,11 @@ TEST(MKLDNN_NDArray, CopyFrom) { std::vector pds = tas.pds; std::vector in_arrs = GetTestInputArrays(); - for (auto in_arr : in_arrs) { + for (auto &in_arr : in_arrs) { + if (in_arr.arr.IsMKLDNNData() && in_arr.arr.IsView()) + continue; std::vector out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds); - for (auto out_arr : out_arrs) { - if (in_arr.arr.IsMKLDNNData() && in_arr.arr.IsView()) - in_arr.arr = in_arr.arr.Reorder2Default(); + for (auto &out_arr : out_arrs) { const mkldnn::memory *mem = in_arr.arr.GetMKLDNNData(); out_arr.arr.CopyFrom(*mem); MKLDNNStream::Get()->Submit(); @@ -728,29 +885,30 @@ void TestOp(const OpAttrs &attrs, VerifyFunc verify_fn) { std::vector pds = tas.pds; std::vector in_arrs = GetTestInputArrays(); - for (auto in_arr : in_arrs) { - for (auto dispatch : dispatches) { - std::vector out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds); - for (auto out_arr : out_arrs) { - for (int i = 0; i < attrs.num_inputs; i++) - inputs[i] = &in_arr.arr; + for (auto &in_arr : in_arrs) { + for (auto &dispatch : dispatches) { + std::vector> out_arrs(attrs.num_outputs); + for (int i = 0; i < attrs.num_outputs; i++) + out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds); + for (int i = 0; i < attrs.num_inputs; i++) + inputs[i] = &in_arr.arr; + for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) { for (int i = 0; i < attrs.num_outputs; i++) { req[i] = kWriteTo; - outputs[i] = &out_arr.arr; + outputs[i] = &out_arrs[i][output_i].arr; } - PrintVerifyMsg(in_arr, out_arr); + PrintVerifyMsg(in_arr, out_arrs[0][output_i]); Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req, dispatch, mxnet::OpStatePtr()); - for (auto output : outputs) - output->WaitToRead(); + Engine::Get()->WaitForAll(); verify_fn(inputs, outputs); } } } - for (auto dispatch : dispatches) { + for (auto &dispatch : dispatches) { in_arrs = GetTestInputArrays(); - for (auto arr : in_arrs) { + for (auto &arr : in_arrs) { // If the array is a view, we shouldn't write data to it. if (arr.arr.IsView()) continue; @@ -764,8 +922,7 @@ void TestOp(const OpAttrs &attrs, VerifyFunc verify_fn) { PrintVerifyMsg(orig, arr); Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req, dispatch, mxnet::OpStatePtr()); - for (auto output : outputs) - output->WaitToRead(); + Engine::Get()->WaitForAll(); std::vector orig_inputs(attrs.num_inputs); for (int i = 0; i < attrs.num_inputs; i++) orig_inputs[i] = &orig.arr; @@ -774,6 +931,57 @@ void TestOp(const OpAttrs &attrs, VerifyFunc verify_fn) { } } +void TestConcatOp(const OpAttrs &attrs, VerifyFunc verify_fn, + bool backwards = false) { + std::vector inputs(attrs.num_inputs); + std::vector outputs(attrs.num_outputs); + std::vector req(attrs.num_outputs); + std::vector dispatches = attrs.dispatches; + + TestArrayShapes tas = GetTestArrayShapes(); + std::vector pds = tas.pds; + + std::vector in_arrs = GetTestInputArrays(); + + // concat backwards uses scaled up inputs + if (backwards) { + std::string str_dim = const_cast(attrs).attrs.dict["dim"]; + int dim = std::stoi(str_dim); + in_arrs = GetTestInputArrays(false, attrs.num_outputs, dim); + } + + for (auto &in_arr : in_arrs) { + for (auto &dispatch : dispatches) { + std::vector> out_arrs(attrs.num_outputs); + + std::string str_dim = const_cast(attrs).attrs.dict["dim"]; + int dim = std::stoi(str_dim); + if (dim >= in_arr.arr.shape().ndim()) + continue; + float scale = backwards ? 1 / static_cast(attrs.num_outputs) : + static_cast(attrs.num_inputs); + for (int i = 0; i < attrs.num_outputs; i++) + out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds, scale, dim); + + for (int i = 0; i < attrs.num_inputs; i++) + inputs[i] = &in_arr.arr; + + for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) { + for (int i = 0; i < attrs.num_outputs; i++) { + req[i] = kWriteTo; + outputs[i] = &out_arrs[i][output_i].arr; + } + + PrintVerifyMsg(in_arr, out_arrs[0][output_i]); + Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, + outputs, req, dispatch, mxnet::OpStatePtr()); + Engine::Get()->WaitForAll(); + verify_fn(inputs, outputs); + } + } + } +} + TEST(IMPERATIVE, CopyOp) { OpAttrs attrs = GetCopyOp(); TestOp(attrs, VerifyCopyResult); @@ -804,6 +1012,24 @@ TEST(IMPERATIVE, SumBackwardsOp) { TestOp(attrs, VerifySumBackwardsResult); } +TEST(IMPERATIVE, ConcatOp) { + for (int num_inputs = 2; num_inputs < 4; num_inputs++) { + for (int dim = 0; dim < 5; dim++) { + OpAttrs attrs = GetConcatOp(num_inputs, dim); + TestConcatOp(attrs, VerifyConcatResult); + } + } +} + +TEST(IMPERATIVE, ConcatBackwardsOp) { + for (int num_inputs = 2; num_inputs < 4; num_inputs++) { + for (int dim = 0; dim < 5; dim++) { + OpAttrs attrs = GetConcatBackwardsOp(num_inputs, dim); + TestConcatOp(attrs, VerifyConcatBackwardsResult, true); + } + } +} + TEST(MKLDNN_BASE, MKLDNNSum) { std::vector in_arrs = GetTestInputArrays(); std::vector in_arrs2 = GetTestInputArrays(true); @@ -819,7 +1045,7 @@ TEST(MKLDNN_BASE, MKLDNNSum) { continue; } std::vector out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds); - for (auto out_arr : out_arrs) { + for (auto &out_arr : out_arrs) { auto in_mem1 = in_arr.arr.GetMKLDNNData(); auto in_mem2 = in_arr2.arr.GetMKLDNNData(); if (out_arr.arr.IsView()) @@ -870,7 +1096,7 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) { continue; } std::vector out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds); - for (auto out_arr : out_arrs) { + for (auto &out_arr : out_arrs) { auto in_mem = in_arr.arr.GetMKLDNNData(); auto in_mem2 = in_arr2.arr.GetMKLDNNData(); NDArray orig_output = out_arr.arr.Copy(out_arr.arr.ctx()); @@ -919,7 +1145,7 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) { continue; } std::vector out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds); - for (auto out_arr : out_arrs) { + for (auto &out_arr : out_arrs) { auto in_mem = in_arr.arr.GetMKLDNNData(); auto in_mem2 = in_arr2.arr.GetMKLDNNData(); NDArray orig_output = out_arr.arr.Copy(out_arr.arr.ctx());