Skip to content

Commit

Permalink
speed up of topk-operator (apache#10997)
Browse files Browse the repository at this point in the history
  • Loading branch information
asmushetzel authored and Jin Huang committed May 29, 2018
1 parent 81fcba9 commit 33312e5
Showing 1 changed file with 183 additions and 20 deletions.
203 changes: 183 additions & 20 deletions src/operator/tensor/ordering_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,174 @@ inline void ParseTopKParam(const TShape& src_shape, const TopKParam& param, TSha
<< *element_num << ", get k = " << *k;
}

using namespace mshadow;

template<typename xpu>
void TopKSort(const Tensor<xpu, 1, real_t>& dat,
const Tensor<xpu, 1, int>& ind,
const Tensor<xpu, 1, char>& work,
int K, int N, bool is_ascend,
Stream<xpu> *s);

template<>
MSHADOW_FORCE_INLINE void TopKSort<cpu>(const Tensor<cpu, 1, real_t>& dat,
const Tensor<cpu, 1, int>& ind,
const Tensor<cpu, 1, char>& work,
int K, int N, bool is_ascend,
Stream<cpu> *s) {
// Use full sort when K is relatively large.
const bool full_sort(K*8 > N);
// Batch size.
const int M(dat.size(0)/N);
const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount());
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < M; ++i) {
real_t *vals = dat.dptr_;
int *indices = ind.dptr_+i*N;
if (is_ascend) {
if (full_sort) {
std::sort(indices, indices+N,
[&](const int& i1, const int& i2){ return vals[i1] < vals[i2]; });
} else {
std::partial_sort(indices, indices+K, indices+N,
[&](const int& i1, const int& i2){ return vals[i1] < vals[i2]; });
}
} else {
if (full_sort) {
std::sort(indices, indices+N,
[&](const int& i1, const int& i2){ return vals[i1] > vals[i2]; });
} else {
std::partial_sort(indices, indices+K, indices+N,
[&](const int& i1, const int& i2){ return vals[i1] > vals[i2]; });
}
}
real_t *buff = reinterpret_cast<real_t*>(work.dptr_)+i*K;
for (int j = 0; j < K; ++j) {
buff[j] = vals[indices[j]];
}
std::copy(buff, buff+K, &vals[i*N]);
}
}

#ifdef __CUDACC__

template<typename DType>
MSHADOW_XINLINE bool TopKCompare(DType val1, int ind1, DType val2, int ind2, bool is_ascend) {
// Negative indices denote undefined values which are considered arbitrary small resp. large.
return (ind2 < 0) || (ind1 >= 0 && ((is_ascend && val1 < val2) || (!is_ascend && val1 > val2)));
}

template<typename DType>
MSHADOW_XINLINE void MergeTopK(int K, DType *val1, int *ind1, DType *val2, int *ind2,
bool is_ascend) {
// In-place merge of two sorted top-K lists into val1/ind1. First determine the intervals
// [0,..,i1], [0,..i2] of the two lists that will be part of the merged list.
int i1(K-1), i2(K-1);
for (int i = 0; i < K; ++i) {
if (TopKCompare(val1[i1], ind1[i1], val2[i2], ind2[i2], is_ascend)) {
--i2;
} else {
--i1;
}
}
// Now merge the lists from back to front.
for (int i = K; i--;) {
if (i2 < 0 || i1 >= 0 && TopKCompare(val2[i2], ind2[i2], val1[i1], ind1[i1], is_ascend)) {
val1[i] = val1[i1];
ind1[i] = ind1[i1];
--i1;
} else {
val1[i] = val2[i2];
ind1[i] = ind2[i2];
--i2;
}
}
}

