Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add stable nrm2 Reducer #11573

Merged
merged 7 commits into from
Jul 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/mshadow
123 changes: 123 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,22 @@ struct product {
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& none) { // NOLINT(*)
Reduce(dst, src);
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& none) {} // NOLINT(*)
/*!
*\brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
Expand Down Expand Up @@ -762,6 +778,26 @@ struct nansum {
residual = (t - dst) - y;
dst = t;
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
DType t1 = dst_val + src_val;
DType e = t1 - src_val;
DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
dst_val = t1 + t2;
dst_residual = t2 - (dst_val - t1);
}
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
/*!
*\brief set the initial value during reduction
*/
Expand Down Expand Up @@ -799,13 +835,30 @@ struct nanprod {
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& none) { // NOLINT(*)
Reduce(dst, src);
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& none) {} // NOLINT(*)
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType & initv) { // NOLINT(*)
initv = 1;
}

/*!
*\brief set the initial value during reduction
*/
Expand All @@ -815,6 +868,76 @@ struct nanprod {
}
};

/*! \brief compute l2 norm */
struct nrm2 {
/*! \brief do reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& sum_of_squares, volatile DType src) { // NOLINT(*)
sum_of_squares += src * src;
}
/*! \brief do stable reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& sum_of_squares, volatile DType src, volatile DType& scale) { // NOLINT(*)
if (src != 0) {
DType abs = mshadow_op::abs::Map(src);
if (scale < abs) {
sum_of_squares = 1 + sum_of_squares * (scale / abs) * (scale / abs);
scale = abs;
} else {
sum_of_squares = sum_of_squares + (abs / scale) * (abs / scale);
}
}
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
dst_val += src_val;
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_ssq, volatile DType& dst_scale, volatile DType& src_ssq, volatile DType& src_scale) { // NOLINT(*)
if (dst_scale != 0 && dst_scale >= src_scale) {
dst_ssq = dst_ssq + src_ssq * (src_scale / dst_scale) * (src_scale / dst_scale);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please elaborate on how this expression was obtained

Copy link
Contributor Author

@leezu leezu Jul 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Remember that we use a scaled sum of squares to compute the L2 norm, to avoid numeric instability caused by the squaring and subsequently taking the square root of very small / large numbers.
For efficient reducing, on GPU multiple reducers compute a reduction of a part of a vector to be reduced. Their result is a scaled sum of squares. To combine the reducers, we must find a common scale for all of them. Following the implementation of Reduce, I choose the largest scale.

Above equation simply rescales the sum of squares of the reducer that currently uses a smaller scale value, such that in the end norm(x) = sqrt(ssq) * scale = dst_scale * sqrt(dst_ssq + src_ssq*src_scale/dst_scale*src_scale_dst_scale) = sqrt(src_scale*src_scale*src_ssq + dst_scale*dst_scale*dst_ssq) (where we wan't to avoid the right part due to numerical instability; here scale and ssq denote what is written to dst_ssq and dst_scale in above code).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explanation!

} else if (src_scale != 0 && dst_scale < src_scale) {
dst_ssq = src_ssq + dst_ssq * (dst_scale / src_scale) * (dst_scale / src_scale);
dst_scale = src_scale;
}
}
/*! \brief finalize reduction result */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& sum_of_squares) { // NOLINT(*)
sum_of_squares = math::sqrt(sum_of_squares);
}
/*! \brief finalize reduction result */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& sum_of_squares, volatile DType& scale) { // NOLINT(*)
sum_of_squares = scale * math::sqrt(sum_of_squares);
}
/*!
*\brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
*/
template<typename DType>
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
return redsrc / redres;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &sum_of_squares) { // NOLINT(*)
sum_of_squares = 0;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &sum_of_squares, DType &scale) { // NOLINT(*)
SetInitValue(sum_of_squares);
scale = 0;
}
};

struct nanprod_grad : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
Expand Down
43 changes: 27 additions & 16 deletions src/operator/tensor/broadcast_reduce-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -123,27 +123,32 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
// Fix bx to avoid bank conflicts. Assumes warpSize number of banks
const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx;
const int it0 = tidx + tidy*fbx;
shTile[it0] = val;
shTile[it0 * 2] = val;
shTile[it0 * 2 + 1] = residual;
__syncthreads();
for (int t=1;t < by;t <<= 1) {
DType tmp, residual;
Reducer::SetInitValue(tmp, residual);
if (tidy + t < by) tmp = shTile[it0 + t*fbx];
DType tmp, tmp_residual;
Reducer::SetInitValue(tmp, tmp_residual);
if (tidy + t < by) {
tmp = shTile[(it0 + t*fbx) * 2];
tmp_residual = shTile[(it0 + t*fbx) * 2 + 1];
}
__syncthreads();
Reducer::Reduce(shTile[it0], tmp, residual);
Reducer::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual);
__syncthreads();
}
if (idx < N && tidy == 0) {
assign(&small[idx + m0*N], addto, shTile[tidx]);
Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
assign(&small[idx + m0*N], addto, shTile[tidx * 2]);
}
} else {
if (idx < N) {
Reducer::Finalize(val, residual);
assign(&small[idx + m0*N], addto, val);
}
}
}
}

}

