Skip to content

Commit

Permalink
cpu: reorder: support reorders to and from e3m0
Browse files Browse the repository at this point in the history
  • Loading branch information
mgouicem committed Dec 20, 2024
1 parent 1587666 commit bb352ad
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 4 deletions.
14 changes: 14 additions & 0 deletions src/cpu/ref_io_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ inline float load_float_value(data_type_t dt, const void *ptr, dim_t idx) {
float4_e2m1_t val(nibble_pair.get(idx % 2), true);
return static_cast<float>(val);
}
case f4_e3m0: {
const nibble2_t nibble_pair
= reinterpret_cast<const nibble2_t *>(ptr)[idx / 2];
float4_e3m0_t val(nibble_pair.get(idx % 2), true);
return static_cast<float>(val);
}
default: assert(!"bad data_type");
}

Expand Down Expand Up @@ -133,6 +139,14 @@ inline void store_float_value(data_type_t dt, float val, void *ptr, dim_t idx) {
dst_[idx / 2] = nibble_pair;
break;
}
case f4_e3m0: {
auto dst_ = reinterpret_cast<nibble2_t *>(ptr);
nibble2_t nibble_pair = dst_[idx / 2];
float4_e3m0_t f4_val(val);
nibble_pair.set(f4_val.raw_bits_, idx % 2);
dst_[idx / 2] = nibble_pair;
break;
}
default: assert(!"bad data_type");
}

Expand Down
2 changes: 2 additions & 0 deletions src/cpu/reorder/cpu_reorder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ static const std::map<reorder_impl_key_t, const void *> &
regular_impl_list_map() {
static const std::map<reorder_impl_key_t, const void *> the_map = {
{{f32, f4_e2m1, 0}, &regular_fp4_impl_list_map()},
{{f32, f4_e3m0, 0}, &regular_fp4_impl_list_map()},
{{f32, e8m0, 0}, &regular_f32_fp8_impl_list_map()},
{{f32, f8_e5m2, 0}, &regular_f32_fp8_impl_list_map()},
{{f32, f8_e4m3, 0}, &regular_f32_fp8_impl_list_map()},
Expand All @@ -36,6 +37,7 @@ regular_impl_list_map() {
{{f32, s8, 0}, &regular_f32_s8_impl_list_map()},
{{f32, u8, 0}, &regular_f32_u8_impl_list_map()},
{{f4_e2m1, data_type::undef, 0}, &regular_fp4_impl_list_map()},
{{f4_e3m0, data_type::undef, 0}, &regular_fp4_impl_list_map()},
{{f8_e5m2, data_type::undef, 0}, &regular_fp8_impl_list_map()},
{{f8_e4m3, data_type::undef, 0}, &regular_fp8_impl_list_map()},
{{bf16, data_type::undef, 0}, &regular_bf16_impl_list_map()},
Expand Down
8 changes: 8 additions & 0 deletions src/cpu/reorder/cpu_reorder_regular_fp4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ const impl_list_map_t &regular_fp4_impl_list_map() {
REG_SR(f4_e2m1, any, f32, any, fmt_order::any, spec::reference)
nullptr,
}},
{{f32, f4_e3m0, 0}, {
REG_SR(f32, any, f4_e3m0, any, fmt_order::any, spec::reference)
nullptr,
}},
{{f4_e3m0, data_type::undef, 0}, {
REG_SR(f4_e3m0, any, f32, any, fmt_order::any, spec::reference)
nullptr,
}},
});
return the_map;
}
Expand Down
8 changes: 4 additions & 4 deletions src/cpu/reorder/simple_reorder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2154,7 +2154,7 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
typename utils::enable_if<tag_i == format_tag::any
&& tag_o == format_tag::any && type_i == data_type::f32
&& utils::one_of(type_o, data_type::s4, data_type::u4,
data_type::f4_e2m1),
data_type::f4_e2m1, data_type::f4_e3m0),
spec::reference>::type> {
static status_t is_applicable(const memory_desc_wrapper &input_d,
const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
Expand Down Expand Up @@ -2240,7 +2240,7 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
typename utils::enable_if<tag_i == format_tag::any
&& tag_o == format_tag::any
&& utils::one_of(type_i, data_type::s4, data_type::u4,
data_type::f4_e2m1)
data_type::f4_e2m1, data_type::f4_e3m0)
&& utils::one_of(type_o, data_type::f32,
data_type::bf16, data_type::f16),
spec::reference>::type> {
Expand Down Expand Up @@ -2509,9 +2509,9 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
&& order_keep == fmt_order::any
// u4/s4 requires a special implementation
&& !utils::one_of(type_i, data_type::s4, data_type::u4,
data_type::f4_e2m1)
data_type::f4_e2m1, data_type::f4_e3m0)
&& !utils::one_of(type_o, data_type::s4, data_type::u4,
data_type::f4_e2m1),
data_type::f4_e2m1, data_type::f4_e3m0),
spec::reference>::type> {
static status_t is_applicable(const memory_desc_wrapper &input_d,
const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
Expand Down

0 comments on commit bb352ad

Please sign in to comment.