Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
support for add_n(dense, csr, dense) = dense with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Hao Jin committed Jun 19, 2018
1 parent 47e2b89 commit fedab5c
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 10 deletions.
30 changes: 30 additions & 0 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,36 @@ inline bool ContainsOnlyStorage(const std::vector<NDArray>& ndarrays,
return false;
}

/*! \brief returns true if storage type of any array in `ndarrays`
* is the same as the target `stype`. false is returned for empty inputs.
*/
inline bool ContainsStorageType(const std::vector<NDArray>& ndarrays,
const NDArrayStorageType stype) {
if (!ndarrays.empty()) {
for (const auto& nd : ndarrays) {
if (nd.storage_type() == stype) {
return true;
}
}
}
return false;
}

/*! \brief returns true if any storage type `ndstype` in `ndstypes`
* is the same as the target `stype`. false is returned for empty inputs.
*/
inline bool ContainsStorageType(const std::vector<int>& ndstypes,
const NDArrayStorageType stype) {
if (!ndstypes.empty()) {
for (const auto& ndstype : ndstypes) {
if (ndstype == stype) {
return true;
}
}
}
return false;
}

/*! \brief get string representation of dispatch_mode */
inline std::string dispatch_mode_string(const DispatchMode x) {
switch (x) {
Expand Down
62 changes: 61 additions & 1 deletion src/ndarray/ndarray_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include "./ndarray_function.h"
#include "./ndarray_function-inl.h"
#include "../common/utils.h"
#include "../operator/mxnet_op.h"
#include "../operator/tensor/elemwise_binary_op-inl.h"

namespace mxnet {
namespace ndarray {
Expand Down Expand Up @@ -165,6 +167,61 @@ void ElementwiseSumRsp(mshadow::Stream<cpu>* s,
});
}

void ElementwiseSumContainsDnsImpl(mshadow::Stream<cpu>* s,
const Resource& rsc,
const std::vector<NDArray>& nds,
NDArray* out) {
using namespace mxnet::op;
using namespace mxnet::op::mxnet_op;
const TBlob& out_data = out->data();
MSHADOW_TYPE_SWITCH(out->dtype(), DType, { // data type
Kernel<set_zero, cpu>::Launch(s, out_data.Size(), out_data.dptr<DType>());
for (const auto& nd : nds) {
const nnvm::dim_t num_rows = nd.shape()[0];
const nnvm::dim_t num_cols = nd.shape()[1];
const TBlob& nd_data = nd.data();
switch (nd.storage_type()) {
case kDefaultStorage: {
Kernel<op_with_req<mshadow_op::plus, kWriteTo>, cpu>::Launch(
s, out_data.Size(), out_data.dptr<DType>(), out_data.dptr<DType>(),
nd_data.dptr<DType>());
break;
}
case kCSRStorage: {
const TBlob& nd_indices = nd.aux_data(csr::kIdx);
const TBlob& nd_indptr = nd.aux_data(csr::kIndPtr);
MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, { // indices type
MSHADOW_IDX_TYPE_SWITCH(nd_indptr.type_flag_, CType, { // indptr type
if (nd.storage_initialized()) {
Kernel<ElemwiseDnsCsrDnsKernel<kWriteTo, mshadow_op::plus>, cpu>::Launch(
s, num_rows, out_data.dptr<DType>(), out_data.dptr<DType>(),
nd_data.dptr<DType>(), nd_indices.dptr<IType>(),
nd_indptr.dptr<CType>(), num_rows, num_cols);
}
});
});
break;
}
case kRowSparseStorage: {
const TBlob& nd_indices = nd.aux_data(rowsparse::kIdx);
MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, { // indices type
if (nd.storage_initialized()) {
const nnvm::dim_t nz_rows = nd_indices.Size();
Kernel<ElemwiseDnsRspDnsKernel<kWriteTo, mshadow_op::plus>, cpu>::Launch(
s, nz_rows * num_cols, out_data.dptr<DType>(),
out_data.dptr<DType>(), nd_data.dptr<DType>(), nd_indices.dptr<IType>(),
num_rows, nz_rows, num_cols);
}
});
break;
}
default:
LOG(FATAL) << "unknown storage type " << nd.storage_type() << "encountered...";
}
}
});
}