template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2, int unroll>
Expand Down Expand Up @@ -207,27 +212,32 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
// Fix bx to avoid bank conflicts. Assumes warpSize number of banks
const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx;
const int it0 = tidx + tidy*fbx;
shTile[it0] = val;
shTile[it0 * 2] = val;
shTile[it0 * 2 + 1] = residual;
__syncthreads();
for (int t=1;t < by;t <<= 1) {
DType tmp, residual;
Reducer::SetInitValue(tmp, residual);
if (tidy + t < by) tmp = shTile[it0 + t*fbx];
DType tmp, tmp_residual;
Reducer::SetInitValue(tmp, tmp_residual);
if (tidy + t < by) {
tmp = shTile[(it0 + t*fbx) * 2];
tmp_residual = shTile[(it0 + t*fbx) * 2 + 1];
}
__syncthreads();
Reducer::Reduce(shTile[it0], tmp, residual);
Reducer::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual);
__syncthreads();
}
if (idx < N && tidy == 0) {
assign(&small[idx + m0*N], addto, shTile[tidx]);
Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
assign(&small[idx + m0*N], addto, shTile[tidx * 2]);
}
} else {
if (idx < N) {
Reducer::Finalize(val, residual);
assign(&small[idx + m0*N], addto, val);
}
}
}
}

}

// Simple reduction of lines when M is small
Expand All @@ -244,6 +254,7 @@ __global__ void reduce_lines_kernel(const int N, const int M, const bool addto,
}

if (idx < N) {
Reducer::Finalize(val, residual);
assign(&small_out[idx], addto, val);
}

Expand Down Expand Up @@ -453,7 +464,7 @@ ReduceImplConfig<ndim> ConfigureReduceImpl(const TShape& small, const TShape& bi
by++;
}
config.kernel_1.shMemSize = (config.kernel_1.blockDim.x > 1) ?
config.kernel_1.blockDim.x*by*sizeof(DType) : 0;
config.kernel_1.blockDim.x*by*sizeof(DType) * 2 : 0;
// Maximum number of times we want TB to loop in M
// Max size of M-block each TB can handle
int maxMblock = config.kernel_1.blockDim.x*config.maxLoopPerTB;
Expand All @@ -464,7 +475,7 @@ ReduceImplConfig<ndim> ConfigureReduceImpl(const TShape& small, const TShape& bi
ceil_idiv<unsigned int>(config.N, config.kernel_1.blockDim.x));
config.kernel_1.gridDim.y = std::min(kBaseGridNum, config.Mnext);
config.kernel_1.shMemSize = (config.kernel_1.blockDim.y > 1) ?
config.kernel_1.blockDim.x*config.kernel_1.blockDim.y*sizeof(DType) : 0;
config.kernel_1.blockDim.x*config.kernel_1.blockDim.y*sizeof(DType) * 2 : 0;
// Maximum number of times we want TB to loop in M
// Max size of M-block each TB can handle
int maxMblock = config.kernel_1.blockDim.y*config.maxLoopPerTB;
Expand Down
2 changes: 2 additions & 0 deletions src/operator/tensor/broadcast_reduce-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ MSHADOW_XINLINE void seq_reduce_assign(const int idx, const int M, const bool ad
coord = unravel(k, rshape);
Reducer::Reduce(val, OP::Map(big[j + dot(coord, rstride)]), residual);
}
Reducer::Finalize(val, residual);
assign(&small[idx], addto, val);
}

Expand Down Expand Up @@ -256,6 +257,7 @@ MSHADOW_XINLINE void seq_reduce_assign(const int idx, const int M, const bool ad

Reducer::Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual);
}
Reducer::Finalize(val, residual);
assign(&small[idx], addto, val);
}

Expand Down
3 changes: 1 addition & 2 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1005,9 +1005,8 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs,
ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, mshadow_op::abs>(
ctx, inputs, req, outputs, small);
} else if (param.ord == 2) {
ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, mshadow_op::square>(
ReduceAxesComputeImpl<xpu, mshadow_op::nrm2, false, mshadow_op::identity>(
ctx, inputs, req, outputs, small);
SqRootForL2<xpu>(ctx, req[0], outputs[0]);
}
}

Expand Down
12 changes: 11 additions & 1 deletion tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import mxnet as mx
import numpy as np
from distutils.version import LooseVersion
import os
import pickle as pkl
import unittest
Expand Down Expand Up @@ -1276,10 +1277,19 @@ def test_ndarray_astype():

@with_seed()
def test_norm(ctx=default_context()):
try:
import scipy
assert LooseVersion(scipy.__version__) >= LooseVersion('0.1')
from scipy.linalg import norm as sp_norm
except (AssertionError, ImportError):
print("Could not import scipy.linalg.norm or scipy is too old. "
"Falling back to numpy.linalg.norm which is not numerically stable.")
from numpy.linalg import norm as sp_norm

def l1norm(input_data, axis=0, keepdims=False):
return np.sum(abs(input_data), axis=axis, keepdims=keepdims)
def l2norm(input_data, axis=0, keepdims=False):
return np.linalg.norm(input_data, axis=axis, keepdims=keepdims)
return sp_norm(input_data, axis=axis, keepdims=keepdims)

in_data_dim = random_sample([4,5,6], 1)[0]
in_data_shape = rand_shape_nd(in_data_dim)
Expand Down
Loading