Skip to content

Commit

Permalink
gpu: amd: enable SYCL kernels (#2024)
Browse files Browse the repository at this point in the history
Co-authored-by: Denis Samoilov <[email protected]>
  • Loading branch information
sgeor255 and densamoilov authored Aug 9, 2024
1 parent 894f16c commit d4d58cf
Show file tree
Hide file tree
Showing 22 changed files with 146 additions and 81 deletions.
16 changes: 16 additions & 0 deletions cmake/options.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,17 @@ set(ONEDNN_ENABLE_GEMM_KERNELS_ISA "ALL" CACHE STRING
SSE41 < AVX2 < AVX512 < AMX (or ALL). It means that if user selects, e.g.
AVX2 ISA, SSE41 kernels will also present at build time.")

set(DNNL_AMD_SYCL_KERNELS_TARGET_ARCH "" CACHE STRING
"Specifies the target architecture (e.g. gfx90a when compiling on AMD MI210)
to be used for compiling generic SYCL kernels for AMD vendor.
When this option is set to a valid architecture (see LLVM target column in
https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html#supported-gpus
for supported architectures), the generic SYCL kernels will be enabled for AMD
vendor. If not set, the SYCL kernels will not be compiled.
Warning: This option is temporary and will be removed as soon as the compiler
stops to require specifying the target architecture. After removing the option
the generic SYCL kernels will always be enabled for AMD vendor.")

# =============
# Optimizations
# =============
Expand Down Expand Up @@ -327,6 +338,11 @@ else()
set(DNNL_WITH_SYCL false)
endif()

if(DNNL_SYCL_HIP AND NOT "${DNNL_AMD_SYCL_KERNELS_TARGET_ARCH}" STREQUAL "")
add_definitions(-DDNNL_AMD_ENABLE_SYCL_KERNELS=1)
set(DNNL_AMD_ENABLE_SYCL_KERNELS TRUE)
endif()

# =============
# Miscellaneous
# =============
Expand Down
4 changes: 4 additions & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ if(DNNL_SYCL_CUDA)
append(CMAKE_CXX_FLAGS "-Wno-linker-warnings")
endif()

if (DNNL_AMD_ENABLE_SYCL_KERNELS)
append(CMAKE_CXX_FLAGS "-fsycl-targets=amdgcn-amd-amdhsa -Xsycl-target-backend --offload-arch=${DNNL_AMD_SYCL_KERNELS_TARGET_ARCH}")
endif()

# propagate sanitizer flags
append(CMAKE_C_FLAGS "${CMAKE_CCXX_SANITIZER_FLAGS}")
append(CMAKE_CXX_FLAGS "${CMAKE_CCXX_SANITIZER_FLAGS}")
Expand Down
4 changes: 4 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ if(DNNL_WITH_SYCL)
append(CMAKE_SHARED_LINKER_FLAGS "-fsycl-targets=nvptx64-nvidia-cuda")
append(CMAKE_STATIC_LINKER_FLAGS "-fsycl-targets=nvptx64-nvidia-cuda")
endif()
if(DNNL_AMD_ENABLE_SYCL_KERNELS)
append(CMAKE_SHARED_LINKER_FLAGS "-fsycl-targets=amdgcn-amd-amdhsa -Xsycl-target-backend --offload-arch=${DNNL_AMD_SYCL_KERNELS_TARGET_ARCH}")
append(CMAKE_STATIC_LINKER_FLAGS "-fsycl-targets=amdgcn-amd-amdhsa -Xsycl-target-backend --offload-arch=${DNNL_AMD_SYCL_KERNELS_TARGET_ARCH}")
endif()
endif()

if(ONEDNN_BUILD_GRAPH)
Expand Down
5 changes: 5 additions & 0 deletions src/gpu/amd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,8 @@ The `miopenTransform` function is the equivalent of oneDNN reorder function.
* Per dimension scaling is not supported (a single alpha and beta value is
accepted by the transform tensor function).
* Supported data types: `f32`

### Other Primitives

Some missing primitives/features are supported through
[generic SYCL kernels](../generic/sycl/README.md).
5 changes: 3 additions & 2 deletions src/gpu/generic/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ file(GLOB SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/*.cpp
)

# The generic SYCL kernels are currently enabled for NVIDIA vendor only.
if(DNNL_SYCL_CUDA)
# - Always enable generic SYCL kernels for NVIDIA vendor.
# - Only enable the generic SYCL kernels for AMD vendor if target architecture has been specified.
if(DNNL_SYCL_CUDA OR DNNL_AMD_ENABLE_SYCL_KERNELS)
add_subdirectory(sycl)
endif()

Expand Down
4 changes: 4 additions & 0 deletions src/gpu/generic/sycl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ if(DNNL_SYCL_CUDA)
append(CMAKE_CXX_FLAGS "-Wno-linker-warnings")
endif()

if(DNNL_AMD_ENABLE_SYCL_KERNELS)
append(CMAKE_CXX_FLAGS "-fsycl-targets=amdgcn-amd-amdhsa -Xsycl-target-backend --offload-arch=${DNNL_AMD_SYCL_KERNELS_TARGET_ARCH}")
endif()

set(OBJ_LIB ${LIB_PACKAGE_NAME}_gpu_generic_sycl)
add_library(${OBJ_LIB} OBJECT ${SOURCES})
set_property(GLOBAL APPEND PROPERTY DNNL_LIB_DEPS
Expand Down
63 changes: 63 additions & 0 deletions src/gpu/generic/sycl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Supported Primitives

## Batch Normalization

The implementation supports both forward and backward directions.

* Supported formats: `NCDHW`, `NDHWC`, `NCHW`, `NHWC`, `NCW`, `NWC`, `NC`

## Eltwise

The implementation supports both forward and backward directions.

* Supported algorithms: `abs`, `clip`, `clip_v2`, `elu`, `exp`, `gelu_erf`,
`gelu_tanh`, `hardsigmoid`, `hardswish`, `linear`, `log`, `logistic`, `mish`,
`pow`, `relu`, `round`, `soft_relu`, `sqrt`, `square`,`swish` and `tanh`
* Supported formats: `NCDHW`, `NDHWC`, `NCHW`, `NHWC`, `NCW`, `NWC`, `NC`, `N`

## LRN

The implementation supports both forward and backward directions.

* Supported formats: `NCDHW`, `NDHWC`, `NCHW`, `NHWC`, `NCW`, `NWC`, `NC`

## Pooling

The implementation supports both forward and backward directions.

* Supported formats: `NCDHW`, `NDHWC`, `NCHW`, `NHWC`, `NCW`, `NWC`

## PReLU

The implementation supports both forward and backward propagations.

* Supported formats: `NCDHW`, `NDHWC`, `NCHW`, `NHWC`, `NCW`, `NWC`, `NC`

* Forward pass supports `f32`, `f16`, `bf16`, `s8` and `u8` data types
* Backward pass supports `f32` and `bf16` data types

## Reorder

* Format support limitations: blocked formats are not supported
* Supported data types: `f32`, `bf16`, `f16`, `s8`, `u8`

## Resampling

The implementation supports both forward and backward directions.

* Supported formats: `NCDHW`, `NDHWC`, `NCHW`, `NHWC`, `NCW`, `NWC`

## Softmax/LogSoftmax

The implementation supports both forward and backward directions.

* Supported formats: `NCDHW`, `NDHWC`, `NCHW`, `NHWC`, `NCW`, `NWC`, `NC`

## Shuffle

The implementation supports both forward and backward propagations.

* Supported formats: `NCDHW`, `NDHWC`, `NCHW`, `NHWC`, `NCW`, `NWC`, `NC`

* Forward pass supports `f32`, `f16`, `bf16` and `s8` data types.
* Backward pass supports `f32` and `bf16` data types.
5 changes: 4 additions & 1 deletion src/gpu/gpu_batch_normalization_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,17 @@
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA
#include "gpu/generic/sycl/ref_batch_normalization.hpp"
#include "gpu/nvidia/cudnn_batch_normalization.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/amd/miopen_batch_normalization.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA || DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/generic/sycl/ref_batch_normalization.hpp"
#endif

namespace dnnl {
namespace impl {
namespace gpu {
Expand Down
5 changes: 4 additions & 1 deletion src/gpu/gpu_binary_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA
#include "gpu/generic/sycl/ref_binary.hpp"
#include "gpu/nvidia/cudnn_binary.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/amd/miopen_binary.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA || DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/generic/sycl/ref_binary.hpp"
#endif

namespace dnnl {
namespace impl {
namespace gpu {
Expand Down
5 changes: 4 additions & 1 deletion src/gpu/gpu_convolution_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,17 @@
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA
#include "gpu/generic/sycl/ref_convolution.hpp"
#include "gpu/nvidia/cudnn_convolution.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/amd/miopen_convolution.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA || DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/generic/sycl/ref_convolution.hpp"
#endif

namespace dnnl {
namespace impl {
namespace gpu {
Expand Down
5 changes: 4 additions & 1 deletion src/gpu/gpu_eltwise_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,17 @@
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA
#include "gpu/generic/sycl/ref_eltwise.hpp"
#include "gpu/nvidia/cudnn_eltwise.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/amd/miopen_eltwise.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA || DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/generic/sycl/ref_eltwise.hpp"
#endif

namespace dnnl {
namespace impl {
namespace gpu {
Expand Down
3 changes: 2 additions & 1 deletion src/gpu/gpu_impl_list.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ namespace gpu {
// NOTE: Support for the standalone GENERIC vendor has not been added yet.
#if defined(DNNL_WITH_SYCL) \
&& ((DNNL_GPU_VENDOR == DNNL_VENDOR_GENERIC) \
|| (DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA))
|| (DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA) \
|| (DNNL_AMD_ENABLE_SYCL_KERNELS == 1))
#define DNNL_GPU_GENERIC_SYCL_ONLY(...) __VA_ARGS__
#else
#define DNNL_GPU_GENERIC_SYCL_ONLY(...)
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/gpu_layer_normalization_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include "gpu/intel/ocl/vectorized_lnorm.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA
#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA || DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/generic/sycl/ref_layer_normalizations.hpp"
#endif

Expand Down
5 changes: 4 additions & 1 deletion src/gpu/gpu_lrn_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA
#include "gpu/generic/sycl/ref_lrn.hpp"
#include "gpu/nvidia/cudnn_lrn.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/amd/miopen_lrn.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA || DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/generic/sycl/ref_lrn.hpp"
#endif

namespace dnnl {
namespace impl {
namespace gpu {
Expand Down
5 changes: 4 additions & 1 deletion src/gpu/gpu_pooling_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA
#include "gpu/generic/sycl/ref_pooling.hpp"
#include "gpu/nvidia/cudnn_pooling.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/amd/miopen_pooling.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA || DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/generic/sycl/ref_pooling.hpp"
#endif

namespace dnnl {
namespace impl {
namespace gpu {
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/gpu_prelu_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "gpu/intel/ocl/ref_prelu.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA
#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA || DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/generic/sycl/ref_prelu.hpp"
#endif

Expand Down
4 changes: 3 additions & 1 deletion src/gpu/gpu_reorder_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA
#include "gpu/generic/sycl/ref_reorder.hpp"
#include "gpu/nvidia/cudnn_reorder.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/amd/miopen_reorder.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA || DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/generic/sycl/ref_reorder.hpp"
#endif
namespace dnnl {
namespace impl {
namespace gpu {
Expand Down
5 changes: 4 additions & 1 deletion src/gpu/gpu_resampling_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA
#include "gpu/generic/sycl/ref_resampling.hpp"
#include "gpu/nvidia/cudnn_resampling.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA || DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/generic/sycl/ref_resampling.hpp"
#endif

namespace dnnl {
namespace impl {
namespace gpu {
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/gpu_shuffle_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "gpu/intel/ocl/shuffle_by_reorder.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA
#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA || DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/generic/sycl/ref_shuffle.hpp"
#endif

Expand Down
5 changes: 4 additions & 1 deletion src/gpu/gpu_softmax_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA
#include "gpu/generic/sycl/ref_softmax.hpp"
#include "gpu/nvidia/cudnn_softmax.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/amd/miopen_softmax.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA || DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/generic/sycl/ref_softmax.hpp"
#endif

namespace dnnl {
namespace impl {
namespace gpu {
Expand Down
5 changes: 4 additions & 1 deletion src/gpu/gpu_sum_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA
#include "gpu/nvidia/cudnn_sum.hpp"
#endif

#if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA || DNNL_GPU_VENDOR == DNNL_VENDOR_AMD
#include "gpu/generic/sycl/ref_sum.hpp"
#include "gpu/generic/sycl/ref_sum_many_inputs.hpp"
#include "gpu/nvidia/cudnn_sum.hpp"
#endif

namespace dnnl {
Expand Down
Loading

0 comments on commit d4d58cf

Please sign in to comment.