From c75e5ca744ab0cc0665af170279156a2e270a4a7 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Wed, 4 Jul 2018 07:28:33 +0000 Subject: [PATCH] Add stable nrm2 Reducer --- src/operator/mshadow_op.h | 73 ++++++++++++++++++++++ src/operator/tensor/broadcast_reduce-inl.h | 3 + src/operator/tensor/broadcast_reduce_op.h | 3 +- tests/python/unittest/test_operator.py | 1 - 4 files changed, 77 insertions(+), 3 deletions(-) diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 5953568c7faf..022418b1caf7 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -697,6 +697,12 @@ struct product { MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& none) { // NOLINT(*) Reduce(dst, src); } + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*) + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& none) {} // NOLINT(*) /*! *\brief calculate gradient of redres with respect to redsrc, * redres: reduced result, redsrc: one of reduction element @@ -762,6 +768,12 @@ struct nansum { residual = (t - dst) - y; dst = t; } + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*) + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*) /*! *\brief set the initial value during reduction */ @@ -799,6 +811,12 @@ struct nanprod { MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& none) { // NOLINT(*) Reduce(dst, src); } + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*) + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& none) {} // NOLINT(*) /*! *\brief set the initial value during reduction */ @@ -815,6 +833,61 @@ struct nanprod { } }; +/*! \brief compute l2 norm */ +struct nrm2 { + /*! \brief do reduction into dst */ + template + MSHADOW_XINLINE static void Reduce(volatile DType& sum_of_squares, volatile DType src) { // NOLINT(*) + sum_of_squares += src * src; + } + /*! \brief do stable reduction into dst */ + template + MSHADOW_XINLINE static void Reduce(volatile DType& sum_of_squares, volatile DType src, volatile DType& scale) { // NOLINT(*) + if (src != 0) { + DType abs = mshadow_op::abs::Map(src); + if (scale < abs) { + sum_of_squares = 1 + sum_of_squares * (scale / abs) * (scale / abs); + scale = abs; + } else { + sum_of_squares = sum_of_squares + (abs / scale) * (abs / scale); + } + } + } + /*! \brief finalize reduction result */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& sum_of_squares) { // NOLINT(*) + sum_of_squares = std::sqrt(sum_of_squares); + } + /*! \brief finalize reduction result */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& sum_of_squares, volatile DType& scale) { // NOLINT(*) + sum_of_squares = scale * std::sqrt(sum_of_squares); + } + /*! + *\brief calculate gradient of redres with respect to redsrc, + * redres: reduced result, redsrc: one of reduction element + */ + template + MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { + return redsrc / redres; + } + /*! + *\brief set the initial value during reduction + */ + template + MSHADOW_XINLINE static void SetInitValue(DType &sum_of_squares) { // NOLINT(*) + sum_of_squares = 0; + } + /*! + *\brief set the initial value during reduction + */ + template + MSHADOW_XINLINE static void SetInitValue(DType &sum_of_squares, DType &scale) { // NOLINT(*) + SetInitValue(sum_of_squares); + scale = 0; + } +}; + struct nanprod_grad : public mxnet_op::tunable { template MSHADOW_XINLINE static DType Map(DType a, DType b) { diff --git a/src/operator/tensor/broadcast_reduce-inl.h b/src/operator/tensor/broadcast_reduce-inl.h index 76ec92a9e724..974047504523 100644 --- a/src/operator/tensor/broadcast_reduce-inl.h +++ b/src/operator/tensor/broadcast_reduce-inl.h @@ -37,6 +37,7 @@ namespace op { namespace broadcast { using namespace mshadow; + const int MAX_DIM = 5; template @@ -165,6 +166,7 @@ MSHADOW_XINLINE void seq_reduce_assign(const int idx, const int M, const bool ad coord = unravel(k, rshape); Reducer::Reduce(val, OP::Map(big[j + dot(coord, rstride)]), residual); } + Reducer::Finalize(val, residual); assign(&small[idx], addto, val); } @@ -256,6 +258,7 @@ MSHADOW_XINLINE void seq_reduce_assign(const int idx, const int M, const bool ad Reducer::Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual); } + Reducer::Finalize(val, residual); assign(&small[idx], addto, val); } diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index ac7199a94823..d9a749e0db82 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -1005,9 +1005,8 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs, ReduceAxesComputeImpl( ctx, inputs, req, outputs, small); } else if (param.ord == 2) { - ReduceAxesComputeImpl( + ReduceAxesComputeImpl( ctx, inputs, req, outputs, small); - SqRootForL2(ctx, req[0], outputs[0]); } } diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index ae5cba21711a..0befc6be5f8a 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3031,7 +3031,6 @@ def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5): grad_nodes={'data': req, 'gamma': req, 'beta': req}, numeric_eps=1e-2, rtol=1e-2, atol=1e-2) -@unittest.skip("Flaky test: https://github.com/apache/incubator-mxnet/issues/11509") @with_seed() def test_norm(): def l1norm(input_data, axis=0, keepdims=True):