Skip to content

Commit

Permalink
gpu: nvidia: add bf16 support to cudnn based pool
Browse files Browse the repository at this point in the history
  • Loading branch information
densamoilov committed Mar 4, 2023
1 parent 4b46abf commit 5c03e3f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/gpu/nvidia/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ The following table documents the supported data types.
| f32 | Training, Inference |
| f16 | Inference |
| s8 | Inference (when applicable) |
| bf16 | Training, Inference |

## Supported Primitives and Implementation Limitations

Expand Down
8 changes: 7 additions & 1 deletion src/gpu/nvidia/cudnn_pooling.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2022 Intel Corporation
* Copyright 2020-2023 Intel Corporation
* Copyright 2020-2022 Codeplay Software Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -70,6 +70,12 @@ status_t cudnn_pooling_fwd_t::execute(const exec_ctx_t &ctx) const {
reinterpret_cast<unsigned char &>(val),
dst_wrap.nelems(),
cuda_stream->get_underlying_stream());
} else if (dst_wrap.data_type() == data_type_t::dnnl_bf16) {
bfloat16_t val = nstl::numeric_limits<bfloat16_t>::lowest();
cuMemsetD32Async(reinterpret_cast<CUdeviceptr>(dst),
reinterpret_cast<unsigned short &>(val),
dst_wrap.nelems(),
cuda_stream->get_underlying_stream());
}
});
});
Expand Down
23 changes: 19 additions & 4 deletions src/gpu/nvidia/cudnn_pooling.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2022 Intel Corporation
* Copyright 2020-2023 Intel Corporation
* Copyright 2020 Codeplay Software Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -81,20 +81,26 @@ struct cudnn_pooling_fwd_t : public primitive_t {

assert(engine->kind() == engine_kind::gpu);
auto src_dt = src_md()->data_type;
auto *sycl_engine
= utils::downcast<impl::sycl::sycl_engine_base_t *>(engine);

bool ok = true && is_fwd()
&& utils::one_of(desc()->prop_kind, forward_training,
forward_inference)
&& utils::one_of(desc()->alg_kind, pooling_max,
pooling_avg_include_padding,
pooling_avg_exclude_padding)
&& utils::one_of(src_dt, s8, f16, f32)
&& utils::one_of(src_dt, s8, f16, f32, bf16)
&& src_dt == dst_md()->data_type
&& IMPLICATION(utils::one_of(src_dt, f16),
desc()->prop_kind == forward_inference)
&& IMPLICATION(src_dt == s8, desc()->accum_data_type == s32)
&& !is_dilated() && attr()->has_default_values()
&& set_default_params() == status::success && blocking_ok();
&& set_default_params() == status::success && blocking_ok()
&& IMPLICATION(
utils::one_of(data_type::bf16, src_md()->data_type,
dst_md()->data_type),
has_bf16_support(sycl_engine->device()));
if (!ok) return status::unimplemented;

bool is_training = desc_.prop_kind == forward_training;
Expand Down Expand Up @@ -143,6 +149,8 @@ struct cudnn_pooling_bwd_t : public primitive_t {
using namespace alg_kind;
using namespace format_tag;
assert(engine->kind() == engine_kind::gpu);
auto *sycl_engine
= utils::downcast<impl::sycl::sycl_engine_base_t *>(engine);

bool ok = true && !is_fwd()
&& set_default_params() == status::success
Expand All @@ -154,10 +162,17 @@ struct cudnn_pooling_bwd_t : public primitive_t {
diff_dst_md()->data_type,
diff_src_md()->data_type)
|| utils::everyone_is(data_type::f16,
diff_dst_md()->data_type,
diff_src_md()->data_type)
|| utils::everyone_is(data_type::bf16,
diff_dst_md()->data_type,
diff_src_md()->data_type))
&& !is_dilated() && attr()->has_default_values()
&& no_blocking();
&& no_blocking()
&& IMPLICATION(utils::one_of(data_type::bf16,
diff_dst_md()->data_type,
diff_src_md()->data_type),
has_bf16_support(sycl_engine->device()));
if (!ok) return status::unimplemented;

init_mem_by_tag(get_tag(diff_dst_md_), diff_src_md_);
Expand Down

0 comments on commit 5c03e3f

Please sign in to comment.