From 33312e5951ac56abe7549af5d09bf2bfb8af9932 Mon Sep 17 00:00:00 2001 From: moin Date: Mon, 21 May 2018 20:18:57 +0200 Subject: [PATCH] speed up of topk-operator (#10997) --- src/operator/tensor/ordering_op-inl.h | 203 +++++++++++++++++++++++--- 1 file changed, 183 insertions(+), 20 deletions(-) diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 606406dfe0bd..105ee8b90db8 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -152,6 +152,174 @@ inline void ParseTopKParam(const TShape& src_shape, const TopKParam& param, TSha << *element_num << ", get k = " << *k; } +using namespace mshadow; + +template +void TopKSort(const Tensor& dat, + const Tensor& ind, + const Tensor& work, + int K, int N, bool is_ascend, + Stream *s); + +template<> +MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, + const Tensor& ind, + const Tensor& work, + int K, int N, bool is_ascend, + Stream *s) { + // Use full sort when K is relatively large. + const bool full_sort(K*8 > N); + // Batch size. + const int M(dat.size(0)/N); + const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()); + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < M; ++i) { + real_t *vals = dat.dptr_; + int *indices = ind.dptr_+i*N; + if (is_ascend) { + if (full_sort) { + std::sort(indices, indices+N, + [&](const int& i1, const int& i2){ return vals[i1] < vals[i2]; }); + } else { + std::partial_sort(indices, indices+K, indices+N, + [&](const int& i1, const int& i2){ return vals[i1] < vals[i2]; }); + } + } else { + if (full_sort) { + std::sort(indices, indices+N, + [&](const int& i1, const int& i2){ return vals[i1] > vals[i2]; }); + } else { + std::partial_sort(indices, indices+K, indices+N, + [&](const int& i1, const int& i2){ return vals[i1] > vals[i2]; }); + } + } + real_t *buff = reinterpret_cast(work.dptr_)+i*K; + for (int j = 0; j < K; ++j) { + buff[j] = vals[indices[j]]; + } + std::copy(buff, buff+K, &vals[i*N]); + } +} + +#ifdef __CUDACC__ + +template +MSHADOW_XINLINE bool TopKCompare(DType val1, int ind1, DType val2, int ind2, bool is_ascend) { + // Negative indices denote undefined values which are considered arbitrary small resp. large. + return (ind2 < 0) || (ind1 >= 0 && ((is_ascend && val1 < val2) || (!is_ascend && val1 > val2))); +} + +template +MSHADOW_XINLINE void MergeTopK(int K, DType *val1, int *ind1, DType *val2, int *ind2, + bool is_ascend) { + // In-place merge of two sorted top-K lists into val1/ind1. First determine the intervals + // [0,..,i1], [0,..i2] of the two lists that will be part of the merged list. + int i1(K-1), i2(K-1); + for (int i = 0; i < K; ++i) { + if (TopKCompare(val1[i1], ind1[i1], val2[i2], ind2[i2], is_ascend)) { + --i2; + } else { + --i1; + } + } + // Now merge the lists from back to front. + for (int i = K; i--;) { + if (i2 < 0 || i1 >= 0 && TopKCompare(val2[i2], ind2[i2], val1[i1], ind1[i1], is_ascend)) { + val1[i] = val1[i1]; + ind1[i] = ind1[i1]; + --i1; + } else { + val1[i] = val2[i2]; + ind1[i] = ind2[i2]; + --i2; + } + } +} + +template +__global__ void PartialSortSmallK(int K, int N, DType *val, int *ind, bool is_ascend) { + // Buffer for pairwise reduction. + extern __shared__ int buff[]; + // Start of buffer sections associated with this thread. + const int offset(threadIdx.x*K); + int *ind_buff = &buff[offset]; + DType *val_buff = reinterpret_cast(&buff[blockDim.x*K])+offset; + // Initialize top-K values for this thread. + for (int i = 0; i < K; ++i) { + ind_buff[i] = -1; + } + // Range of values this thread cares about. Each thread block processes + // a different batch item (i.e. a different set of ind/val where we + // have to select the top-K elements). All threads within the same + // block work on the same batch item. + const int first(blockIdx.x*N+threadIdx.x), last((blockIdx.x+1)*N); + // Select top-K from this range and store it sorted in the buffer. + // We assume a small K, so linear insertion is o.k. + for (int i = first; i < last; i += blockDim.x) { + DType cur_val(val[i]); + int cur_ind(ind[i]); + for (int j = K; j-- && TopKCompare(cur_val, cur_ind, val_buff[j], ind_buff[j], is_ascend); ) { + if (j+1 < K) { + val_buff[j+1] = val_buff[j]; + ind_buff[j+1] = ind_buff[j]; + } + val_buff[j] = cur_val; + ind_buff[j] = cur_ind; + } + } + // Recursive merge of sorted lists for this thread block. Note that blockDim.x is not + // necessary a power of two, therefore the additional checks for last_s. + for (unsigned int s = (blockDim.x+1)/2, last_s = blockDim.x; + last_s > 1; last_s = s, s = (s+1)/2) { + __syncthreads(); + if (threadIdx.x < s && threadIdx.x+s < last_s) { + MergeTopK(K, val_buff, ind_buff, val_buff+s*K, ind_buff+s*K, is_ascend); + } + } + // Final updates on master thread. + if (threadIdx.x == 0) { + for (int i = 0; i < K; ++i) { + ind[blockIdx.x*N+i] = ind_buff[i]; + val[blockIdx.x*N+i] = val_buff[i]; + } + } +} + +template<> +MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, + const Tensor& ind, + const Tensor& work, + int K, int N, bool is_ascend, + Stream *s) { + // Use full sort for all but very small K for which we + // can do a partial sort entirely within shared memory. + const bool full_sort(K > 5); + // Batch size. + const int M(dat.size(0)/N); + if (full_sort) { + // Divide workspace into two parts. The first one is needed to store batch ids. + const int id_size(sizeof(int)*ind.size(0)); + Tensor batch_id(reinterpret_cast(work.dptr_), Shape1(ind.size(0)), s); + Tensor sort_work(work.dptr_+id_size, Shape1(work.size(0)-id_size), s); + mxnet::op::SortByKey(dat, ind, is_ascend, &sort_work); + if (M > 1) { + // Back to back sorting. Note that mxnet::op::SortByKey is a stable sort. + batch_id = ind / N; + mxnet::op::SortByKey(batch_id, dat, true, &sort_work); + batch_id = ind / N; + mxnet::op::SortByKey(batch_id, ind, true, &sort_work); + } + } else { + const int nthreads(mshadow::cuda::kBaseThreadNum); + PartialSortSmallK<<::GetStream(s)>>> + (K, N, dat.dptr_, ind.dptr_, is_ascend); + } +} + +#endif + + /*! * \brief Implementation of the TopK operation * @@ -180,7 +348,7 @@ void TopKImpl(RunContext ctx, Tensor workspace; Tensor temp_workspace; Tensor sorted_dat; - Tensor indices, batch_id, sel_indices; + Tensor indices, sel_indices; Tensor mask_val; int batch_size, element_num; // number of batches + the size of each batch int axis = 0; @@ -191,10 +359,16 @@ void TopKImpl(RunContext ctx, ParseTopKParam(src.shape_, param, &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend); Tensor dat = src.FlatTo3D(axis, axis, s); - size_t temp_size = mxnet::op::SortByKeyWorkspaceSize(src.Size()); + size_t temp_size = 0; + // Temp space needed by the gpu-based full sorts. + temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize(src.Size())); temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize(src.Size())); temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize(src.Size())); - size_t workspace_size = temp_size + sizeof(real_t) * src.Size() + sizeof(int) * src.Size() * 2; + // Additional temp space for gpu full sorts for batch ids. + temp_size += sizeof(int) * src.Size(); + // Temp space for cpu sorts. + temp_size = std::max(temp_size, sizeof(real_t) * src.Size()); + size_t workspace_size = temp_size + sizeof(real_t) * src.Size() + sizeof(int) * src.Size(); if (param.ret_typ == topk_enum::kReturnMask) { workspace_size += sizeof(int) * batch_size * k + sizeof(real_t) * batch_size * k; } @@ -206,9 +380,6 @@ void TopKImpl(RunContext ctx, indices = Tensor(reinterpret_cast(workspace_curr_ptr), Shape1(src.Size()), s); // indices in the original matrix workspace_curr_ptr += sizeof(int) * src.Size(); - batch_id = Tensor(reinterpret_cast(workspace_curr_ptr), - Shape1(src.Size()), s); // batch id in the original matrix - workspace_curr_ptr += sizeof(int) * src.Size(); if (do_transpose) { sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size())); } else { @@ -232,19 +403,11 @@ void TopKImpl(RunContext ctx, } temp_workspace = Tensor(workspace_curr_ptr, Shape1(temp_size), s); // temp space workspace_curr_ptr += temp_size; - // 2. Perform inplace batch sort using the `SortByKey` in MShadow + + // 2. Perform inplace batch sort. // After sorting, each batch in `sorted_dat` will be sorted in the corresponding order - // and the `indices` will contain the corresponding index in `sorted_dat` - // Sort the data and keep record of the correspondence to global indices. - mxnet::op::SortByKey(sorted_dat, indices, is_ascend, &temp_workspace); - // Calculate the corresponding batch indices of the elements - batch_id = indices / element_num; - // Since the SortByKey performs stable sort, the second SortByKey will reorder - // the sorted_dat based on the order of the batch_id - mxnet::op::SortByKey(batch_id, sorted_dat, true, &temp_workspace); - // Reorder the indices - batch_id = indices / element_num; - mxnet::op::SortByKey(batch_id, indices, true, &temp_workspace); + // up to the k-th element and the `indices` will contain the corresponding index in `sorted_dat` + TopKSort(sorted_dat, indices, temp_workspace, k, element_num, is_ascend, s); // 3. Assign results to the ret blob if (param.ret_typ == topk_enum::kReturnMask) { @@ -264,7 +427,7 @@ void TopKImpl(RunContext ctx, } IndexFill(ret_mask, sel_indices, mask_val); } else if (param.ret_typ == topk_enum::kReturnIndices) { - indices -= batch_id * element_num; + indices = F(indices, element_num); if (do_transpose) { Tensor ret_indices = ret[0].FlatTo3D(axis, axis, s); ret_indices = tcast(transpose( @@ -281,7 +444,7 @@ void TopKImpl(RunContext ctx, inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k)); } } else { - indices -= batch_id * element_num; + indices = F(indices, element_num); if (do_transpose) { Tensor ret_value = ret[0].FlatTo3D(axis, axis, s); Tensor ret_indices = ret[1].FlatTo3D(axis, axis, s);