Skip to content

Commit

Permalink
api: add fp4_e3m0 support
Browse files Browse the repository at this point in the history
  • Loading branch information
mgouicem committed Dec 20, 2024
1 parent ed96323 commit 32dc361
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 3 deletions.
2 changes: 2 additions & 0 deletions include/oneapi/dnnl/dnnl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,8 @@ struct memory : public handle<dnnl_memory_t> {
enum class data_type {
/// Undefined data type (used for empty memory descriptors).
undef = dnnl_data_type_undef,
/// 4-bit float data type with 3-bit exponent and 0 bit mantissa.
f4_e3m0 = dnnl_f4_e3m0,
/// [MX-compliant 4-bit float data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 2-bit exponent and 1 bit mantissa.
f4_e2m1 = dnnl_f4_e2m1,
/// [MX-compliant 8-bit compliant scale data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 8-bit exponent.
Expand Down
2 changes: 2 additions & 0 deletions include/oneapi/dnnl/dnnl_common_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ typedef enum {
dnnl_e8m0 = 13,
/// [MX-compliant 4-bit float data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 2-bit exponent and 1 bit mantissa.
dnnl_f4_e2m1 = 14,
/// 4-bit float data type with 3-bit exponent and 0 bit mantissa.
dnnl_f4_e3m0 = 15,

/// Parameter to allow internal only data_types without undefined behavior.
/// This parameter is chosen to be valid for so long as sizeof(int) >= 2.
Expand Down
1 change: 1 addition & 0 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ const alg_kind_t eltwise_stochastic_round
using data_type_t = dnnl_data_type_t;
namespace data_type {
const data_type_t undef = dnnl_data_type_undef;
const data_type_t f4_e3m0 = dnnl_f4_e3m0;
const data_type_t f4_e2m1 = dnnl_f4_e2m1;
const data_type_t e8m0 = dnnl_e8m0;
const data_type_t f8_e5m2 = dnnl_f8_e5m2;
Expand Down
1 change: 1 addition & 0 deletions src/common/dnnl_debug_autogenerated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ const char *dnnl_dt2str(dnnl_data_type_t v) {
if (v == dnnl_u4) return "u4";
if (v == dnnl_e8m0) return "e8m0";
if (v == dnnl_f4_e2m1) return "f4_e2m1";
if (v == dnnl_f4_e3m0) return "f4_e3m0";
if (v == dnnl_data_type_max) return "data_type_max";
assert(!"unknown dt");
return "unknown dt";
Expand Down
8 changes: 8 additions & 0 deletions src/common/dnnl_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ struct typesize_traits {}; /* ::data_type_size -> f32 */
template <primitive_kind_t>
struct pkind_traits {}; /* ::desc_type, ::query_d */

template <>
struct prec_traits<data_type::f4_e3m0> {
typedef float4_e3m0_t type;
};
template <>
struct prec_traits<data_type::f4_e2m1> {
typedef float4_e2m1_t type;
Expand Down Expand Up @@ -95,6 +99,10 @@ struct prec_traits<data_type::boolean> {
typedef bool type;
};

template <>
struct data_traits<float4_e3m0_t> {
static constexpr data_type_t data_type = data_type::f4_e3m0;
};
template <>
struct data_traits<float4_e2m1_t> {
static constexpr data_type_t data_type = data_type::f4_e2m1;
Expand Down
69 changes: 68 additions & 1 deletion src/common/float4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ uint8_t float2e2m1(float f) {
// There is no NaN or infinity in e2m1, for now we just return zero
// TODO: figure if there is a standard value to return
uint32_t naninf_mask = 0x7f800000;
if ((f_raw & naninf_mask) == naninf_mask) return 0x00000000;
if ((f_raw & naninf_mask) == naninf_mask) return 0x00;

// we convert with naive closest value computation out of 8
float e2m1_val_table[8] = {0.0f, .5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f};
Expand Down Expand Up @@ -91,5 +91,72 @@ float4_e2m1_t::operator float16_t() const {
return e2m1_table[raw_bits_];
}

uint8_t float2e3m0(float f) {
uint32_t f_raw = float2int(f);
uint32_t sign = f_raw & 0x80000000;

// There is no NaN or infinity in e3m0, we just return maxval
uint32_t naninf_mask = 0x7f800000;
if ((f_raw & naninf_mask) == naninf_mask) return 0x7;

// we convert with naive closest value computation out of 8
float e3m0_val_table[8] = {0.0f, .25f, .5f, 1.0f, 2.0f, 4.0f, 8.0f, 16.0f};

float abs_f = int2float(f_raw ^ sign);

int idx = 0;
float min_diff = ::fabsf(e3m0_val_table[idx] - abs_f);
uint8_t raw_bits = idx;
for (++idx; idx < 8; ++idx) {
float diff = ::fabsf(e3m0_val_table[idx] - abs_f);
if (diff < min_diff) {
min_diff = diff;
raw_bits = idx;
}
// Special case for midpoint, we round to even (so even index)
if ((diff == min_diff) && !(idx & 1)) raw_bits = idx;
}
assert(raw_bits < 8);
// reapply sign
if (sign) raw_bits = raw_bits | 0x08;
assert(raw_bits < 16);
return raw_bits;
}

float4_e3m0_t &float4_e3m0_t::operator=(bfloat16_t f) {
float f32 = f;
raw_bits_ = float2e3m0(f32);
return *this;
}

float4_e3m0_t &float4_e3m0_t::operator=(float16_t f) {
float f32 = f;
raw_bits_ = float2e3m0(f32);
return *this;
}

float4_e3m0_t &float4_e3m0_t::operator=(float f) {
raw_bits_ = float2e3m0(f);
return *this;
}

float4_e3m0_t::operator float() const {
// List of e3m0 values. The index of each value maps to its encoding.
static const float e3m0_table[16]
= {0.0f, .25f, .5f, 1.0f, 2.0f, 4.0f, 8.0f, 16.0f, -0.0f, -.25f,
-.5f, -1.0f, -2.0f, -4.0f, -8.0f, -16.0f};
assert(raw_bits_ < 16);
return e3m0_table[raw_bits_];
}

float4_e3m0_t::operator float16_t() const {
// List of e3m0 values. The index of each value maps to its encoding.
static const float16_t e3m0_table[16]
= {0.0f, .25f, .5f, 1.0f, 2.0f, 4.0f, 8.0f, 16.0f, -0.0f, -.25f,
-.5f, -1.0f, -2.0f, -4.0f, -8.0f, -16.0f};
assert(raw_bits_ < 16);
return e3m0_table[raw_bits_];
}

} // namespace impl
} // namespace dnnl
23 changes: 23 additions & 0 deletions src/common/float4.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,29 @@ struct float4_e2m1_t {
};
static_assert(sizeof(float4_e2m1_t) == 1, "float4_e2m1_t must be 1 byte");

struct float4_e3m0_t {
uint8_t raw_bits_;
float4_e3m0_t() = default;
constexpr float4_e3m0_t(uint8_t r, bool = true) : raw_bits_(r) {}
float4_e3m0_t(float f) { (*this) = f; }
float4_e3m0_t(float16_t f) { (*this) = f; }
float4_e3m0_t(bfloat16_t f) { (*this) = f; }

float4_e3m0_t DNNL_API &operator=(float f);
float4_e3m0_t DNNL_API &operator=(float16_t f);
float4_e3m0_t DNNL_API &operator=(bfloat16_t f);

DNNL_API operator float() const;
DNNL_API operator float16_t() const;
DNNL_API operator bfloat16_t() const;

float4_e3m0_t &operator+=(const float a) {
(*this) = float {*this} + a;
return *this;
}
};
static_assert(sizeof(float4_e3m0_t) == 1, "float4_e3m0_t must be 1 byte");

} // namespace impl
} // namespace dnnl

Expand Down
16 changes: 16 additions & 0 deletions src/common/nstl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,22 @@ struct numeric_limits<int8_t> : public std::numeric_limits<int8_t> {};
template <>
struct numeric_limits<uint8_t> : public std::numeric_limits<uint8_t> {};

template <>
struct numeric_limits<float4_e3m0_t> {
static constexpr float4_e3m0_t lowest() { return float4_e3m0_t(0xf, true); }
// Min normal is equal to the value 1.0
static constexpr float4_e3m0_t min() { return float4_e3m0_t(0x1, true); }
// Max normal is equal to the value 6.0
static constexpr float4_e3m0_t max() { return float4_e3m0_t(0x7, true); }

static constexpr int bias = 0x3;
static constexpr int digits = 1; // 1 implicit bit

static constexpr float4_e3m0_t epsilon() {
return float4_e3m0_t(0x3, true);
}
};

template <>
struct numeric_limits<float4_e2m1_t> {
static constexpr float4_e2m1_t lowest() { return float4_e2m1_t(0xf, true); }
Expand Down
12 changes: 10 additions & 2 deletions src/common/type_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ namespace types {
inline size_t data_type_size(data_type_t data_type) {
using namespace data_type;
switch ((int)data_type) {
case f4_e3m0: return sizeof(prec_traits<f4_e3m0>::type);
case f4_e2m1: return sizeof(prec_traits<f4_e2m1>::type);
case e8m0: return sizeof(prec_traits<e8m0>::type);
case f8_e5m2: return sizeof(prec_traits<f8_e5m2>::type);
Expand Down Expand Up @@ -139,6 +140,7 @@ inline T min_value(data_type_t data_type) {
case x: \
return static_cast<T>(nstl::numeric_limits<prec_traits<x>::type>::min())
switch (data_type) {
CASE(f4_e3m0);
CASE(f4_e2m1);
CASE(e8m0);
CASE(f8_e5m2);
Expand Down Expand Up @@ -166,6 +168,7 @@ inline T max_value(data_type_t data_type) {
case x: \
return static_cast<T>(nstl::numeric_limits<prec_traits<x>::type>::max())
switch (data_type) {
CASE(f4_e3m0);
CASE(f4_e2m1);
CASE(e8m0);
CASE(f8_e5m2);
Expand Down Expand Up @@ -195,6 +198,7 @@ inline float max_value(data_type_t data_type) {
return static_cast<float>( \
nstl::numeric_limits<prec_traits<x>::type>::max())
switch (data_type) {
CASE(f4_e3m0);
CASE(f4_e2m1);
CASE(e8m0);
CASE(f8_e5m2);
Expand Down Expand Up @@ -233,6 +237,7 @@ inline T lowest_value(data_type_t data_type) {
return static_cast<T>( \
nstl::numeric_limits<prec_traits<x>::type>::lowest())
switch (data_type) {
CASE(f4_e3m0);
CASE(f4_e2m1);
CASE(e8m0);
CASE(f8_e5m2);
Expand Down Expand Up @@ -261,6 +266,7 @@ inline T digits(data_type_t data_type) {
return static_cast<T>( \
nstl::numeric_limits<prec_traits<x>::type>::digits)
switch (data_type) {
CASE(f4_e3m0);
CASE(f4_e2m1);
CASE(e8m0);
CASE(f8_e5m2);
Expand Down Expand Up @@ -419,6 +425,7 @@ inline data_type_t default_accum_data_type(
// true
if (one_of(src_dt, s8, u8, u4, s4) && (dst_dt != f32 || strict)) return s32;

if (one_of(f4_e3m0, src_dt, dst_dt)) return f32;
if (one_of(f4_e2m1, src_dt, dst_dt)) return f32;
if (one_of(f8_e5m2, src_dt, dst_dt)) return f32;
if (one_of(f8_e4m3, src_dt, dst_dt)) return f32;
Expand Down Expand Up @@ -461,6 +468,7 @@ inline data_type_t default_accum_data_type(data_type_t src_dt,
return f32;
}

if (one_of(f4_e3m0, src_dt, wei_dt, dst_dt)) return f32;
if (one_of(f4_e2m1, src_dt, wei_dt, dst_dt)) return f32;
if (one_of(f8_e5m2, src_dt, wei_dt, dst_dt)) return f32;
if (one_of(f8_e4m3, src_dt, wei_dt, dst_dt)) return f32;
Expand Down Expand Up @@ -1262,8 +1270,8 @@ inline bool memory_desc_sanity_check(int ndims, const dims_t dims,
if (ndims == 0) return true;

bool ok = dims != nullptr && 0 < ndims && ndims <= DNNL_MAX_NDIMS
&& utils::one_of(data_type, f4_e2m1, e8m0, f8_e5m2, f8_e4m3, f16,
bf16, f32, f64, s32, s8, u8, s4, u4);
&& utils::one_of(data_type, f4_e3m0, f4_e2m1, e8m0, f8_e5m2,
f8_e4m3, f16, bf16, f32, f64, s32, s8, u8, s4, u4);
if (!ok) return false;

bool has_runtime_dims = false;
Expand Down
1 change: 1 addition & 0 deletions tests/benchdnn/dnnl_debug_autogenerated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ dnnl_data_type_t str2dt(const char *str) {
CASE(u4);
CASE(e8m0);
CASE(f4_e2m1);
CASE(f4_e3m0);
CASE(data_type_max);
#undef CASE
if (!strcmp("undef", str) || !strcmp("dnnl_data_type_undef", str))
Expand Down

0 comments on commit 32dc361

Please sign in to comment.