Skip to content

Commit

Permalink
gpu: nvidia: add bf16 support to cudnn based eltwise
Browse files Browse the repository at this point in the history
  • Loading branch information
densamoilov committed Mar 4, 2023
1 parent 241ad64 commit 4b46abf
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions src/gpu/nvidia/cudnn_eltwise.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2022 Intel Corporation
* Copyright 2020-2023 Intel Corporation
* Copyright 2020 Codeplay Software Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -36,16 +36,23 @@ struct cudnn_eltwise_fwd_t : public primitive_t {

DECLARE_COMMON_PD_T("cuda:cudnn:any", cudnn_eltwise_fwd_t);

status_t init(engine_t *) {
status_t init(engine_t *engine) {
using namespace alg_kind;

auto sycl_dev
= utils::downcast<impl::sycl::sycl_engine_base_t *>(engine)
->device();

bool ok = is_fwd()
// Supported algorithms
&& utils::one_of(desc()->alg_kind, eltwise_relu,
eltwise_tanh, eltwise_elu, eltwise_logistic)
// Supported data types
&& utils::one_of(src_md()->data_type, data_type::f32,
data_type::f16, data_type::s8)
data_type::bf16, data_type::f16, data_type::s8)
&& src_md()->data_type == dst_md()->data_type
&& IMPLICATION(src_md()->data_type == data_type::bf16,
has_bf16_support(sycl_dev))
&& IMPLICATION(desc()->alg_kind == eltwise_relu,
desc()->alpha == 0)
&& attr()->has_default_values()
Expand Down Expand Up @@ -82,8 +89,13 @@ struct cudnn_eltwise_bwd_t : public primitive_t {
// Supported algorithms
&& utils::one_of(desc()->alg_kind, eltwise_relu)
// Supported data types
&& utils::everyone_is(data_type::f32, data_md()->data_type,
diff_src_md()->data_type, diff_dst_md()->data_type)
&& (utils::everyone_is(data_type::f32, data_md()->data_type,
diff_src_md()->data_type,
diff_dst_md()->data_type)
|| utils::everyone_is(data_type::bf16,
data_md()->data_type,
diff_src_md()->data_type,
diff_dst_md()->data_type))
&& IMPLICATION(desc()->alg_kind == eltwise_relu,
desc()->alpha == 0)
&& attr()->has_default_values()
Expand Down

0 comments on commit 4b46abf

Please sign in to comment.