diff --git a/src/gpu/nvidia/README.md b/src/gpu/nvidia/README.md index 503a3f33fb6..6e9f0234708 100644 --- a/src/gpu/nvidia/README.md +++ b/src/gpu/nvidia/README.md @@ -149,16 +149,16 @@ The following table shows the convolution status for the oneDNN Nvidia backend: #### Forward direction | Weights Format | Winograd Supported | Supported Input Format | Supported Output Format | Supported Data Type | Limitations | |----------------|--------------------|------------------------|-------------------------|---------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| 2D NCHW | YES | NCHW, NHWC | NCHW, NHWC | f32, f16 | The Winograd algorithm has limitations:
* Filter size must be 3x3 or 5x5.
* Dilation must be zero for all dimensions.
* Horizontal and vertical filter stride must be 1. | -| 2D NHWC | NO | NHWC | NHWC | f32, f16, int8 | * Dilation must be zero in all dimensions.
* Output feature maps must be multiple of 4 for `int8` type. | -| 3D NCHW | NO | NCHW, NHWC | NCHW, NHWC | f32, f16 | | +| 2D NCHW | YES | NCHW, NHWC | NCHW, NHWC | f32, f16, bf16 | The Winograd algorithm has limitations:
* Filter size must be 3x3 or 5x5.
* Dilation must be zero for all dimensions.
* Horizontal and vertical filter stride must be 1. | +| 2D NHWC | NO | NHWC | NHWC | f32, f16, bf16, int8 | * Dilation must be zero in all dimensions.
* Output feature maps must be multiple of 4 for `int8` type. | +| 3D NCHW | NO | NCHW, NHWC | NCHW, NHWC | f32, f16, bf16 | | #### Backward direction | Weights Format | Winograd Supported | Supported Input Format | Supported Output Format | Supported Data Type | Limitations | |----------------|--------------------|------------------------|-------------------------|---------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| 2D NCHW | YES | NCHW, NHWC | NCHW | f32, f16 | 1. Dilation must be zero for all dimensions.
2. The Winograd algorithm has limitations:
* Filter size must be 3x3 or 5x5.
* Dilation must be zero for all dimensions.
* Horizontal and vertical filter stride must be 1. | -| 2D NHWC | NO | NHWC | NHWC | f32, f16 | | -| 3D NCHW | NO | NCHW, NHWC | NCHW | f32, f16 | | +| 2D NCHW | YES | NCHW, NHWC | NCHW | f32, f16, bf16 | 1. Dilation must be zero for all dimensions.
2. The Winograd algorithm has limitations:
* Filter size must be 3x3 or 5x5.
* Dilation must be zero for all dimensions.
* Horizontal and vertical filter stride must be 1. | +| 2D NHWC | NO | NHWC | NHWC | f32, f16, bf16 | | +| 3D NCHW | NO | NCHW, NHWC | NCHW | f32, f16, bf16 | | ### Deconvolution Deconvolution primitive is implemented through the convolution with swapped diff --git a/src/gpu/nvidia/cudnn_convolution.hpp b/src/gpu/nvidia/cudnn_convolution.hpp index 0cf75ad5ae8..1f72db6250d 100644 --- a/src/gpu/nvidia/cudnn_convolution.hpp +++ b/src/gpu/nvidia/cudnn_convolution.hpp @@ -50,6 +50,8 @@ struct cudnn_convolution_fwd_t : public primitive_t { using sm_t = primitive_attr_t::skip_mask_t; const auto attr_skip_mask = sm_t::scales_runtime | sm_t::post_ops; + auto *sycl_engine + = utils::downcast(engine); bool ok = utils::one_of(desc()->prop_kind, prop_kind::forward_training, prop_kind::forward_inference); @@ -60,10 +62,17 @@ struct cudnn_convolution_fwd_t : public primitive_t { weights_md_.data_type, dst_md_.data_type) || utils::everyone_is(f16, src_md_.data_type, weights_md_.data_type, dst_md_.data_type) + || utils::everyone_is(bf16, src_md_.data_type, + weights_md_.data_type, dst_md_.data_type) || (utils::everyone_is(s8, src_md_.data_type, weights_md_.data_type) && utils::one_of( - dst_md_.data_type, f32, s8))); + dst_md_.data_type, f32, s8))) + && IMPLICATION( + utils::one_of(data_type::bf16, src_md_.data_type, + weights_md_.data_type, dst_md_.data_type), + has_bf16_support(sycl_engine->device())); + ok = ok && this->set_default_formats(); ok = ok && IMPLICATION( @@ -202,14 +211,25 @@ struct cudnn_convolution_bwd_data_t : public primitive_t { status_t init(engine_t *engine) { using namespace data_type; + bool ok = desc()->prop_kind == prop_kind::backward_data; + auto *sycl_engine + = utils::downcast(engine); ok = ok && this->set_default_formats(); ok = ok && (utils::everyone_is(f32, diff_src_md_.data_type, weights_md_.data_type, diff_dst_md_.data_type) || utils::everyone_is(f16, diff_src_md_.data_type, weights_md_.data_type, - diff_dst_md_.data_type)); + diff_dst_md_.data_type) + || utils::everyone_is(bf16, diff_src_md_.data_type, + weights_md_.data_type, + diff_dst_md_.data_type)) + && IMPLICATION(utils::one_of(data_type::bf16, + diff_src_md_.data_type, + weights_md_.data_type, + diff_dst_md_.data_type), + has_bf16_support(sycl_engine->device())); ok = ok && IMPLICATION( @@ -269,6 +289,8 @@ struct cudnn_convolution_bwd_weights_t : public primitive_t { status_t init(engine_t *engine) { using namespace data_type; bool ok = desc()->prop_kind == prop_kind::backward_weights; + auto *sycl_engine + = utils::downcast(engine); ok = ok && this->set_default_formats(); ok = ok && (utils::everyone_is(f32, src_md_.data_type, @@ -276,7 +298,16 @@ struct cudnn_convolution_bwd_weights_t : public primitive_t { diff_dst_md_.data_type) || utils::everyone_is(f16, src_md_.data_type, diff_weights_md_.data_type, - diff_dst_md_.data_type)); + diff_dst_md_.data_type) + || utils::everyone_is(bf16, src_md_.data_type, + diff_weights_md_.data_type, + diff_dst_md_.data_type)) + + && IMPLICATION( + utils::one_of(data_type::bf16, src_md_.data_type, + diff_weights_md_.data_type, + diff_dst_md_.data_type), + has_bf16_support(sycl_engine->device())); ok = ok && IMPLICATION( diff --git a/src/gpu/nvidia/cudnn_convolution_impl.hpp b/src/gpu/nvidia/cudnn_convolution_impl.hpp index 98d06c8b5b7..299e10d4ee6 100644 --- a/src/gpu/nvidia/cudnn_convolution_impl.hpp +++ b/src/gpu/nvidia/cudnn_convolution_impl.hpp @@ -330,10 +330,14 @@ struct cudnn_convolution_impl_base_t } void set_compute_format() { - if (data_types[x] == CUDNN_DATA_INT8) { - computation_data_type = CUDNN_DATA_INT32; - } else { - computation_data_type = data_types[y]; + switch (data_types[x]) { + case CUDNN_DATA_INT8: + computation_data_type = CUDNN_DATA_INT32; + break; + case CUDNN_DATA_BFLOAT16: + computation_data_type = CUDNN_DATA_FLOAT; + break; + default: computation_data_type = data_types[y]; } }