Skip to content

Commit

Permalink
Fix broken Pooling CUDA NHWC Ops and ensure NCHW / NHWC parity. (#19889)
Browse files Browse the repository at this point in the history
Fixed all CUDA NHWC Pooling operations which were broken and enabled the
NHWC CUDA pooling tests. Disabled all pooling tests which are not
supported by the CUDA EP.

Ensure parity between CUDA NHWC / NCHW and work towards 100% tests
enabled for the CUDA EP / CUDA NHWC EP.

---------

Co-authored-by: Tianlei Wu <[email protected]>
  • Loading branch information
2 people authored and rachguo committed Mar 21, 2024
1 parent a13e5d5 commit 247f8c5
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 98 deletions.
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kM
MaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, float, MaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, MLFloat16, MaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, int8_t, MaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, uint8_t, MaxPool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16,
Expand Down Expand Up @@ -135,6 +137,7 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) {
kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, float, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, MLFloat16, MaxPool)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
float, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
Expand All @@ -147,6 +150,10 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) {
float, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12,
MLFloat16, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12,
int8_t, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12,
uint8_t, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
float, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
Expand Down
29 changes: 22 additions & 7 deletions onnxruntime/core/providers/cuda/cudnn_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,28 @@ Status CudnnTensor::Set(gsl::span<const int64_t> input_dims, cudnnDataType_t dat
TensorPitches pitches(input_dims);
InlinedVector<int, kTensorShapeSmallBufferElementsSize> dims(rank);
InlinedVector<int, kTensorShapeSmallBufferElementsSize> strides(rank);
for (int i = 0; i < rank; i++) {
dims[i] = gsl::narrow_cast<int>(input_dims[i]);
strides[i] = gsl::narrow_cast<int>(pitches[i]);
}
if (is_nhwc) {
std::swap(dims[1], dims[rank - 1]);
std::swap(strides[1], strides[rank - 1]);

if (!is_nhwc) {
for (int i = 0; i < rank; i++) {
dims[i] = gsl::narrow_cast<int>(input_dims[i]);
strides[i] = gsl::narrow_cast<int>(pitches[i]);
}
} else {
// NHWDC <-> NCHWD

// N
dims[0] = gsl::narrow_cast<int>(input_dims[0]);
strides[0] = gsl::narrow_cast<int>(pitches[0]);

// HWD
for (int i = 1; i < rank - 1; i++) {
dims[i + 1] = gsl::narrow_cast<int>(input_dims[i]);
strides[i + 1] = gsl::narrow_cast<int>(pitches[i]);
}

// C
dims[1] = gsl::narrow_cast<int>(input_dims[rank - 1]);
strides[1] = gsl::narrow_cast<int>(pitches[rank - 1]);
}
CUDNN_RETURN_IF_ERROR(cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast<int>(rank), dims.data(), strides.data()));
return Status::OK();
Expand Down
114 changes: 85 additions & 29 deletions onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/shared_inc/fast_divmod.h"
#include "core/providers/cuda/shared_inc/cuda_utils.h"

namespace onnxruntime {
namespace cuda {
template <typename T>
template <typename T, bool Layout>
__global__ void MaxPoolWithIndexKernel(
int64_t batch,
int64_t channels,
Expand Down Expand Up @@ -44,11 +45,27 @@ __global__ void MaxPoolWithIndexKernel(
int id = blockIdx.x * blockDim.x + threadIdx.x;
if (id >= output_size) return;

auto compute_offset =
[height, width, depth, channels](int n_index, int c_index, int h_index, int w_index, int d_index) -> int64_t {
if constexpr (Layout == LAYOUT_NCHW) {
return (((n_index * channels + c_index) * height + h_index) * width + w_index) * depth + d_index;
} else if constexpr (Layout == LAYOUT_NHWC) {
return (((n_index * height + h_index) * width + w_index) * depth + d_index) * channels + c_index;
}
};

int d_index, w_index, h_index, c_index, n_index, id_tmp;
fdm_d.divmod(id, id_tmp, d_index);
fdm_w.divmod(id_tmp, id_tmp, w_index);
fdm_h.divmod(id_tmp, id_tmp, h_index);
fdm_c.divmod(id_tmp, n_index, c_index);
if constexpr (Layout == LAYOUT_NCHW) {
fdm_d.divmod(id, id_tmp, d_index);
fdm_w.divmod(id_tmp, id_tmp, w_index);
fdm_h.divmod(id_tmp, id_tmp, h_index);
fdm_c.divmod(id_tmp, n_index, c_index);
} else if constexpr (Layout == LAYOUT_NHWC) {
fdm_c.divmod(id, id_tmp, c_index);
fdm_d.divmod(id_tmp, id_tmp, d_index);
fdm_w.divmod(id_tmp, id_tmp, w_index);
fdm_h.divmod(id_tmp, n_index, h_index);
}

int64_t d_start = d_index * stride_d - pad_d;
int64_t w_start = w_index * stride_w - pad_w;
Expand All @@ -64,29 +81,45 @@ __global__ void MaxPoolWithIndexKernel(
int64_t d_index_max = -1;
int64_t w_index_max = -1;
int64_t h_index_max = -1;
int64_t offset = (n_index * channels + c_index) * height * width * depth;
int64_t offset = compute_offset(n_index, c_index, 0, 0, 0);
const T* p_slice = p_input + offset;
T maxval = p_slice[h_start * width * depth + w_start * depth + d_start] - (T)1;
T maxval = p_slice[compute_offset(0, 0, h_start, w_start, d_start)] - (T)1;
for (int64_t d = d_start; d < d_end; d += dilation_d) {
for (int64_t w = w_start; w < w_end; w += dilation_w) {
for (int64_t h = h_start; h < h_end; h += dilation_h) {
if (p_slice[h * width * depth + w * depth + d] > maxval) {
auto pool_offset = compute_offset(0, 0, h, w, d);
if (p_slice[pool_offset] > maxval) {
h_index_max = h;
w_index_max = w;
d_index_max = d;
maxval = static_cast<float>(p_slice[h * width * depth + w * depth + d]);
maxval = static_cast<float>(p_slice[pool_offset]);
}
}
}
}
p_output[id] = p_input[offset + h_index_max * width * depth + w_index_max * depth + d_index_max];
p_output[id] = p_input[offset + compute_offset(0, 0, h_index_max, w_index_max, d_index_max)];

if (p_indices) {
p_indices[id] = storage_order == 0 ? offset + h_index_max * width * depth + w_index_max * depth + d_index_max
: offset + h_index_max + w_index_max * height + d_index_max * width * height;
if constexpr (Layout == LAYOUT_NCHW) {
p_indices[id] = storage_order == 0 ? offset + h_index_max * width * depth + w_index_max * depth + d_index_max
: offset + h_index_max + w_index_max * height + d_index_max * width * height;
} else if constexpr (Layout == LAYOUT_NHWC) {
// The tests currently have to be provided in NHWC layout so that tests do not fail. When converting between
// layouts, does it make sense to do an index conversion as well?
// Storing indices in NHWC layout isn't critical as they are supposed to be used by Unpooling operations
// which currently assume that indices reference to Tensors in NHWC layout.
int64_t id_nchw =

Check warning on line 111 in onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu:111: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
(((n_index * channels + c_index) * pooled_height + h_index) * pooled_width + w_index) * pooled_depth + d_index;
int64_t offset_nchw = (n_index * channels + c_index) * width * height * depth;

p_indices[id_nchw] = (storage_order == 0)
? offset_nchw + h_index_max * width * depth + w_index_max * depth + d_index_max
: offset_nchw + h_index_max + w_index_max * height + d_index_max * width * height;
}
}
}

template <typename T>
template <typename T, bool Layout>
void MaxPoolWithIndex(
cudaStream_t stream,
const TensorShape& input_shape,
Expand All @@ -99,14 +132,29 @@ void MaxPoolWithIndex(
const T* p_input,
T* p_output,
int64_t* p_indices) {
int64_t batchs = input_shape[0];
int64_t channels = input_shape[1];
int64_t height = input_shape[2];
int64_t width = kernel_shape.size() > 1 ? input_shape[3] : 1;
int64_t depth = kernel_shape.size() > 2 ? input_shape[4] : 1;
int64_t pooled_height = output_shape[2];
int64_t pooled_width = kernel_shape.size() > 1 ? output_shape[3] : 1;
int64_t pooled_depth = kernel_shape.size() > 2 ? output_shape[4] : 1;
int64_t batchs, channels, height, width, depth;
int64_t pooled_height, pooled_width, pooled_depth;
if constexpr (Layout == LAYOUT_NCHW) {
batchs = input_shape[0];
channels = input_shape[1];
height = input_shape[2];
width = kernel_shape.size() > 1 ? input_shape[3] : 1;
depth = kernel_shape.size() > 2 ? input_shape[4] : 1;

pooled_height = output_shape[2];
pooled_width = kernel_shape.size() > 1 ? output_shape[3] : 1;
pooled_depth = kernel_shape.size() > 2 ? output_shape[4] : 1;
} else if constexpr (Layout == LAYOUT_NHWC) {
batchs = input_shape[0];
height = input_shape[1];
width = kernel_shape.size() > 1 ? input_shape[2] : 1;
depth = kernel_shape.size() > 2 ? input_shape[3] : 1;
channels = input_shape[input_shape.NumDimensions() - 1];

pooled_height = output_shape[1];
pooled_width = kernel_shape.size() > 1 ? output_shape[2] : 1;
pooled_depth = kernel_shape.size() > 2 ? output_shape[3] : 1;
}
int64_t kernel_h = kernel_shape[0];
int64_t kernel_w = kernel_shape.size() > 1 ? kernel_shape[1] : 1;
int64_t kernel_d = kernel_shape.size() > 2 ? kernel_shape[2] : 1;
Expand All @@ -130,7 +178,7 @@ void MaxPoolWithIndex(
fast_divmod fdm_d(static_cast<int>(pooled_depth));

int blocksPerGrid = (int)((output_size + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock);
MaxPoolWithIndexKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
MaxPoolWithIndexKernel<T, Layout><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
batchs,
channels,
height,
Expand Down Expand Up @@ -162,8 +210,8 @@ void MaxPoolWithIndex(
p_indices);
}

#define INSTANTIATEMAXPOOLWITHINDEX(T) \
template void MaxPoolWithIndex<T>( \
#define INSTANTIATEMAXPOOLWITHINDEX(T, Layout) \
template void MaxPoolWithIndex<T, Layout>( \
cudaStream_t stream, \
const TensorShape& input_shape, \
const TensorShape& output_shape, \
Expand All @@ -176,11 +224,19 @@ void MaxPoolWithIndex(
T* p_output, \
int64_t* p_indices);

INSTANTIATEMAXPOOLWITHINDEX(float)
INSTANTIATEMAXPOOLWITHINDEX(double)
INSTANTIATEMAXPOOLWITHINDEX(half)
INSTANTIATEMAXPOOLWITHINDEX(int8_t)
INSTANTIATEMAXPOOLWITHINDEX(uint8_t)
INSTANTIATEMAXPOOLWITHINDEX(float, LAYOUT_NCHW)
INSTANTIATEMAXPOOLWITHINDEX(double, LAYOUT_NCHW)
INSTANTIATEMAXPOOLWITHINDEX(half, LAYOUT_NCHW)
INSTANTIATEMAXPOOLWITHINDEX(int8_t, LAYOUT_NCHW)
INSTANTIATEMAXPOOLWITHINDEX(uint8_t, LAYOUT_NCHW)

#ifdef ENABLE_CUDA_NHWC_OPS
INSTANTIATEMAXPOOLWITHINDEX(float, LAYOUT_NHWC)
INSTANTIATEMAXPOOLWITHINDEX(double, LAYOUT_NHWC)
INSTANTIATEMAXPOOLWITHINDEX(half, LAYOUT_NHWC)
INSTANTIATEMAXPOOLWITHINDEX(int8_t, LAYOUT_NHWC)
INSTANTIATEMAXPOOLWITHINDEX(uint8_t, LAYOUT_NHWC)
#endif

} // namespace cuda
} // namespace onnxruntime
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/nn/max_pool_with_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

namespace onnxruntime {
namespace cuda {
template <typename T>
template <typename T, bool Layout>
void MaxPoolWithIndex(
cudaStream_t stream,
const TensorShape& input_shape,
Expand Down
Loading

0 comments on commit 247f8c5

Please sign in to comment.