diff --git a/3rdparty/mshadow b/3rdparty/mshadow index a8c650ce8a70..463c0dffe3ea 160000 --- a/3rdparty/mshadow +++ b/3rdparty/mshadow @@ -1 +1 @@ -Subproject commit a8c650ce8a708608a282c4d1e251c57873a8db25 +Subproject commit 463c0dffe3eae8c39caf7989c85b7244823df27e diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 5953568c7faf..81a55c4a0137 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -697,6 +697,22 @@ struct product { MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& none) { // NOLINT(*) Reduce(dst, src); } + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*) + Reduce(dst_val, src_val); + } + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*) + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& none) {} // NOLINT(*) /*! *\brief calculate gradient of redres with respect to redsrc, * redres: reduced result, redsrc: one of reduction element @@ -762,6 +778,26 @@ struct nansum { residual = (t - dst) - y; dst = t; } + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*) + DType t1 = dst_val + src_val; + DType e = t1 - src_val; + DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual; + dst_val = t1 + t2; + dst_residual = t2 - (dst_val - t1); + } + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*) + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*) /*! *\brief set the initial value during reduction */ @@ -799,6 +835,22 @@ struct nanprod { MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& none) { // NOLINT(*) Reduce(dst, src); } + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*) + Reduce(dst_val, src_val); + } + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*) + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& none) {} // NOLINT(*) /*! *\brief set the initial value during reduction */ @@ -806,6 +858,7 @@ struct nanprod { MSHADOW_XINLINE static void SetInitValue(DType & initv) { // NOLINT(*) initv = 1; } + /*! *\brief set the initial value during reduction */ @@ -815,6 +868,76 @@ struct nanprod { } }; +/*! \brief compute l2 norm */ +struct nrm2 { + /*! \brief do reduction into dst */ + template + MSHADOW_XINLINE static void Reduce(volatile DType& sum_of_squares, volatile DType src) { // NOLINT(*) + sum_of_squares += src * src; + } + /*! \brief do stable reduction into dst */ + template + MSHADOW_XINLINE static void Reduce(volatile DType& sum_of_squares, volatile DType src, volatile DType& scale) { // NOLINT(*) + if (src != 0) { + DType abs = mshadow_op::abs::Map(src); + if (scale < abs) { + sum_of_squares = 1 + sum_of_squares * (scale / abs) * (scale / abs); + scale = abs; + } else { + sum_of_squares = sum_of_squares + (abs / scale) * (abs / scale); + } + } + } + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) + dst_val += src_val; + } + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_ssq, volatile DType& dst_scale, volatile DType& src_ssq, volatile DType& src_scale) { // NOLINT(*) + if (dst_scale != 0 && dst_scale >= src_scale) { + dst_ssq = dst_ssq + src_ssq * (src_scale / dst_scale) * (src_scale / dst_scale); + } else if (src_scale != 0 && dst_scale < src_scale) { + dst_ssq = src_ssq + dst_ssq * (dst_scale / src_scale) * (dst_scale / src_scale); + dst_scale = src_scale; + } + } + /*! \brief finalize reduction result */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& sum_of_squares) { // NOLINT(*) + sum_of_squares = math::sqrt(sum_of_squares); + } + /*! \brief finalize reduction result */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& sum_of_squares, volatile DType& scale) { // NOLINT(*) + sum_of_squares = scale * math::sqrt(sum_of_squares); + } + /*! + *\brief calculate gradient of redres with respect to redsrc, + * redres: reduced result, redsrc: one of reduction element + */ + template + MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { + return redsrc / redres; + } + /*! + *\brief set the initial value during reduction + */ + template + MSHADOW_XINLINE static void SetInitValue(DType &sum_of_squares) { // NOLINT(*) + sum_of_squares = 0; + } + /*! + *\brief set the initial value during reduction + */ + template + MSHADOW_XINLINE static void SetInitValue(DType &sum_of_squares, DType &scale) { // NOLINT(*) + SetInitValue(sum_of_squares); + scale = 0; + } +}; + struct nanprod_grad : public mxnet_op::tunable { template MSHADOW_XINLINE static DType Map(DType a, DType b) { diff --git a/src/operator/tensor/broadcast_reduce-inl.cuh b/src/operator/tensor/broadcast_reduce-inl.cuh index b6bb39a19847..5c9b45f547fc 100644 --- a/src/operator/tensor/broadcast_reduce-inl.cuh +++ b/src/operator/tensor/broadcast_reduce-inl.cuh @@ -123,27 +123,32 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, // Fix bx to avoid bank conflicts. Assumes warpSize number of banks const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx; const int it0 = tidx + tidy*fbx; - shTile[it0] = val; + shTile[it0 * 2] = val; + shTile[it0 * 2 + 1] = residual; __syncthreads(); for (int t=1;t < by;t <<= 1) { - DType tmp, residual; - Reducer::SetInitValue(tmp, residual); - if (tidy + t < by) tmp = shTile[it0 + t*fbx]; + DType tmp, tmp_residual; + Reducer::SetInitValue(tmp, tmp_residual); + if (tidy + t < by) { + tmp = shTile[(it0 + t*fbx) * 2]; + tmp_residual = shTile[(it0 + t*fbx) * 2 + 1]; + } __syncthreads(); - Reducer::Reduce(shTile[it0], tmp, residual); + Reducer::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); __syncthreads(); } if (idx < N && tidy == 0) { - assign(&small[idx + m0*N], addto, shTile[tidx]); + Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); + assign(&small[idx + m0*N], addto, shTile[tidx * 2]); } } else { if (idx < N) { + Reducer::Finalize(val, residual); assign(&small[idx + m0*N], addto, val); } } } } - } template @@ -207,27 +212,32 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto, // Fix bx to avoid bank conflicts. Assumes warpSize number of banks const int fbx = (do_transpose && ((bx & (warpSize - 1)) == 0)) ? (bx + 1) : bx; const int it0 = tidx + tidy*fbx; - shTile[it0] = val; + shTile[it0 * 2] = val; + shTile[it0 * 2 + 1] = residual; __syncthreads(); for (int t=1;t < by;t <<= 1) { - DType tmp, residual; - Reducer::SetInitValue(tmp, residual); - if (tidy + t < by) tmp = shTile[it0 + t*fbx]; + DType tmp, tmp_residual; + Reducer::SetInitValue(tmp, tmp_residual); + if (tidy + t < by) { + tmp = shTile[(it0 + t*fbx) * 2]; + tmp_residual = shTile[(it0 + t*fbx) * 2 + 1]; + } __syncthreads(); - Reducer::Reduce(shTile[it0], tmp, residual); + Reducer::Merge(shTile[it0 * 2], shTile[it0 * 2 + 1], tmp, tmp_residual); __syncthreads(); } if (idx < N && tidy == 0) { - assign(&small[idx + m0*N], addto, shTile[tidx]); + Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]); + assign(&small[idx + m0*N], addto, shTile[tidx * 2]); } } else { if (idx < N) { + Reducer::Finalize(val, residual); assign(&small[idx + m0*N], addto, val); } } } } - } // Simple reduction of lines when M is small @@ -244,6 +254,7 @@ __global__ void reduce_lines_kernel(const int N, const int M, const bool addto, } if (idx < N) { + Reducer::Finalize(val, residual); assign(&small_out[idx], addto, val); } @@ -453,7 +464,7 @@ ReduceImplConfig ConfigureReduceImpl(const TShape& small, const TShape& bi by++; } config.kernel_1.shMemSize = (config.kernel_1.blockDim.x > 1) ? - config.kernel_1.blockDim.x*by*sizeof(DType) : 0; + config.kernel_1.blockDim.x*by*sizeof(DType) * 2 : 0; // Maximum number of times we want TB to loop in M // Max size of M-block each TB can handle int maxMblock = config.kernel_1.blockDim.x*config.maxLoopPerTB; @@ -464,7 +475,7 @@ ReduceImplConfig ConfigureReduceImpl(const TShape& small, const TShape& bi ceil_idiv(config.N, config.kernel_1.blockDim.x)); config.kernel_1.gridDim.y = std::min(kBaseGridNum, config.Mnext); config.kernel_1.shMemSize = (config.kernel_1.blockDim.y > 1) ? - config.kernel_1.blockDim.x*config.kernel_1.blockDim.y*sizeof(DType) : 0; + config.kernel_1.blockDim.x*config.kernel_1.blockDim.y*sizeof(DType) * 2 : 0; // Maximum number of times we want TB to loop in M // Max size of M-block each TB can handle int maxMblock = config.kernel_1.blockDim.y*config.maxLoopPerTB; diff --git a/src/operator/tensor/broadcast_reduce-inl.h b/src/operator/tensor/broadcast_reduce-inl.h index 76ec92a9e724..713e3f1ac602 100644 --- a/src/operator/tensor/broadcast_reduce-inl.h +++ b/src/operator/tensor/broadcast_reduce-inl.h @@ -165,6 +165,7 @@ MSHADOW_XINLINE void seq_reduce_assign(const int idx, const int M, const bool ad coord = unravel(k, rshape); Reducer::Reduce(val, OP::Map(big[j + dot(coord, rstride)]), residual); } + Reducer::Finalize(val, residual); assign(&small[idx], addto, val); } @@ -256,6 +257,7 @@ MSHADOW_XINLINE void seq_reduce_assign(const int idx, const int M, const bool ad Reducer::Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual); } + Reducer::Finalize(val, residual); assign(&small[idx], addto, val); } diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index ac7199a94823..d9a749e0db82 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -1005,9 +1005,8 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs, ReduceAxesComputeImpl( ctx, inputs, req, outputs, small); } else if (param.ord == 2) { - ReduceAxesComputeImpl( + ReduceAxesComputeImpl( ctx, inputs, req, outputs, small); - SqRootForL2(ctx, req[0], outputs[0]); } } diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index cf5906ae4546..b57e71d73b2a 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -17,6 +17,7 @@ import mxnet as mx import numpy as np +from distutils.version import LooseVersion import os import pickle as pkl import unittest @@ -1276,10 +1277,19 @@ def test_ndarray_astype(): @with_seed() def test_norm(ctx=default_context()): + try: + import scipy + assert LooseVersion(scipy.__version__) >= LooseVersion('0.1') + from scipy.linalg import norm as sp_norm + except (AssertionError, ImportError): + print("Could not import scipy.linalg.norm or scipy is too old. " + "Falling back to numpy.linalg.norm which is not numerically stable.") + from numpy.linalg import norm as sp_norm + def l1norm(input_data, axis=0, keepdims=False): return np.sum(abs(input_data), axis=axis, keepdims=keepdims) def l2norm(input_data, axis=0, keepdims=False): - return np.linalg.norm(input_data, axis=axis, keepdims=keepdims) + return sp_norm(input_data, axis=axis, keepdims=keepdims) in_data_dim = random_sample([4,5,6], 1)[0] in_data_shape = rand_shape_nd(in_data_dim) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index ae5cba21711a..5a2067eab4ad 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -23,6 +23,7 @@ import math import random import itertools +from distutils.version import LooseVersion from numpy.testing import assert_allclose, assert_array_equal from mxnet.test_utils import * from mxnet.base import py_str, MXNetError, _as_list @@ -3031,13 +3032,22 @@ def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5): grad_nodes={'data': req, 'gamma': req, 'beta': req}, numeric_eps=1e-2, rtol=1e-2, atol=1e-2) -@unittest.skip("Flaky test: https://github.com/apache/incubator-mxnet/issues/11509") @with_seed() def test_norm(): + try: + import scipy + assert LooseVersion(scipy.__version__) >= LooseVersion('0.1') + from scipy.linalg import norm as sp_norm + except (AssertionError, ImportError): + print("Could not import scipy.linalg.norm or scipy is too old. " + "Falling back to numpy.linalg.norm which is not numerically stable.") + from numpy.linalg import norm as sp_norm + def l1norm(input_data, axis=0, keepdims=True): return np.sum(abs(input_data), axis=axis, keepdims=keepdims) - def l2norm(input_data, axis=0, keepdims=True): - return np.linalg.norm(input_data, axis=axis, keepdims=keepdims) + + def l2norm(input_data, axis=0, keepdims=True): + return sp_norm(input_data, axis=axis, keepdims=keepdims) ctx = default_context() data = mx.symbol.Variable('data') @@ -3051,7 +3061,7 @@ def l2norm(input_data, axis=0, keepdims=True): for i in range(in_data_dim): norm_sym = mx.symbol.norm(data=data, ord=order, axis=i, keepdims=True) npy_out = l1norm(in_data, i) if order is 1 else l2norm(in_data, i) - npy_out_backward = np.sign(in_data) if order is 1 else in_data/npy_out + npy_out_backward = np.sign(in_data) if order is 1 else in_data/npy_out check_symbolic_forward(norm_sym, [in_data], [npy_out], rtol=1e-2 if dtype is np.float16 else 1e-5, atol=1e-2 if dtype is np.float16 else 1e-5, ctx=ctx) @@ -3059,22 +3069,23 @@ def l2norm(input_data, axis=0, keepdims=True): [npy_out_backward], rtol=1e-2 if dtype is np.float16 else 1e-5, atol=1e-2 if dtype is np.float16 else 1e-5, ctx=ctx) - # check gradient - check_numeric_gradient(norm_sym, [in_data], numeric_eps=epsilon, rtol=1e-2, atol=1e-3) - if i < in_data_dim-1: - norm_sym = mx.symbol.norm(data=data, ord=order, axis=(i, i+1), keepdims=True) - npy_out = l1norm(in_data, (i, i+1)) if order is 1 else l2norm(in_data, (i, i+1)) - npy_out_backward = np.sign(in_data) if order is 1 else in_data/npy_out - check_symbolic_forward(norm_sym, [in_data], [npy_out], - rtol=1e-2 if dtype is np.float16 else 1e-5, - atol=1e-2 if dtype is np.float16 else 1e-5, ctx=ctx) - check_symbolic_backward(norm_sym, [in_data], [np.ones(npy_out.shape)], - [npy_out_backward], - rtol=1e-2 if dtype is np.float16 else 1e-5, - atol=1e-2 if dtype is np.float16 else 1e-5, ctx=ctx) - # check gradient - check_numeric_gradient(norm_sym, [in_data], numeric_eps=epsilon, rtol=1e-2, atol=1e-3) - + # Disable numeric gradient https://github.com/apache/incubator-mxnet/issues/11509 + # # check gradient + # check_numeric_gradient(norm_sym, [in_data], numeric_eps=epsilon, rtol=1e-2, atol=1e-3) + # if i < in_data_dim-1: + # norm_sym = mx.symbol.norm(data=data, ord=order, axis=(i, i+1), keepdims=True) + # npy_out = l1norm(in_data, (i, i+1)) if order is 1 else l2norm(in_data, (i, i+1)) + # npy_out_backward = np.sign(in_data) if order is 1 else in_data/npy_out + # check_symbolic_forward(norm_sym, [in_data], [npy_out], + # rtol=1e-2 if dtype is np.float16 else 1e-5, + # atol=1e-2 if dtype is np.float16 else 1e-5, ctx=ctx) + # check_symbolic_backward(norm_sym, [in_data], [np.ones(npy_out.shape)], + # [npy_out_backward], + # rtol=1e-2 if dtype is np.float16 else 1e-5, + # atol=1e-2 if dtype is np.float16 else 1e-5, ctx=ctx) + # # check gradient + # check_numeric_gradient(norm_sym, [in_data], numeric_eps=epsilon, rtol=1e-2, atol=1e-3) + def test_layer_norm(): for dtype, forward_check_eps in zip([np.float16, np.float32, np.float64],