diff --git a/src/gpu/nvidia/README.md b/src/gpu/nvidia/README.md
index 503a3f33fb6..6e9f0234708 100644
--- a/src/gpu/nvidia/README.md
+++ b/src/gpu/nvidia/README.md
@@ -149,16 +149,16 @@ The following table shows the convolution status for the oneDNN Nvidia backend:
#### Forward direction
| Weights Format | Winograd Supported | Supported Input Format | Supported Output Format | Supported Data Type | Limitations |
|----------------|--------------------|------------------------|-------------------------|---------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| 2D NCHW | YES | NCHW, NHWC | NCHW, NHWC | f32, f16 | The Winograd algorithm has limitations:
* Filter size must be 3x3 or 5x5.
* Dilation must be zero for all dimensions.
* Horizontal and vertical filter stride must be 1. |
-| 2D NHWC | NO | NHWC | NHWC | f32, f16, int8 | * Dilation must be zero in all dimensions.
* Output feature maps must be multiple of 4 for `int8` type. |
-| 3D NCHW | NO | NCHW, NHWC | NCHW, NHWC | f32, f16 | |
+| 2D NCHW | YES | NCHW, NHWC | NCHW, NHWC | f32, f16, bf16 | The Winograd algorithm has limitations:
* Filter size must be 3x3 or 5x5.
* Dilation must be zero for all dimensions.
* Horizontal and vertical filter stride must be 1. |
+| 2D NHWC | NO | NHWC | NHWC | f32, f16, bf16, int8 | * Dilation must be zero in all dimensions.
* Output feature maps must be multiple of 4 for `int8` type. |
+| 3D NCHW | NO | NCHW, NHWC | NCHW, NHWC | f32, f16, bf16 | |
#### Backward direction
| Weights Format | Winograd Supported | Supported Input Format | Supported Output Format | Supported Data Type | Limitations |
|----------------|--------------------|------------------------|-------------------------|---------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| 2D NCHW | YES | NCHW, NHWC | NCHW | f32, f16 | 1. Dilation must be zero for all dimensions.
2. The Winograd algorithm has limitations:
* Filter size must be 3x3 or 5x5.
* Dilation must be zero for all dimensions.
* Horizontal and vertical filter stride must be 1. |
-| 2D NHWC | NO | NHWC | NHWC | f32, f16 | |
-| 3D NCHW | NO | NCHW, NHWC | NCHW | f32, f16 | |
+| 2D NCHW | YES | NCHW, NHWC | NCHW | f32, f16, bf16 | 1. Dilation must be zero for all dimensions.
2. The Winograd algorithm has limitations:
* Filter size must be 3x3 or 5x5.
* Dilation must be zero for all dimensions.
* Horizontal and vertical filter stride must be 1. |
+| 2D NHWC | NO | NHWC | NHWC | f32, f16, bf16 | |
+| 3D NCHW | NO | NCHW, NHWC | NCHW | f32, f16, bf16 | |
### Deconvolution
Deconvolution primitive is implemented through the convolution with swapped
diff --git a/src/gpu/nvidia/cudnn_convolution.hpp b/src/gpu/nvidia/cudnn_convolution.hpp
index 0cf75ad5ae8..1f72db6250d 100644
--- a/src/gpu/nvidia/cudnn_convolution.hpp
+++ b/src/gpu/nvidia/cudnn_convolution.hpp
@@ -50,6 +50,8 @@ struct cudnn_convolution_fwd_t : public primitive_t {
using sm_t = primitive_attr_t::skip_mask_t;
const auto attr_skip_mask = sm_t::scales_runtime | sm_t::post_ops;
+ auto *sycl_engine
+ = utils::downcast(engine);
bool ok = utils::one_of(desc()->prop_kind,
prop_kind::forward_training, prop_kind::forward_inference);
@@ -60,10 +62,17 @@ struct cudnn_convolution_fwd_t : public primitive_t {
weights_md_.data_type, dst_md_.data_type)
|| utils::everyone_is(f16, src_md_.data_type,
weights_md_.data_type, dst_md_.data_type)
+ || utils::everyone_is(bf16, src_md_.data_type,
+ weights_md_.data_type, dst_md_.data_type)
|| (utils::everyone_is(s8, src_md_.data_type,
weights_md_.data_type)
&& utils::one_of(
- dst_md_.data_type, f32, s8)));
+ dst_md_.data_type, f32, s8)))
+ && IMPLICATION(
+ utils::one_of(data_type::bf16, src_md_.data_type,
+ weights_md_.data_type, dst_md_.data_type),
+ has_bf16_support(sycl_engine->device()));
+
ok = ok && this->set_default_formats();
ok = ok
&& IMPLICATION(
@@ -202,14 +211,25 @@ struct cudnn_convolution_bwd_data_t : public primitive_t {
status_t init(engine_t *engine) {
using namespace data_type;
+
bool ok = desc()->prop_kind == prop_kind::backward_data;
+ auto *sycl_engine
+ = utils::downcast(engine);
ok = ok && this->set_default_formats();
ok = ok
&& (utils::everyone_is(f32, diff_src_md_.data_type,
weights_md_.data_type, diff_dst_md_.data_type)
|| utils::everyone_is(f16, diff_src_md_.data_type,
weights_md_.data_type,
- diff_dst_md_.data_type));
+ diff_dst_md_.data_type)
+ || utils::everyone_is(bf16, diff_src_md_.data_type,
+ weights_md_.data_type,
+ diff_dst_md_.data_type))
+ && IMPLICATION(utils::one_of(data_type::bf16,
+ diff_src_md_.data_type,
+ weights_md_.data_type,
+ diff_dst_md_.data_type),
+ has_bf16_support(sycl_engine->device()));
ok = ok
&& IMPLICATION(
@@ -269,6 +289,8 @@ struct cudnn_convolution_bwd_weights_t : public primitive_t {
status_t init(engine_t *engine) {
using namespace data_type;
bool ok = desc()->prop_kind == prop_kind::backward_weights;
+ auto *sycl_engine
+ = utils::downcast(engine);
ok = ok && this->set_default_formats();
ok = ok
&& (utils::everyone_is(f32, src_md_.data_type,
@@ -276,7 +298,16 @@ struct cudnn_convolution_bwd_weights_t : public primitive_t {
diff_dst_md_.data_type)
|| utils::everyone_is(f16, src_md_.data_type,
diff_weights_md_.data_type,
- diff_dst_md_.data_type));
+ diff_dst_md_.data_type)
+ || utils::everyone_is(bf16, src_md_.data_type,
+ diff_weights_md_.data_type,
+ diff_dst_md_.data_type))
+
+ && IMPLICATION(
+ utils::one_of(data_type::bf16, src_md_.data_type,
+ diff_weights_md_.data_type,
+ diff_dst_md_.data_type),
+ has_bf16_support(sycl_engine->device()));
ok = ok
&& IMPLICATION(
diff --git a/src/gpu/nvidia/cudnn_convolution_impl.hpp b/src/gpu/nvidia/cudnn_convolution_impl.hpp
index 98d06c8b5b7..299e10d4ee6 100644
--- a/src/gpu/nvidia/cudnn_convolution_impl.hpp
+++ b/src/gpu/nvidia/cudnn_convolution_impl.hpp
@@ -330,10 +330,14 @@ struct cudnn_convolution_impl_base_t
}
void set_compute_format() {
- if (data_types[x] == CUDNN_DATA_INT8) {
- computation_data_type = CUDNN_DATA_INT32;
- } else {
- computation_data_type = data_types[y];
+ switch (data_types[x]) {
+ case CUDNN_DATA_INT8:
+ computation_data_type = CUDNN_DATA_INT32;
+ break;
+ case CUDNN_DATA_BFLOAT16:
+ computation_data_type = CUDNN_DATA_FLOAT;
+ break;
+ default: computation_data_type = data_types[y];
}
}