From 5c03e3f0117ca6aeb67fa1a95daf8b4c32f0b5b2 Mon Sep 17 00:00:00 2001 From: Denis Samoilov Date: Fri, 3 Mar 2023 14:19:04 -0800 Subject: [PATCH] gpu: nvidia: add bf16 support to cudnn based pool --- src/gpu/nvidia/README.md | 1 + src/gpu/nvidia/cudnn_pooling.cpp | 8 +++++++- src/gpu/nvidia/cudnn_pooling.hpp | 23 +++++++++++++++++++---- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/src/gpu/nvidia/README.md b/src/gpu/nvidia/README.md index e2f250dd66a..fffba665074 100644 --- a/src/gpu/nvidia/README.md +++ b/src/gpu/nvidia/README.md @@ -38,6 +38,7 @@ The following table documents the supported data types. | f32 | Training, Inference | | f16 | Inference | | s8 | Inference (when applicable) | +| bf16 | Training, Inference | ## Supported Primitives and Implementation Limitations diff --git a/src/gpu/nvidia/cudnn_pooling.cpp b/src/gpu/nvidia/cudnn_pooling.cpp index dfb60afecc4..4ecd0cde418 100644 --- a/src/gpu/nvidia/cudnn_pooling.cpp +++ b/src/gpu/nvidia/cudnn_pooling.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2023 Intel Corporation * Copyright 2020-2022 Codeplay Software Limited * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -70,6 +70,12 @@ status_t cudnn_pooling_fwd_t::execute(const exec_ctx_t &ctx) const { reinterpret_cast(val), dst_wrap.nelems(), cuda_stream->get_underlying_stream()); + } else if (dst_wrap.data_type() == data_type_t::dnnl_bf16) { + bfloat16_t val = nstl::numeric_limits::lowest(); + cuMemsetD32Async(reinterpret_cast(dst), + reinterpret_cast(val), + dst_wrap.nelems(), + cuda_stream->get_underlying_stream()); } }); }); diff --git a/src/gpu/nvidia/cudnn_pooling.hpp b/src/gpu/nvidia/cudnn_pooling.hpp index 449f62ccd47..83d396284f6 100644 --- a/src/gpu/nvidia/cudnn_pooling.hpp +++ b/src/gpu/nvidia/cudnn_pooling.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation +* Copyright 2020-2023 Intel Corporation * Copyright 2020 Codeplay Software Limited * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -81,6 +81,8 @@ struct cudnn_pooling_fwd_t : public primitive_t { assert(engine->kind() == engine_kind::gpu); auto src_dt = src_md()->data_type; + auto *sycl_engine + = utils::downcast(engine); bool ok = true && is_fwd() && utils::one_of(desc()->prop_kind, forward_training, @@ -88,13 +90,17 @@ struct cudnn_pooling_fwd_t : public primitive_t { && utils::one_of(desc()->alg_kind, pooling_max, pooling_avg_include_padding, pooling_avg_exclude_padding) - && utils::one_of(src_dt, s8, f16, f32) + && utils::one_of(src_dt, s8, f16, f32, bf16) && src_dt == dst_md()->data_type && IMPLICATION(utils::one_of(src_dt, f16), desc()->prop_kind == forward_inference) && IMPLICATION(src_dt == s8, desc()->accum_data_type == s32) && !is_dilated() && attr()->has_default_values() - && set_default_params() == status::success && blocking_ok(); + && set_default_params() == status::success && blocking_ok() + && IMPLICATION( + utils::one_of(data_type::bf16, src_md()->data_type, + dst_md()->data_type), + has_bf16_support(sycl_engine->device())); if (!ok) return status::unimplemented; bool is_training = desc_.prop_kind == forward_training; @@ -143,6 +149,8 @@ struct cudnn_pooling_bwd_t : public primitive_t { using namespace alg_kind; using namespace format_tag; assert(engine->kind() == engine_kind::gpu); + auto *sycl_engine + = utils::downcast(engine); bool ok = true && !is_fwd() && set_default_params() == status::success @@ -154,10 +162,17 @@ struct cudnn_pooling_bwd_t : public primitive_t { diff_dst_md()->data_type, diff_src_md()->data_type) || utils::everyone_is(data_type::f16, + diff_dst_md()->data_type, + diff_src_md()->data_type) + || utils::everyone_is(data_type::bf16, diff_dst_md()->data_type, diff_src_md()->data_type)) && !is_dilated() && attr()->has_default_values() - && no_blocking(); + && no_blocking() + && IMPLICATION(utils::one_of(data_type::bf16, + diff_dst_md()->data_type, + diff_src_md()->data_type), + has_bf16_support(sycl_engine->device())); if (!ok) return status::unimplemented; init_mem_by_tag(get_tag(diff_dst_md_), diff_src_md_);