Skip to content

Commit

Permalink
gpu: sycl: add PReLU primitive implemented using SYCL kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
densamoilov committed Mar 18, 2023
1 parent ed49414 commit 3933b58
Show file tree
Hide file tree
Showing 8 changed files with 856 additions and 4 deletions.
1 change: 0 additions & 1 deletion examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ if(DNNL_SYCL_CUDA)
${CMAKE_CURRENT_SOURCE_DIR}/primitives/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives/lstm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives/layer_normalization.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives/prelu.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives/reorder.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives/shuffle.cpp)
endif()
Expand Down
5 changes: 5 additions & 0 deletions src/gpu/nvidia/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,11 @@ backward propagation respectively.

* Supported data type are `f32`, `f16`, `bf16` and `s8`.

### PReLU
The PReLU primitive (Leaky ReLU with a trainable alpha parameter) is implemented
using SYCL kernels. The primitive supports both forward and backward
propagations for the data types f32, s32, bf16, f16, s8 and u8.

### Reorder

The `cudnnTransform` function is the equivalent of oneDNN reorder function.
Expand Down
6 changes: 5 additions & 1 deletion src/gpu/nvidia/sycl_cuda_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
#include "gpu/nvidia/sycl_cuda_engine.hpp"
#include "gpu/nvidia/sycl_cuda_scoped_context.hpp"
#include "gpu/nvidia/sycl_cuda_stream.hpp"

#include "gpu/sycl/ref_binary.hpp"
#include "gpu/sycl/ref_prelu.hpp"

namespace dnnl {
namespace impl {
Expand Down Expand Up @@ -215,6 +215,10 @@ constexpr dnnl::impl::impl_list_item_t sycl_cuda_impl_list[] = {
INSTANCE(cudnn_batch_normalization_fwd_t)
INSTANCE(cudnn_batch_normalization_bwd_t)

// PReLU
INSTANCE(sycl::ref_prelu_fwd_t)
INSTANCE(sycl::ref_prelu_bwd_t)

// Pooling
INSTANCE(cudnn_pooling_fwd_t)
INSTANCE(cudnn_pooling_bwd_t)
Expand Down
Loading

0 comments on commit 3933b58

Please sign in to comment.