From 4ecef067314aef01bf463e31ae9d2b611b7525f2 Mon Sep 17 00:00:00 2001 From: Bryce Adelstein Lelbach aka wash Date: Wed, 11 Dec 2024 11:45:44 -0800 Subject: [PATCH] [cuda.cooperative] Add inclusive_scan to cuda.cooperative. --- .../experimental/block/__init__.py | 2 +- .../experimental/block/_block_scan.py | 25 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/python/cuda_cooperative/cuda/cooperative/experimental/block/__init__.py b/python/cuda_cooperative/cuda/cooperative/experimental/block/__init__.py index f51c3dccfb6..26d655abe6e 100644 --- a/python/cuda_cooperative/cuda/cooperative/experimental/block/__init__.py +++ b/python/cuda_cooperative/cuda/cooperative/experimental/block/__init__.py @@ -4,6 +4,6 @@ from cuda.cooperative.experimental.block._block_merge_sort import merge_sort_keys from cuda.cooperative.experimental.block._block_reduce import reduce, sum -from cuda.cooperative.experimental.block._block_scan import exclusive_sum +from cuda.cooperative.experimental.block._block_scan import exclusive_sum, inclusive_sum from cuda.cooperative.experimental.block._block_radix_sort import radix_sort_keys, radix_sort_keys_descending from cuda.cooperative.experimental.block._block_load_store import load, store diff --git a/python/cuda_cooperative/cuda/cooperative/experimental/block/_block_scan.py b/python/cuda_cooperative/cuda/cooperative/experimental/block/_block_scan.py index 59d170163b6..b9776118c04 100644 --- a/python/cuda_cooperative/cuda/cooperative/experimental/block/_block_scan.py +++ b/python/cuda_cooperative/cuda/cooperative/experimental/block/_block_scan.py @@ -67,3 +67,28 @@ def exclusive_sum(dtype, threads_in_block, items_per_thread, prefix_op=None): return Invocable(temp_files=[make_binary_tempfile(ltoir, '.ltoir') for ltoir in specialization.get_lto_ir()], temp_storage_bytes=specialization.get_temp_storage_bytes(), algorithm=specialization) + +def inclusive_sum(dtype, threads_in_block, items_per_thread, prefix_op=None): + template = Algorithm('BlockScan', + 'InclusiveSum', + 'block_scan', + ['cub/block/block_scan.cuh'], + [TemplateParameter('T'), + TemplateParameter('BLOCK_DIM_X')], + [[Pointer(numba.uint8), + DependentArray(Dependency( + 'T'), Dependency('ITEMS_PER_THREAD')), + DependentArray(Dependency( + 'T'), Dependency('ITEMS_PER_THREAD')), + DependentOperator(Dependency('T'), [Dependency('T')], Dependency('PrefixOp'))], + [Pointer(numba.uint8), + DependentArray(Dependency( + 'T'), Dependency('ITEMS_PER_THREAD')), + DependentArray(Dependency('T'), Dependency('ITEMS_PER_THREAD'))]]) + specialization = template.specialize({'T': dtype, + 'BLOCK_DIM_X': threads_in_block, + 'ITEMS_PER_THREAD': items_per_thread, + 'PrefixOp': prefix_op}) + return Invocable(temp_files=[make_binary_tempfile(ltoir, '.ltoir') for ltoir in specialization.get_lto_ir()], + temp_storage_bytes=specialization.get_temp_storage_bytes(), + algorithm=specialization)