diff --git a/include/oneapi/dnnl/dnnl.hpp b/include/oneapi/dnnl/dnnl.hpp index 80ca2fc15db..7ba89edf877 100644 --- a/include/oneapi/dnnl/dnnl.hpp +++ b/include/oneapi/dnnl/dnnl.hpp @@ -856,6 +856,8 @@ struct memory : public handle { enum class data_type { /// Undefined data type (used for empty memory descriptors). undef = dnnl_data_type_undef, + /// 4-bit float data type with 3-bit exponent and 0 bit mantissa. + f4_e3m0 = dnnl_f4_e3m0, /// [MX-compliant 4-bit float data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 2-bit exponent and 1 bit mantissa. f4_e2m1 = dnnl_f4_e2m1, /// [MX-compliant 8-bit compliant scale data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 8-bit exponent. diff --git a/include/oneapi/dnnl/dnnl_common_types.h b/include/oneapi/dnnl/dnnl_common_types.h index 737f6dfb037..d80a2e7d0a6 100644 --- a/include/oneapi/dnnl/dnnl_common_types.h +++ b/include/oneapi/dnnl/dnnl_common_types.h @@ -106,6 +106,8 @@ typedef enum { dnnl_e8m0 = 13, /// [MX-compliant 4-bit float data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 2-bit exponent and 1 bit mantissa. dnnl_f4_e2m1 = 14, + /// 4-bit float data type with 3-bit exponent and 0 bit mantissa. + dnnl_f4_e3m0 = 15, /// Parameter to allow internal only data_types without undefined behavior. /// This parameter is chosen to be valid for so long as sizeof(int) >= 2. diff --git a/src/common/c_types_map.hpp b/src/common/c_types_map.hpp index e7789e3ca88..b22a63ba2dd 100644 --- a/src/common/c_types_map.hpp +++ b/src/common/c_types_map.hpp @@ -153,6 +153,7 @@ const alg_kind_t eltwise_stochastic_round using data_type_t = dnnl_data_type_t; namespace data_type { const data_type_t undef = dnnl_data_type_undef; +const data_type_t f4_e3m0 = dnnl_f4_e3m0; const data_type_t f4_e2m1 = dnnl_f4_e2m1; const data_type_t e8m0 = dnnl_e8m0; const data_type_t f8_e5m2 = dnnl_f8_e5m2; diff --git a/src/common/dnnl_debug_autogenerated.cpp b/src/common/dnnl_debug_autogenerated.cpp index 055cf19e93e..ecd88375f86 100644 --- a/src/common/dnnl_debug_autogenerated.cpp +++ b/src/common/dnnl_debug_autogenerated.cpp @@ -59,6 +59,7 @@ const char *dnnl_dt2str(dnnl_data_type_t v) { if (v == dnnl_u4) return "u4"; if (v == dnnl_e8m0) return "e8m0"; if (v == dnnl_f4_e2m1) return "f4_e2m1"; + if (v == dnnl_f4_e3m0) return "f4_e3m0"; if (v == dnnl_data_type_max) return "data_type_max"; assert(!"unknown dt"); return "unknown dt"; diff --git a/src/common/dnnl_traits.hpp b/src/common/dnnl_traits.hpp index d9aaef8b1c1..618cc7c77d6 100644 --- a/src/common/dnnl_traits.hpp +++ b/src/common/dnnl_traits.hpp @@ -38,6 +38,10 @@ struct typesize_traits {}; /* ::data_type_size -> f32 */ template struct pkind_traits {}; /* ::desc_type, ::query_d */ +template <> +struct prec_traits { + typedef float4_e3m0_t type; +}; template <> struct prec_traits { typedef float4_e2m1_t type; @@ -95,6 +99,10 @@ struct prec_traits { typedef bool type; }; +template <> +struct data_traits { + static constexpr data_type_t data_type = data_type::f4_e3m0; +}; template <> struct data_traits { static constexpr data_type_t data_type = data_type::f4_e2m1; diff --git a/src/common/float4.cpp b/src/common/float4.cpp index 09ad8268301..636770a621e 100644 --- a/src/common/float4.cpp +++ b/src/common/float4.cpp @@ -32,7 +32,7 @@ uint8_t float2e2m1(float f) { // There is no NaN or infinity in e2m1, for now we just return zero // TODO: figure if there is a standard value to return uint32_t naninf_mask = 0x7f800000; - if ((f_raw & naninf_mask) == naninf_mask) return 0x00000000; + if ((f_raw & naninf_mask) == naninf_mask) return 0x00; // we convert with naive closest value computation out of 8 float e2m1_val_table[8] = {0.0f, .5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f}; @@ -91,5 +91,72 @@ float4_e2m1_t::operator float16_t() const { return e2m1_table[raw_bits_]; } +uint8_t float2e3m0(float f) { + uint32_t f_raw = float2int(f); + uint32_t sign = f_raw & 0x80000000; + + // There is no NaN or infinity in e3m0, we just return maxval + uint32_t naninf_mask = 0x7f800000; + if ((f_raw & naninf_mask) == naninf_mask) return 0x7; + + // we convert with naive closest value computation out of 8 + float e3m0_val_table[8] = {0.0f, .25f, .5f, 1.0f, 2.0f, 4.0f, 8.0f, 16.0f}; + + float abs_f = int2float(f_raw ^ sign); + + int idx = 0; + float min_diff = ::fabsf(e3m0_val_table[idx] - abs_f); + uint8_t raw_bits = idx; + for (++idx; idx < 8; ++idx) { + float diff = ::fabsf(e3m0_val_table[idx] - abs_f); + if (diff < min_diff) { + min_diff = diff; + raw_bits = idx; + } + // Special case for midpoint, we round to even (so even index) + if ((diff == min_diff) && !(idx & 1)) raw_bits = idx; + } + assert(raw_bits < 8); + // reapply sign + if (sign) raw_bits = raw_bits | 0x08; + assert(raw_bits < 16); + return raw_bits; +} + +float4_e3m0_t &float4_e3m0_t::operator=(bfloat16_t f) { + float f32 = f; + raw_bits_ = float2e3m0(f32); + return *this; +} + +float4_e3m0_t &float4_e3m0_t::operator=(float16_t f) { + float f32 = f; + raw_bits_ = float2e3m0(f32); + return *this; +} + +float4_e3m0_t &float4_e3m0_t::operator=(float f) { + raw_bits_ = float2e3m0(f); + return *this; +} + +float4_e3m0_t::operator float() const { + // List of e3m0 values. The index of each value maps to its encoding. + static const float e3m0_table[16] + = {0.0f, .25f, .5f, 1.0f, 2.0f, 4.0f, 8.0f, 16.0f, -0.0f, -.25f, + -.5f, -1.0f, -2.0f, -4.0f, -8.0f, -16.0f}; + assert(raw_bits_ < 16); + return e3m0_table[raw_bits_]; +} + +float4_e3m0_t::operator float16_t() const { + // List of e3m0 values. The index of each value maps to its encoding. + static const float16_t e3m0_table[16] + = {0.0f, .25f, .5f, 1.0f, 2.0f, 4.0f, 8.0f, 16.0f, -0.0f, -.25f, + -.5f, -1.0f, -2.0f, -4.0f, -8.0f, -16.0f}; + assert(raw_bits_ < 16); + return e3m0_table[raw_bits_]; +} + } // namespace impl } // namespace dnnl diff --git a/src/common/float4.hpp b/src/common/float4.hpp index 1d24c07235a..44be31d9d0a 100644 --- a/src/common/float4.hpp +++ b/src/common/float4.hpp @@ -49,6 +49,29 @@ struct float4_e2m1_t { }; static_assert(sizeof(float4_e2m1_t) == 1, "float4_e2m1_t must be 1 byte"); +struct float4_e3m0_t { + uint8_t raw_bits_; + float4_e3m0_t() = default; + constexpr float4_e3m0_t(uint8_t r, bool = true) : raw_bits_(r) {} + float4_e3m0_t(float f) { (*this) = f; } + float4_e3m0_t(float16_t f) { (*this) = f; } + float4_e3m0_t(bfloat16_t f) { (*this) = f; } + + float4_e3m0_t DNNL_API &operator=(float f); + float4_e3m0_t DNNL_API &operator=(float16_t f); + float4_e3m0_t DNNL_API &operator=(bfloat16_t f); + + DNNL_API operator float() const; + DNNL_API operator float16_t() const; + DNNL_API operator bfloat16_t() const; + + float4_e3m0_t &operator+=(const float a) { + (*this) = float {*this} + a; + return *this; + } +}; +static_assert(sizeof(float4_e3m0_t) == 1, "float4_e3m0_t must be 1 byte"); + } // namespace impl } // namespace dnnl diff --git a/src/common/nstl.hpp b/src/common/nstl.hpp index 1f5093ba40a..3623ff4e8b6 100644 --- a/src/common/nstl.hpp +++ b/src/common/nstl.hpp @@ -157,6 +157,22 @@ struct numeric_limits : public std::numeric_limits {}; template <> struct numeric_limits : public std::numeric_limits {}; +template <> +struct numeric_limits { + static constexpr float4_e3m0_t lowest() { return float4_e3m0_t(0xf, true); } + // Min normal is equal to the value 1.0 + static constexpr float4_e3m0_t min() { return float4_e3m0_t(0x1, true); } + // Max normal is equal to the value 6.0 + static constexpr float4_e3m0_t max() { return float4_e3m0_t(0x7, true); } + + static constexpr int bias = 0x3; + static constexpr int digits = 1; // 1 implicit bit + + static constexpr float4_e3m0_t epsilon() { + return float4_e3m0_t(0x3, true); + } +}; + template <> struct numeric_limits { static constexpr float4_e2m1_t lowest() { return float4_e2m1_t(0xf, true); } diff --git a/src/common/type_helpers.hpp b/src/common/type_helpers.hpp index 83aeaae8a7e..4ad0eb8be5f 100644 --- a/src/common/type_helpers.hpp +++ b/src/common/type_helpers.hpp @@ -93,6 +93,7 @@ namespace types { inline size_t data_type_size(data_type_t data_type) { using namespace data_type; switch ((int)data_type) { + case f4_e3m0: return sizeof(prec_traits::type); case f4_e2m1: return sizeof(prec_traits::type); case e8m0: return sizeof(prec_traits::type); case f8_e5m2: return sizeof(prec_traits::type); @@ -139,6 +140,7 @@ inline T min_value(data_type_t data_type) { case x: \ return static_cast(nstl::numeric_limits::type>::min()) switch (data_type) { + CASE(f4_e3m0); CASE(f4_e2m1); CASE(e8m0); CASE(f8_e5m2); @@ -166,6 +168,7 @@ inline T max_value(data_type_t data_type) { case x: \ return static_cast(nstl::numeric_limits::type>::max()) switch (data_type) { + CASE(f4_e3m0); CASE(f4_e2m1); CASE(e8m0); CASE(f8_e5m2); @@ -195,6 +198,7 @@ inline float max_value(data_type_t data_type) { return static_cast( \ nstl::numeric_limits::type>::max()) switch (data_type) { + CASE(f4_e3m0); CASE(f4_e2m1); CASE(e8m0); CASE(f8_e5m2); @@ -233,6 +237,7 @@ inline T lowest_value(data_type_t data_type) { return static_cast( \ nstl::numeric_limits::type>::lowest()) switch (data_type) { + CASE(f4_e3m0); CASE(f4_e2m1); CASE(e8m0); CASE(f8_e5m2); @@ -261,6 +266,7 @@ inline T digits(data_type_t data_type) { return static_cast( \ nstl::numeric_limits::type>::digits) switch (data_type) { + CASE(f4_e3m0); CASE(f4_e2m1); CASE(e8m0); CASE(f8_e5m2); @@ -419,6 +425,7 @@ inline data_type_t default_accum_data_type( // true if (one_of(src_dt, s8, u8, u4, s4) && (dst_dt != f32 || strict)) return s32; + if (one_of(f4_e3m0, src_dt, dst_dt)) return f32; if (one_of(f4_e2m1, src_dt, dst_dt)) return f32; if (one_of(f8_e5m2, src_dt, dst_dt)) return f32; if (one_of(f8_e4m3, src_dt, dst_dt)) return f32; @@ -461,6 +468,7 @@ inline data_type_t default_accum_data_type(data_type_t src_dt, return f32; } + if (one_of(f4_e3m0, src_dt, wei_dt, dst_dt)) return f32; if (one_of(f4_e2m1, src_dt, wei_dt, dst_dt)) return f32; if (one_of(f8_e5m2, src_dt, wei_dt, dst_dt)) return f32; if (one_of(f8_e4m3, src_dt, wei_dt, dst_dt)) return f32; @@ -1262,8 +1270,8 @@ inline bool memory_desc_sanity_check(int ndims, const dims_t dims, if (ndims == 0) return true; bool ok = dims != nullptr && 0 < ndims && ndims <= DNNL_MAX_NDIMS - && utils::one_of(data_type, f4_e2m1, e8m0, f8_e5m2, f8_e4m3, f16, - bf16, f32, f64, s32, s8, u8, s4, u4); + && utils::one_of(data_type, f4_e3m0, f4_e2m1, e8m0, f8_e5m2, + f8_e4m3, f16, bf16, f32, f64, s32, s8, u8, s4, u4); if (!ok) return false; bool has_runtime_dims = false; diff --git a/tests/benchdnn/dnnl_debug_autogenerated.cpp b/tests/benchdnn/dnnl_debug_autogenerated.cpp index fe27010fcf0..4e621c66f09 100644 --- a/tests/benchdnn/dnnl_debug_autogenerated.cpp +++ b/tests/benchdnn/dnnl_debug_autogenerated.cpp @@ -50,6 +50,7 @@ dnnl_data_type_t str2dt(const char *str) { CASE(u4); CASE(e8m0); CASE(f4_e2m1); + CASE(f4_e3m0); CASE(data_type_max); #undef CASE if (!strcmp("undef", str) || !strcmp("dnnl_data_type_undef", str))