Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add stable nrm2 Reducer
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Jul 5, 2018
1 parent e870890 commit c75e5ca
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 3 deletions.
73 changes: 73 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,12 @@ struct product {
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& none) { // NOLINT(*)
Reduce(dst, src);
}
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& none) {} // NOLINT(*)
/*!
*\brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
Expand Down Expand Up @@ -762,6 +768,12 @@ struct nansum {
residual = (t - dst) - y;
dst = t;
}
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
/*!
*\brief set the initial value during reduction
*/
Expand Down Expand Up @@ -799,6 +811,12 @@ struct nanprod {
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& none) { // NOLINT(*)
Reduce(dst, src);
}
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& none) {} // NOLINT(*)
/*!
*\brief set the initial value during reduction
*/
Expand All @@ -815,6 +833,61 @@ struct nanprod {
}
};

/*! \brief compute l2 norm */
struct nrm2 {
/*! \brief do reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& sum_of_squares, volatile DType src) { // NOLINT(*)
sum_of_squares += src * src;
}
/*! \brief do stable reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& sum_of_squares, volatile DType src, volatile DType& scale) { // NOLINT(*)
if (src != 0) {
DType abs = mshadow_op::abs::Map(src);
if (scale < abs) {
sum_of_squares = 1 + sum_of_squares * (scale / abs) * (scale / abs);
scale = abs;
} else {
sum_of_squares = sum_of_squares + (abs / scale) * (abs / scale);
}
}
}
/*! \brief finalize reduction result */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& sum_of_squares) { // NOLINT(*)
sum_of_squares = std::sqrt(sum_of_squares);
}
/*! \brief finalize reduction result */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& sum_of_squares, volatile DType& scale) { // NOLINT(*)
sum_of_squares = scale * std::sqrt(sum_of_squares);
}
/*!
*\brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
*/
template<typename DType>
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
return redsrc / redres;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &sum_of_squares) { // NOLINT(*)
sum_of_squares = 0;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &sum_of_squares, DType &scale) { // NOLINT(*)
SetInitValue(sum_of_squares);
scale = 0;
}
};

struct nanprod_grad : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
Expand Down
3 changes: 3 additions & 0 deletions src/operator/tensor/broadcast_reduce-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ namespace op {
namespace broadcast {
using namespace mshadow;


const int MAX_DIM = 5;

template<int ndim>
Expand Down Expand Up @@ -165,6 +166,7 @@ MSHADOW_XINLINE void seq_reduce_assign(const int idx, const int M, const bool ad
coord = unravel(k, rshape);
Reducer::Reduce(val, OP::Map(big[j + dot(coord, rstride)]), residual);
}
Reducer::Finalize(val, residual);
assign(&small[idx], addto, val);
}

Expand Down Expand Up @@ -256,6 +258,7 @@ MSHADOW_XINLINE void seq_reduce_assign(const int idx, const int M, const bool ad

Reducer::Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual);
}
Reducer::Finalize(val, residual);
assign(&small[idx], addto, val);
}

Expand Down
3 changes: 1 addition & 2 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1005,9 +1005,8 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs,
ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, mshadow_op::abs>(
ctx, inputs, req, outputs, small);
} else if (param.ord == 2) {
ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, mshadow_op::square>(
ReduceAxesComputeImpl<xpu, mshadow_op::nrm2, false, mshadow_op::identity>(
ctx, inputs, req, outputs, small);
SqRootForL2<xpu>(ctx, req[0], outputs[0]);
}
}

Expand Down
1 change: 0 additions & 1 deletion tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3031,7 +3031,6 @@ def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5):
grad_nodes={'data': req, 'gamma': req, 'beta': req},
numeric_eps=1e-2, rtol=1e-2, atol=1e-2)

@unittest.skip("Flaky test: https://github.com/apache/incubator-mxnet/issues/11509")
@with_seed()
def test_norm():
def l1norm(input_data, axis=0, keepdims=True):
Expand Down

0 comments on commit c75e5ca

Please sign in to comment.