template<typename DType>
__global__ void PartialSortSmallK(int K, int N, DType *val, int *ind, bool is_ascend) {
// Buffer for pairwise reduction.
extern __shared__ int buff[];
// Start of buffer sections associated with this thread.
const int offset(threadIdx.x*K);
int *ind_buff = &buff[offset];
DType *val_buff = reinterpret_cast<DType*>(&buff[blockDim.x*K])+offset;
// Initialize top-K values for this thread.
for (int i = 0; i < K; ++i) {
ind_buff[i] = -1;
}
// Range of values this thread cares about. Each thread block processes
// a different batch item (i.e. a different set of ind/val where we
// have to select the top-K elements). All threads within the same
// block work on the same batch item.
const int first(blockIdx.x*N+threadIdx.x), last((blockIdx.x+1)*N);
// Select top-K from this range and store it sorted in the buffer.
// We assume a small K, so linear insertion is o.k.
for (int i = first; i < last; i += blockDim.x) {
DType cur_val(val[i]);
int cur_ind(ind[i]);
for (int j = K; j-- && TopKCompare(cur_val, cur_ind, val_buff[j], ind_buff[j], is_ascend); ) {
if (j+1 < K) {
val_buff[j+1] = val_buff[j];
ind_buff[j+1] = ind_buff[j];
}
val_buff[j] = cur_val;
ind_buff[j] = cur_ind;
}
}
// Recursive merge of sorted lists for this thread block. Note that blockDim.x is not
// necessary a power of two, therefore the additional checks for last_s.
for (unsigned int s = (blockDim.x+1)/2, last_s = blockDim.x;
last_s > 1; last_s = s, s = (s+1)/2) {
__syncthreads();
if (threadIdx.x < s && threadIdx.x+s < last_s) {
MergeTopK(K, val_buff, ind_buff, val_buff+s*K, ind_buff+s*K, is_ascend);
}
}
// Final updates on master thread.
if (threadIdx.x == 0) {
for (int i = 0; i < K; ++i) {
ind[blockIdx.x*N+i] = ind_buff[i];
val[blockIdx.x*N+i] = val_buff[i];
}
}
}

template<>
MSHADOW_FORCE_INLINE void TopKSort<gpu>(const Tensor<gpu, 1, real_t>& dat,
const Tensor<gpu, 1, int>& ind,
const Tensor<gpu, 1, char>& work,
int K, int N, bool is_ascend,
Stream<gpu> *s) {
// Use full sort for all but very small K for which we
// can do a partial sort entirely within shared memory.
const bool full_sort(K > 5);
// Batch size.
const int M(dat.size(0)/N);
if (full_sort) {
// Divide workspace into two parts. The first one is needed to store batch ids.
const int id_size(sizeof(int)*ind.size(0));
Tensor<gpu, 1, int> batch_id(reinterpret_cast<int*>(work.dptr_), Shape1(ind.size(0)), s);
Tensor<gpu, 1, char> sort_work(work.dptr_+id_size, Shape1(work.size(0)-id_size), s);
mxnet::op::SortByKey(dat, ind, is_ascend, &sort_work);
if (M > 1) {
// Back to back sorting. Note that mxnet::op::SortByKey is a stable sort.
batch_id = ind / N;
mxnet::op::SortByKey(batch_id, dat, true, &sort_work);
batch_id = ind / N;
mxnet::op::SortByKey(batch_id, ind, true, &sort_work);
}
} else {
const int nthreads(mshadow::cuda::kBaseThreadNum);
PartialSortSmallK<<<M, nthreads, nthreads*K*(sizeof(int)+sizeof(real_t)),
mshadow::Stream<gpu>::GetStream(s)>>>
(K, N, dat.dptr_, ind.dptr_, is_ascend);
}
}

#endif


