Skip to content

Commit

Permalink
Refactor scan_tuning policy hub so that it accepts OffsetT
Browse files Browse the repository at this point in the history
  • Loading branch information
gonidelis committed Dec 13, 2024
1 parent 73152a9 commit 629c558
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 9 deletions.
1 change: 1 addition & 0 deletions cub/benchmarks/bench/scan/exclusive/base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <cuda/std/__functional/invoke.h>

#include <look_back_helper.cuh>
#include <nvbench_helper.cuh>

#if !TUNE_BASE
# if TUNE_TRANSPOSE == 0
Expand Down
2 changes: 1 addition & 1 deletion cub/cub/device/device_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1322,7 +1322,7 @@ struct DeviceScan
detail::InputValue<InitValueT>,
OffsetT,
AccumT,
detail::scan::policy_hub<AccumT, ScanOpT>,
detail::scan::policy_hub<AccumT, OffsetT, ScanOpT>,
ForceInclusive>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_in,
Expand Down
2 changes: 1 addition & 1 deletion cub/cub/device/dispatch/dispatch_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ template <typename InputIteratorT,
::cuda::std::_If<std::is_same<InitValueT, NullType>::value,
cub::detail::value_t<InputIteratorT>,
typename InitValueT::value_type>>,
typename SelectedPolicy = detail::scan::policy_hub<AccumT, ScanOpT>,
typename SelectedPolicy = detail::scan::policy_hub<AccumT, OffsetT, ScanOpT>,
bool ForceInclusive = false>
struct DispatchScan : SelectedPolicy
{
Expand Down
16 changes: 11 additions & 5 deletions cub/cub/device/dispatch/tuning/tuning_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,18 @@ constexpr accum_size classify_accum_size()
: accum_size::unknown;
}

template <int Threads, int Items, int L2B, int L2W>
template <class OffsetT>
constexpr offset_size classify_offset_size()
{
return sizeof(OffsetT) == 4 ? offset_size::_4 : sizeof(OffsetT) == 8 ? offset_size::_8 : offset_size::unknown;
}

template <int Threads, int Items, int L2B, int L2W, typename DelayConstructor = fixed_delay_constructor_t<L2B, L2W>>
struct tuning
{
static constexpr int threads = Threads;
static constexpr int items = Items;
using delay_constructor = fixed_delay_constructor_t<L2B, L2W>;
using delay_constructor = DelayConstructor;
};

template <class AccumT,
Expand Down Expand Up @@ -223,7 +229,7 @@ struct sm90_tuning<__uint128_t, primitive_op::yes, primitive_accum::no, accum_si
#endif
// clang-format on

template <typename AccumT, typename ScanOpT>
template <typename AccumT, typename OffsetT, typename ScanOpT>
struct policy_hub
{
// For large values, use timesliced loads/stores to fit shared memory.
Expand Down Expand Up @@ -293,7 +299,7 @@ struct policy_hub
} // namespace detail

// TODO(bgruber): deprecate this at some point when we have a better way to allow users to supply tunings
template <typename AccumT, typename ScanOpT = ::cuda::std::plus<>>
using DeviceScanPolicy = detail::scan::policy_hub<AccumT, ScanOpT>;
template <typename AccumT, typename OffsetT, typename ScanOpT = ::cuda::std::plus<>>
using DeviceScanPolicy = detail::scan::policy_hub<AccumT, OffsetT, ScanOpT>;

CUB_NAMESPACE_END
4 changes: 2 additions & 2 deletions thrust/thrust/system/cuda/detail/scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl(
InputValueT,
std::int32_t,
AccumT,
cub::detail::scan::policy_hub<AccumT, ScanOp>,
cub::detail::scan::policy_hub<AccumT, std::int32_t, ScanOp>,
ForceInclusive>;
using Dispatch64 =
cub::DispatchScan<InputIt,
Expand All @@ -139,7 +139,7 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl(
InputValueT,
std::int64_t,
AccumT,
cub::detail::scan::policy_hub<AccumT, ScanOp>,
cub::detail::scan::policy_hub<AccumT, std::int64_t, ScanOp>,
ForceInclusive>;

cudaStream_t stream = thrust::cuda_cub::stream(policy);
Expand Down

0 comments on commit 629c558

Please sign in to comment.