/*!
* \brief Parallel cpu impl of elemwise sum for sparse tensors.
* Currently only support row sparse sum.
Expand All @@ -175,8 +232,11 @@ void ElementwiseSum<cpu>(mshadow::Stream<cpu>* s,
const std::vector<NDArray>& nds,
NDArray* out) {
if (nds.empty()) return;
if (nds[0].storage_type() == kRowSparseStorage) {
if (common::ContainsOnlyStorage(nds, kRowSparseStorage)) {
ElementwiseSumRsp(s, rsc, nds, out);
} else if (common::ContainsStorageType(nds, kDefaultStorage) &&
out->storage_type() == kDefaultStorage) {
ElementwiseSumContainsDnsImpl(s, rsc, nds, out);
} else {
LOG(FATAL) << "ElementwiseSum<cpu> has not been implemented for storage_type = << "
<< nds[0].storage_type();
Expand Down
61 changes: 60 additions & 1 deletion src/ndarray/ndarray_function.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <cub/cub.cuh>
#include <dmlc/logging.h>
#include "../operator/mxnet_op.h"
#include "../operator/tensor/elemwise_binary_op-inl.h"
#include "../operator/tensor/init_op.h"
#include "../operator/tensor/util/tensor_util-inl.h"
#include "../operator/tensor/util/tensor_util-inl.cuh"
Expand Down Expand Up @@ -185,6 +186,61 @@ void ElementwiseSumRspImpl(mshadow::Stream<gpu>* s,
});
}

void ElementwiseSumContainsDnsImpl(mshadow::Stream<gpu>* s,
const Resource& rsc,
const std::vector<NDArray>& nds,
NDArray* out) {
using namespace mxnet::op;
using namespace mxnet::op::mxnet_op;
const TBlob& out_data = out->data();
MSHADOW_TYPE_SWITCH(out->dtype(), DType, { // data type
Kernel<set_zero, gpu>::Launch(s, out_data.Size(), out_data.dptr<DType>());
for (const auto& nd : nds) {
const nnvm::dim_t num_rows = nd.shape()[0];
const nnvm::dim_t num_cols = nd.shape()[1];
const TBlob& nd_data = nd.data();
switch (nd.storage_type()) {
case kDefaultStorage: {
Kernel<op_with_req<mshadow_op::plus, kWriteTo>, gpu>::Launch(
s, out_data.Size(), out_data.dptr<DType>(), out_data.dptr<DType>(),
nd_data.dptr<DType>());
break;
}
case kCSRStorage: {
const TBlob& nd_indices = nd.aux_data(csr::kIdx);
const TBlob& nd_indptr = nd.aux_data(csr::kIndPtr);
MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, { // indices type
MSHADOW_IDX_TYPE_SWITCH(nd_indptr.type_flag_, CType, { // indptr type
if (nd.storage_initialized()) {
Kernel<ElemwiseDnsCsrDnsKernel<kWriteTo, mshadow_op::plus>, gpu>::Launch(
s, num_rows, out_data.dptr<DType>(), out_data.dptr<DType>(),
nd_data.dptr<DType>(), nd_indices.dptr<IType>(),
nd_indptr.dptr<CType>(), num_rows, num_cols);
}
});
});
break;
}
case kRowSparseStorage: {
const TBlob& nd_indices = nd.aux_data(rowsparse::kIdx);
MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, { // indices type
if (nd.storage_initialized()) {
const nnvm::dim_t nz_rows = nd_indices.Size();
Kernel<ElemwiseDnsRspDnsKernel<kWriteTo, mshadow_op::plus>, gpu>::Launch(
s, nz_rows * num_cols, out_data.dptr<DType>(),
out_data.dptr<DType>(), nd_data.dptr<DType>(), nd_indices.dptr<IType>(),
num_rows, nz_rows, num_cols);
}
});
break;
}
default:
LOG(FATAL) << "unknown storage type " << nd.storage_type() << "encountered...";
}
}
});
}

/*!
* \brief Parallel gpu impl of elemwise sum for sparse tensors.
* Currently only support row sparse sum.
Expand All @@ -195,8 +251,11 @@ void ElementwiseSum<gpu>(mshadow::Stream<gpu>* s,
const std::vector<NDArray>& nds,
NDArray* out) {
if (nds.empty()) return;
if (nds[0].storage_type() == kRowSparseStorage) {
if (common::ContainsOnlyStorage(nds, kRowSparseStorage)) {
ElementwiseSumRspImpl(s, rsc, nds, out);
} else if (common::ContainsStorageType(nds, kDefaultStorage) &&
out->storage_type() == kDefaultStorage) {
ElementwiseSumContainsDnsImpl(s, rsc, nds, out);
} else {
LOG(FATAL) << "ElementwiseSum<gpu> has not been implemented for storage_type = << "
<< nds[0].storage_type();
Expand Down
5 changes: 5 additions & 0 deletions src/operator/elemwise_op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ inline bool ElemwiseStorageAttr(const nnvm::NodeAttrs& attrs,
dispatched = storage_type_assign(out_attrs, kCSRStorage,
dispatch_mode, dispatch_ex);
}
if (!dispatched && ContainsStorageType(*in_attrs, kDefaultStorage)) {
// *, dense, * -> dense
dispatched = storage_type_assign(out_attrs, kDefaultStorage,
dispatch_mode, dispatch_ex);
}
if (!dispatched) {
dispatch_fallback(out_attrs, dispatch_mode);
}
Expand Down
5 changes: 4 additions & 1 deletion src/operator/tensor/elemwise_sum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ void ElementWiseSumComputeExCPU(const nnvm::NodeAttrs& attrs,
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
if (req[0] == kNullOp) return;
if (inputs[0].storage_type() == kRowSparseStorage) {
if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) ||
(common::ContainsStorageType(inputs, kDefaultStorage) &&
outputs[0].storage_type() == kDefaultStorage)) {
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
Resource rsc = ResourceManager::Get()->Request(ctx.run_ctx.get_ctx(),
ResourceRequest(ResourceRequest::kTempSpace));
Expand Down Expand Up @@ -145,6 +147,7 @@ MXNET_ADD_SPARSE_OP_ALIAS(ElementWiseSum)
The storage type of ``add_n`` output depends on storage types of inputs
- add_n(row_sparse, row_sparse, ..) = row_sparse
- add_n([default, csr, row_sparse]*, default, [default, csr, row_sparse]*) = default
- otherwise, ``add_n`` generates output with default storage
)doc" ADD_FILELINE)
Expand Down
4 changes: 3 additions & 1 deletion src/operator/tensor/elemwise_sum.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ void ElementWiseSumComputeExGPU(const nnvm::NodeAttrs& attrs,
CHECK_EQ(req.size(), 1U);
if (req[0] == kNullOp) return;
CHECK_EQ(req[0], kWriteTo) << "ElementWiseSumComputeExGPU only supports req = kWriteTo";
if (inputs[0].storage_type() == kRowSparseStorage) {
if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) ||
(common::ContainsStorageType(inputs, kDefaultStorage) &&
outputs[0].storage_type() == kDefaultStorage)) {
mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
NDArray out_nd = outputs[0];
mxnet::ndarray::ElementwiseSum<gpu>(s, ctx.requested[0], inputs, &out_nd);
Expand Down
23 changes: 17 additions & 6 deletions tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1698,14 +1698,14 @@ def check_operator_with_temp_resource(shape, stype):

@with_seed()
def test_sparse_elementwise_sum():
def check_sparse_elementwise_sum_with_shape(stype, shape, n):
def check_sparse_elementwise_sum_with_shape(stypes, shape, n):
# forward
inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)]
out = mx.symbol.sparse.add_n(*inputs, name='esum')
arr = []
arr_grad = [mx.nd.empty(shape, stype=stype) for _ in range(n)]
arr_grad = [mx.nd.empty(shape, stype=stype) for stype in stypes]
densities = [0, 0.01, 0.5, 1.0]
for i in range(n):
for stype in stypes:
arr.append(rand_ndarray(shape, stype, densities[np.random.randint(0, len(densities))]))

exec1 = out.bind(default_context(),
Expand All @@ -1714,18 +1714,29 @@ def check_sparse_elementwise_sum_with_shape(stype, shape, n):
exec1.forward(is_train=True)
out1 = exec1.outputs[0].asnumpy()
out = sum(a.asnumpy() for a in arr)
assert_almost_equal(out, out1)
assert_almost_equal(out, out1, atol=1e-5)

out_grad = mx.nd.empty(shape)
out_grad[:] = np.random.uniform(-10, 10, shape)
# backward
exec1.backward([out_grad])
for a in arr_grad:
assert_almost_equal(a.asnumpy(), out_grad.asnumpy())
assert_almost_equal(a.asnumpy(), out_grad.asnumpy(), atol=1e-5)

all_stypes = ['default', 'csr', 'row_sparse']
for dim in range(2, 4):
shape = tuple(np.random.randint(5, 10, size=dim))
check_sparse_elementwise_sum_with_shape('row_sparse', shape, np.random.randint(1, 9))
rsp_test_cnt = np.random.randint(1, 9)
check_sparse_elementwise_sum_with_shape(['row_sparse' for i in range(rsp_test_cnt)], shape, rsp_test_cnt)
if dim is 2:
test_len = np.random.randint(5, 10)
# at least one default type
stypes = ['default']
for i in range(test_len):
pick_side = np.random.randint(2)
pick_type = np.random.randint(3)
stypes = ([all_stypes[pick_type]] if pick_side is 0 else []) + stypes + ([all_stypes[pick_type]] if pick_side is 1 else [])
check_sparse_elementwise_sum_with_shape(stypes, shape, test_len+1)


@with_seed()
Expand Down

0 comments on commit fedab5c

Please sign in to comment.