/*!
* \brief Implementation of the TopK operation
*
Expand Down Expand Up @@ -180,7 +348,7 @@ void TopKImpl(RunContext ctx,
Tensor<xpu, 1, char> workspace;
Tensor<xpu, 1, char> temp_workspace;
Tensor<xpu, 1, real_t> sorted_dat;
Tensor<xpu, 1, int> indices, batch_id, sel_indices;
Tensor<xpu, 1, int> indices, sel_indices;
Tensor<xpu, 2, real_t> mask_val;
int batch_size, element_num; // number of batches + the size of each batch
int axis = 0;
Expand All @@ -191,10 +359,16 @@ void TopKImpl(RunContext ctx,
ParseTopKParam(src.shape_, param,
&target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
Tensor<xpu, 3, real_t> dat = src.FlatTo3D<xpu, real_t>(axis, axis, s);
size_t temp_size = mxnet::op::SortByKeyWorkspaceSize<int, int, xpu>(src.Size());
size_t temp_size = 0;
// Temp space needed by the gpu-based full sorts.
temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<int, int, xpu>(src.Size()));
temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<int, real_t, xpu>(src.Size()));
temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<real_t, int, xpu>(src.Size()));
size_t workspace_size = temp_size + sizeof(real_t) * src.Size() + sizeof(int) * src.Size() * 2;
// Additional temp space for gpu full sorts for batch ids.
temp_size += sizeof(int) * src.Size();
// Temp space for cpu sorts.
temp_size = std::max(temp_size, sizeof(real_t) * src.Size());
size_t workspace_size = temp_size + sizeof(real_t) * src.Size() + sizeof(int) * src.Size();
if (param.ret_typ == topk_enum::kReturnMask) {
workspace_size += sizeof(int) * batch_size * k + sizeof(real_t) * batch_size * k;
}
Expand All @@ -206,9 +380,6 @@ void TopKImpl(RunContext ctx,
indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
Shape1(src.Size()), s); // indices in the original matrix
workspace_curr_ptr += sizeof(int) * src.Size();
batch_id = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
Shape1(src.Size()), s); // batch id in the original matrix
workspace_curr_ptr += sizeof(int) * src.Size();
if (do_transpose) {
sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
} else {
Expand All @@ -232,19 +403,11 @@ void TopKImpl(RunContext ctx,
}
temp_workspace = Tensor<xpu, 1, char>(workspace_curr_ptr, Shape1(temp_size), s); // temp space
workspace_curr_ptr += temp_size;
// 2. Perform inplace batch sort using the `SortByKey` in MShadow

// 2. Perform inplace batch sort.
// After sorting, each batch in `sorted_dat` will be sorted in the corresponding order
// and the `indices` will contain the corresponding index in `sorted_dat`
// Sort the data and keep record of the correspondence to global indices.
mxnet::op::SortByKey(sorted_dat, indices, is_ascend, &temp_workspace);
// Calculate the corresponding batch indices of the elements
batch_id = indices / element_num;
// Since the SortByKey performs stable sort, the second SortByKey will reorder
// the sorted_dat based on the order of the batch_id
mxnet::op::SortByKey(batch_id, sorted_dat, true, &temp_workspace);
// Reorder the indices
batch_id = indices / element_num;
mxnet::op::SortByKey(batch_id, indices, true, &temp_workspace);
// up to the k-th element and the `indices` will contain the corresponding index in `sorted_dat`
TopKSort(sorted_dat, indices, temp_workspace, k, element_num, is_ascend, s);

// 3. Assign results to the ret blob
if (param.ret_typ == topk_enum::kReturnMask) {
Expand All @@ -264,7 +427,7 @@ void TopKImpl(RunContext ctx,
}
IndexFill(ret_mask, sel_indices, mask_val);
} else if (param.ret_typ == topk_enum::kReturnIndices) {
indices -= batch_id * element_num;
indices = F<mshadow_op::mod>(indices, element_num);
if (do_transpose) {
Tensor<xpu, 3, real_t> ret_indices = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
ret_indices = tcast<real_t>(transpose(
Expand All @@ -281,7 +444,7 @@ void TopKImpl(RunContext ctx,
inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k));
}
} else {
indices -= batch_id * element_num;
indices = F<mshadow_op::mod>(indices, element_num);
if (do_transpose) {
Tensor<xpu, 3, real_t> ret_value = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
Tensor<xpu, 3, real_t> ret_indices = ret[1].FlatTo3D<xpu, real_t>(axis, axis, s);
Expand Down

0 comments on commit 33312e5

Please sign in to comment.