From bb352ad1ea0ea8a8e03a7e228df07c1fdd5f9b51 Mon Sep 17 00:00:00 2001 From: Mourad Gouicem Date: Mon, 16 Dec 2024 00:58:27 -0800 Subject: [PATCH] cpu: reorder: support reorders to and from e3m0 --- src/cpu/ref_io_helper.hpp | 14 ++++++++++++++ src/cpu/reorder/cpu_reorder.cpp | 2 ++ src/cpu/reorder/cpu_reorder_regular_fp4.cpp | 8 ++++++++ src/cpu/reorder/simple_reorder.hpp | 8 ++++---- 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/cpu/ref_io_helper.hpp b/src/cpu/ref_io_helper.hpp index 537184976ee..9024385bfc5 100644 --- a/src/cpu/ref_io_helper.hpp +++ b/src/cpu/ref_io_helper.hpp @@ -99,6 +99,12 @@ inline float load_float_value(data_type_t dt, const void *ptr, dim_t idx) { float4_e2m1_t val(nibble_pair.get(idx % 2), true); return static_cast(val); } + case f4_e3m0: { + const nibble2_t nibble_pair + = reinterpret_cast(ptr)[idx / 2]; + float4_e3m0_t val(nibble_pair.get(idx % 2), true); + return static_cast(val); + } default: assert(!"bad data_type"); } @@ -133,6 +139,14 @@ inline void store_float_value(data_type_t dt, float val, void *ptr, dim_t idx) { dst_[idx / 2] = nibble_pair; break; } + case f4_e3m0: { + auto dst_ = reinterpret_cast(ptr); + nibble2_t nibble_pair = dst_[idx / 2]; + float4_e3m0_t f4_val(val); + nibble_pair.set(f4_val.raw_bits_, idx % 2); + dst_[idx / 2] = nibble_pair; + break; + } default: assert(!"bad data_type"); } diff --git a/src/cpu/reorder/cpu_reorder.cpp b/src/cpu/reorder/cpu_reorder.cpp index ba02835bee0..bd5878d8be8 100644 --- a/src/cpu/reorder/cpu_reorder.cpp +++ b/src/cpu/reorder/cpu_reorder.cpp @@ -26,6 +26,7 @@ static const std::map & regular_impl_list_map() { static const std::map the_map = { {{f32, f4_e2m1, 0}, ®ular_fp4_impl_list_map()}, + {{f32, f4_e3m0, 0}, ®ular_fp4_impl_list_map()}, {{f32, e8m0, 0}, ®ular_f32_fp8_impl_list_map()}, {{f32, f8_e5m2, 0}, ®ular_f32_fp8_impl_list_map()}, {{f32, f8_e4m3, 0}, ®ular_f32_fp8_impl_list_map()}, @@ -36,6 +37,7 @@ regular_impl_list_map() { {{f32, s8, 0}, ®ular_f32_s8_impl_list_map()}, {{f32, u8, 0}, ®ular_f32_u8_impl_list_map()}, {{f4_e2m1, data_type::undef, 0}, ®ular_fp4_impl_list_map()}, + {{f4_e3m0, data_type::undef, 0}, ®ular_fp4_impl_list_map()}, {{f8_e5m2, data_type::undef, 0}, ®ular_fp8_impl_list_map()}, {{f8_e4m3, data_type::undef, 0}, ®ular_fp8_impl_list_map()}, {{bf16, data_type::undef, 0}, ®ular_bf16_impl_list_map()}, diff --git a/src/cpu/reorder/cpu_reorder_regular_fp4.cpp b/src/cpu/reorder/cpu_reorder_regular_fp4.cpp index 09424fe04af..49e3a0ae604 100644 --- a/src/cpu/reorder/cpu_reorder_regular_fp4.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_fp4.cpp @@ -32,6 +32,14 @@ const impl_list_map_t ®ular_fp4_impl_list_map() { REG_SR(f4_e2m1, any, f32, any, fmt_order::any, spec::reference) nullptr, }}, + {{f32, f4_e3m0, 0}, { + REG_SR(f32, any, f4_e3m0, any, fmt_order::any, spec::reference) + nullptr, + }}, + {{f4_e3m0, data_type::undef, 0}, { + REG_SR(f4_e3m0, any, f32, any, fmt_order::any, spec::reference) + nullptr, + }}, }); return the_map; } diff --git a/src/cpu/reorder/simple_reorder.hpp b/src/cpu/reorder/simple_reorder.hpp index fbc13f38e0c..00b8654ae62 100644 --- a/src/cpu/reorder/simple_reorder.hpp +++ b/src/cpu/reorder/simple_reorder.hpp @@ -2154,7 +2154,7 @@ struct simple_reorder_impl::type> { static status_t is_applicable(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { @@ -2240,7 +2240,7 @@ struct simple_reorder_impl::type> { @@ -2509,9 +2509,9 @@ struct simple_reorder_impl::type> { static status_t is_applicable(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {