Skip to content

Commit

Permalink
gpu: nvidia: add bf16 support to cudnn based convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
densamoilov committed Mar 11, 2023
1 parent 8db69ae commit 01e8272
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 13 deletions.
12 changes: 6 additions & 6 deletions src/gpu/nvidia/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,16 @@ The following table shows the convolution status for the oneDNN Nvidia backend:
#### Forward direction
| Weights Format | Winograd Supported | Supported Input Format | Supported Output Format | Supported Data Type | Limitations |
|----------------|--------------------|------------------------|-------------------------|---------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 2D NCHW | YES | NCHW, NHWC | NCHW, NHWC | f32, f16 | The Winograd algorithm has limitations: <br> * Filter size must be 3x3 or 5x5. <br> * Dilation must be zero for all dimensions. <br> * Horizontal and vertical filter stride must be 1. |
| 2D NHWC | NO | NHWC | NHWC | f32, f16, int8 | * Dilation must be zero in all dimensions. <br> * Output feature maps must be multiple of 4 for `int8` type. |
| 3D NCHW | NO | NCHW, NHWC | NCHW, NHWC | f32, f16 | |
| 2D NCHW | YES | NCHW, NHWC | NCHW, NHWC | f32, f16, bf16 | The Winograd algorithm has limitations: <br> * Filter size must be 3x3 or 5x5. <br> * Dilation must be zero for all dimensions. <br> * Horizontal and vertical filter stride must be 1. |
| 2D NHWC | NO | NHWC | NHWC | f32, f16, bf16, int8 | * Dilation must be zero in all dimensions. <br> * Output feature maps must be multiple of 4 for `int8` type. |
| 3D NCHW | NO | NCHW, NHWC | NCHW, NHWC | f32, f16, bf16 | |

#### Backward direction
| Weights Format | Winograd Supported | Supported Input Format | Supported Output Format | Supported Data Type | Limitations |
|----------------|--------------------|------------------------|-------------------------|---------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 2D NCHW | YES | NCHW, NHWC | NCHW | f32, f16 | 1. Dilation must be zero for all dimensions. <br> 2. The Winograd algorithm has limitations: <br> * Filter size must be 3x3 or 5x5. <br> * Dilation must be zero for all dimensions. <br> * Horizontal and vertical filter stride must be 1. |
| 2D NHWC | NO | NHWC | NHWC | f32, f16 | |
| 3D NCHW | NO | NCHW, NHWC | NCHW | f32, f16 | |
| 2D NCHW | YES | NCHW, NHWC | NCHW | f32, f16, bf16 | 1. Dilation must be zero for all dimensions. <br> 2. The Winograd algorithm has limitations: <br> * Filter size must be 3x3 or 5x5. <br> * Dilation must be zero for all dimensions. <br> * Horizontal and vertical filter stride must be 1. |
| 2D NHWC | NO | NHWC | NHWC | f32, f16, bf16 | |
| 3D NCHW | NO | NCHW, NHWC | NCHW | f32, f16, bf16 | |
### Deconvolution

Deconvolution primitive is implemented through the convolution with swapped
Expand Down
37 changes: 34 additions & 3 deletions src/gpu/nvidia/cudnn_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ struct cudnn_convolution_fwd_t : public primitive_t {

using sm_t = primitive_attr_t::skip_mask_t;
const auto attr_skip_mask = sm_t::scales_runtime | sm_t::post_ops;
auto *sycl_engine
= utils::downcast<impl::sycl::sycl_engine_base_t *>(engine);

bool ok = utils::one_of(desc()->prop_kind,
prop_kind::forward_training, prop_kind::forward_inference);
Expand All @@ -60,10 +62,17 @@ struct cudnn_convolution_fwd_t : public primitive_t {
weights_md_.data_type, dst_md_.data_type)
|| utils::everyone_is(f16, src_md_.data_type,
weights_md_.data_type, dst_md_.data_type)
|| utils::everyone_is(bf16, src_md_.data_type,
weights_md_.data_type, dst_md_.data_type)
|| (utils::everyone_is(s8, src_md_.data_type,
weights_md_.data_type)
&& utils::one_of(
dst_md_.data_type, f32, s8)));
dst_md_.data_type, f32, s8)))
&& IMPLICATION(
utils::one_of(data_type::bf16, src_md_.data_type,
weights_md_.data_type, dst_md_.data_type),
has_bf16_support(sycl_engine->device()));

ok = ok && this->set_default_formats();
ok = ok
&& IMPLICATION(
Expand Down Expand Up @@ -202,14 +211,25 @@ struct cudnn_convolution_bwd_data_t : public primitive_t {

status_t init(engine_t *engine) {
using namespace data_type;

bool ok = desc()->prop_kind == prop_kind::backward_data;
auto *sycl_engine
= utils::downcast<impl::sycl::sycl_engine_base_t *>(engine);
ok = ok && this->set_default_formats();
ok = ok
&& (utils::everyone_is(f32, diff_src_md_.data_type,
weights_md_.data_type, diff_dst_md_.data_type)
|| utils::everyone_is(f16, diff_src_md_.data_type,
weights_md_.data_type,
diff_dst_md_.data_type));
diff_dst_md_.data_type)
|| utils::everyone_is(bf16, diff_src_md_.data_type,
weights_md_.data_type,
diff_dst_md_.data_type))
&& IMPLICATION(utils::one_of(data_type::bf16,
diff_src_md_.data_type,
weights_md_.data_type,
diff_dst_md_.data_type),
has_bf16_support(sycl_engine->device()));

ok = ok
&& IMPLICATION(
Expand Down Expand Up @@ -269,14 +289,25 @@ struct cudnn_convolution_bwd_weights_t : public primitive_t {
status_t init(engine_t *engine) {
using namespace data_type;
bool ok = desc()->prop_kind == prop_kind::backward_weights;
auto *sycl_engine
= utils::downcast<impl::sycl::sycl_engine_base_t *>(engine);
ok = ok && this->set_default_formats();
ok = ok
&& (utils::everyone_is(f32, src_md_.data_type,
diff_weights_md_.data_type,
diff_dst_md_.data_type)
|| utils::everyone_is(f16, src_md_.data_type,
diff_weights_md_.data_type,
diff_dst_md_.data_type));
diff_dst_md_.data_type)
|| utils::everyone_is(bf16, src_md_.data_type,
diff_weights_md_.data_type,
diff_dst_md_.data_type))

&& IMPLICATION(
utils::one_of(data_type::bf16, src_md_.data_type,
diff_weights_md_.data_type,
diff_dst_md_.data_type),
has_bf16_support(sycl_engine->device()));

ok = ok
&& IMPLICATION(
Expand Down
12 changes: 8 additions & 4 deletions src/gpu/nvidia/cudnn_convolution_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,14 @@ struct cudnn_convolution_impl_base_t
}

void set_compute_format() {
if (data_types[x] == CUDNN_DATA_INT8) {
computation_data_type = CUDNN_DATA_INT32;
} else {
computation_data_type = data_types[y];
switch (data_types[x]) {
case CUDNN_DATA_INT8:
computation_data_type = CUDNN_DATA_INT32;
break;
case CUDNN_DATA_BFLOAT16:
computation_data_type = CUDNN_DATA_FLOAT;
break;
default: computation_data_type = data_types[y];
}
}

Expand Down

0 comments on commit 01e8272

Please sign in to comment.