Skip to content

Commit

Permalink
Replace std::swap_ranges with memcpy (apache#10351)
Browse files Browse the repository at this point in the history
  • Loading branch information
asitstands authored and piiswrong committed Apr 2, 2018
1 parent 97fa0f9 commit f0e1c76
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/operator/random/shuffle_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <algorithm>
#include <random>
#include <vector>
#include <cstring>
#ifdef USE_GNU_PARALLEL_SHUFFLE
#include <parallel/algorithm>
#endif
Expand All @@ -55,18 +56,24 @@ void Shuffle1D(DType* const out, const index_t size, Rand* const prnd) {

template<typename DType, typename Rand>
void ShuffleND(DType* const out, const index_t size, const index_t first_axis_len,
Rand* const prnd) {
Rand* const prnd, const OpContext& ctx) {
// Fisher-Yates shuffling
using namespace mxnet_op;
const index_t stride = size / first_axis_len;
auto rand_n = [prnd](index_t n) {
std::uniform_int_distribution<index_t> dist(0, n - 1);
return dist(*prnd);
};
CHECK_GT(first_axis_len, 0U);
const size_t stride_bytes = sizeof(DType) * stride;
Tensor<cpu, 1, char> buf =
ctx.requested[1].get_space_typed<cpu, 1, char>(Shape1(stride_bytes), ctx.get_stream<cpu>());
for (index_t i = first_axis_len - 1; i > 0; --i) {
const index_t j = rand_n(i + 1);
if (i != j) {
std::swap_ranges(out + stride * i, out + stride * (i + 1), out + stride * j);
std::memcpy(buf.dptr_, out + stride * i, stride_bytes);
std::memcpy(out + stride * i, out + stride * j, stride_bytes);
std::memcpy(out + stride * j, buf.dptr_, stride_bytes);
}
}
}
Expand Down Expand Up @@ -97,7 +104,7 @@ void ShuffleForwardCPU(const nnvm::NodeAttrs& attrs,
if (input_shape.ndim() == 1) {
Shuffle1D(out.dptr_, size, &prnd);
} else {
ShuffleND(out.dptr_, size, first_axis_len, &prnd);
ShuffleND(out.dptr_, size, first_axis_len, &prnd, ctx);
}
});
}
Expand Down

0 comments on commit f0e1c76

Please sign in to comment.