From 545310d33879ab584dd3b736068ba936e3df4f00 Mon Sep 17 00:00:00 2001 From: Ankit Manerikar Date: Tue, 10 Dec 2024 13:13:29 -0800 Subject: [PATCH] tests: benchdnn: update binary tests to include select op --- tests/benchdnn/binary/bench_binary.cpp | 14 +++++------ tests/benchdnn/binary/binary.cpp | 24 +++++++++++++++---- tests/benchdnn/binary/binary.hpp | 2 ++ tests/benchdnn/binary/ref_binary.cpp | 9 ++++++- tests/benchdnn/dnn_types.cpp | 13 +++++++--- tests/benchdnn/dnn_types.hpp | 4 +++- tests/benchdnn/doc/driver_binary.md | 4 ++-- .../inputs/binary/harness_binary_bf16 | 2 +- .../benchdnn/inputs/binary/harness_binary_f16 | 2 +- .../benchdnn/inputs/binary/harness_binary_f32 | 2 +- .../benchdnn/inputs/binary/harness_binary_i8 | 2 +- tests/benchdnn/inputs/binary/test_binary_ci | 2 +- .../inputs/binary/test_binary_different_dt_ci | 2 +- 13 files changed, 57 insertions(+), 25 deletions(-) diff --git a/tests/benchdnn/binary/bench_binary.cpp b/tests/benchdnn/binary/bench_binary.cpp index 9c383d6cc47..77e9a319477 100644 --- a/tests/benchdnn/binary/bench_binary.cpp +++ b/tests/benchdnn/binary/bench_binary.cpp @@ -62,33 +62,33 @@ int verify_input(const settings_t &s) { static constexpr int n_inputs = 2; if (s.prb_vdims.n_inputs() != n_inputs) { - fprintf(stderr, + BENCHDNN_PRINT(0, "%s\n", "ERROR: binary driver: expect problem dimensions in format " - "`A0xA1x...:B0xB1x...`.\n"); + "`A0xA1x...:B0xB1x...`."); SAFE_V(FAIL); } for (const auto &i_sdt : s.sdt) { if (i_sdt.size() != n_inputs) { - fprintf(stderr, + BENCHDNN_PRINT(0, "%s\n", "ERROR: binary driver: expect data types in format " - "`DT:DT`.\n"); + "`DT:DT`."); SAFE_V(FAIL); } } for (const auto &i_stag : s.stag) { if (i_stag.size() != n_inputs) { - fprintf(stderr, + BENCHDNN_PRINT(0, "%s\n", "ERROR: binary driver: expect format tags in format " - "`TAG:TAG`.\n"); + "`TAG:TAG`."); SAFE_V(FAIL); } } for (const auto &i_alg : s.alg) { if (!(i_alg > alg_t::BINARY_START && i_alg < alg_t::BINARY_END)) { - fprintf(stderr, + BENCHDNN_PRINT(0, "ERROR: binary driver: algorithm `%s` does not belong to " "binary algorithm type.\n", attr_t::post_ops_t::kind2str(i_alg)); diff --git a/tests/benchdnn/binary/binary.cpp b/tests/benchdnn/binary/binary.cpp index 0994c3c2db4..74ff72e6dcc 100644 --- a/tests/benchdnn/binary/binary.cpp +++ b/tests/benchdnn/binary/binary.cpp @@ -56,7 +56,7 @@ int fill_mem( const auto dt = mem_dt.dt(); const int range = 16; - const int f_min = dt == dnnl_u8 ? 0 : -range / 2; + const int f_min = (dt == dnnl_u8 || input_idx == 2) ? 0 : -range / 2; benchdnn_parallel_nd(nelems, [&](int64_t i) { const int64_t gen = (12 * i + 5 * input_idx + 16) % (range + 1); @@ -93,10 +93,14 @@ dnnl_status_t init_pd(init_pd_args_t &init_pd_args) { auto dnnl_attr = make_benchdnn_dnnl_wrapper( create_dnnl_attr(prb->attr, attr_args)); - TIME_C_PD(DNN_SAFE_STATUS(dnnl_binary_primitive_desc_create( + auto src2_d = prb->is_ternary_op() ? dnn_mem_t::init_md(prb->ndims, + prb->vdims[0].data(), dnnl_s8, prb->stag[0]) + : nullptr; + + TIME_C_PD(DNN_SAFE_STATUS(dnnl_binary_primitive_desc_create_v2( &init_pd_args.pd, init_pd_args.engine, alg, - init_pd_args.src_md ? init_pd_args.src_md : src0_d, src1_d, dst_d, - dnnl_attr))); + init_pd_args.src_md ? init_pd_args.src_md : src0_d, src1_d, src2_d, + dst_d, dnnl_attr))); return dnnl_success; } @@ -118,6 +122,12 @@ void skip_unimplemented_prb(const prb_t *prb, res_t *res) { return; } + if (prb->is_ternary_op()) { + res->state = SKIPPED; + res->reason = skip_reason::case_not_supported; + return; + } + // gpu does not support s32 for (const auto &dt : dts) if (dt == dnnl_s32) { @@ -188,6 +198,7 @@ std::vector supported_exec_args(dir_t dir) { static const std::vector exec_args = { DNNL_ARG_SRC_0, DNNL_ARG_SRC_1, + DNNL_ARG_SRC_2, DNNL_ARG_DST, }; return exec_args; @@ -231,9 +242,12 @@ int init_ref_memory_args(dnn_mem_map_t &ref_mem_map, dnn_mem_map_t &mem_map, case DNNL_ARG_SRC_1: SAFE(fill_mem(prb, 1, mem, ref_mem), WARN); break; + case DNNL_ARG_SRC_2: + SAFE(fill_mem(prb, 2, mem, ref_mem), WARN); + break; case DNNL_ARG_DST: if (prb->attr.post_ops.find(alg_t::SUM) >= 0) { - SAFE(fill_mem(prb, 2, mem, ref_mem), WARN); + SAFE(fill_mem(prb, 3, mem, ref_mem), WARN); // Bitwise mode for sum requires a copy due to data for // post-op will be overwritten and it must be refreshed. diff --git a/tests/benchdnn/binary/binary.hpp b/tests/benchdnn/binary/binary.hpp index 76077ef7303..b4c1264d21d 100644 --- a/tests/benchdnn/binary/binary.hpp +++ b/tests/benchdnn/binary/binary.hpp @@ -109,6 +109,8 @@ struct prb_t : public prb_vdims_t { const char *str() const { return repro.c_str(); } + const bool is_ternary_op() const { return alg == alg_t::SELECT; } + private: std::string repro; diff --git a/tests/benchdnn/binary/ref_binary.cpp b/tests/benchdnn/binary/ref_binary.cpp index 1470700948e..07397a0df92 100644 --- a/tests/benchdnn/binary/ref_binary.cpp +++ b/tests/benchdnn/binary/ref_binary.cpp @@ -22,8 +22,10 @@ namespace binary { void compute_ref( const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { + const dnn_mem_t &src0 = args.find(DNNL_ARG_SRC_0); const dnn_mem_t &src1 = args.find(DNNL_ARG_SRC_1); + const dnn_mem_t &src2 = args.find(DNNL_ARG_SRC_2); const dnn_mem_t &dst = args.find(DNNL_ARG_DST); float *dst_ptr = (float *)dst; @@ -41,8 +43,13 @@ void compute_ref( benchdnn_parallel_nd(nelems, [&](int64_t i) { const auto idx_A = dst.get_idx(i, broadcast_mask_A); const auto idx_B = dst.get_idx(i, broadcast_mask_B); + + const bool c_val = prb->is_ternary_op() + ? static_cast(src2.get_elem(idx_A)) + : false; + float res = compute_binary( - prb->alg, scales[0] * A[idx_A], scales[1] * B[idx_B]); + prb->alg, scales[0] * A[idx_A], scales[1] * B[idx_B], c_val); float &dst_fp = dst_ptr[i]; const auto v_po_vals = prepare_po_vals(dst, args, v_po_masks, i); diff --git a/tests/benchdnn/dnn_types.cpp b/tests/benchdnn/dnn_types.cpp index d471754b9c4..804d133a5b2 100644 --- a/tests/benchdnn/dnn_types.cpp +++ b/tests/benchdnn/dnn_types.cpp @@ -532,6 +532,7 @@ static po_table_entry_t kind_table[] = { {pk_t::MIN, {"min", "binary_min"}, dnnl_binary_min}, {pk_t::MUL, {"mul", "binary_mul"}, dnnl_binary_mul}, {pk_t::NE, {"ne", "binary_ne"}, dnnl_binary_ne}, + {pk_t::SELECT, {"select", "binary_select"}, dnnl_binary_select}, {pk_t::SUB, {"sub", "binary_sub"}, dnnl_binary_sub}, {pk_t::BINARY_END, {"binary_undef"}, dnnl_alg_kind_undef}, // prelu @@ -628,7 +629,11 @@ bool attr_t::post_ops_t::entry_t::is_eltwise_kind() const { return kind > ELTWISE_START && kind < ELTWISE_END; } bool attr_t::post_ops_t::entry_t::is_binary_kind() const { - return kind > pk_t::BINARY_START && kind < pk_t::BINARY_END; + // binary select is a ternary operation and not currently + // supported in post-ops for the binary primitive + // TODO: add post-ops support for binary select operation + return kind > pk_t::BINARY_START && kind < pk_t::BINARY_END + && kind != pk_t::SELECT; } bool attr_t::post_ops_t::entry_t::is_prelu_kind() const { return kind == PRELU; @@ -1577,7 +1582,7 @@ float compute_eltwise_bwd( return NAN; } -float compute_binary(pk_t kind, float src0, float src1) { +float compute_binary(pk_t kind, float src0, float src1, bool src2) { // don't compute on nan, propagate it if (std::isnan(src0) || std::isnan(src1)) return NAN; @@ -1605,6 +1610,8 @@ float compute_binary(pk_t kind, float src0, float src1) { return src0 == src1; } else if (kind == pk_t::NE) { return src0 != src1; + } else if (kind == pk_t::SELECT) { + return src2 ? src0 : src1; } else { assert(!"operation not supported!"); } @@ -1664,7 +1671,7 @@ void maybe_post_ops(const attr_t &attr, float &val, float sum_val, const auto &b = e.eltwise.beta; val = compute_eltwise_fwd(e.kind, val, a, b); } else if (e.is_binary_kind()) { - val = compute_binary(e.kind, val, *it_po); + val = compute_binary(e.kind, val, *it_po, false); it_po++; } else if (e.is_prelu_kind()) { val = val > 0 ? val : val * (*it_po); diff --git a/tests/benchdnn/dnn_types.hpp b/tests/benchdnn/dnn_types.hpp index 9427c090795..c1cf07a9871 100644 --- a/tests/benchdnn/dnn_types.hpp +++ b/tests/benchdnn/dnn_types.hpp @@ -274,6 +274,7 @@ struct attr_t { MIN, MUL, NE, + SELECT, SUB, BINARY_END, // a guard to check kind is binary // prelu @@ -656,7 +657,8 @@ float compute_eltwise_fwd( attr_t::post_ops_t::kind_t kind, float src, float alpha, float beta); float compute_eltwise_bwd(attr_t::post_ops_t::kind_t kind, float d_dst, float src, float alpha, float beta); -float compute_binary(attr_t::post_ops_t::kind_t kind, float src0, float src1); +float compute_binary( + attr_t::post_ops_t::kind_t kind, float src0, float src1, bool src2); void maybe_dropout(const attr_t &attr, float &val, int64_t offset, const dnn_mem_t &dropout); void maybe_round(const attr_t &attr, int arg, float &val, int64_t offset, diff --git a/tests/benchdnn/doc/driver_binary.md b/tests/benchdnn/doc/driver_binary.md index 20a41c1cc71..9e39c89c891 100644 --- a/tests/benchdnn/doc/driver_binary.md +++ b/tests/benchdnn/doc/driver_binary.md @@ -17,8 +17,8 @@ where *binary-knobs* are: Refer to [tags](knobs_tag.md) for details. - `--dtag={any [default], ...}` -- physical dst memory layout. Refer to [tags](knobs_tag.md) for details. - - `--alg={ADD [default], DIV, EQ, GE, GT, LE, LT, MAX, MIN, MUL, NE, SUB}` -- - algorithm for binary operations. + - `--alg={ADD [default], DIV, EQ, GE, GT, LE, LT, MAX, MIN, MUL, NE, SELECT, SUB}` + -- algorithm for binary operations. Refer to [binary primitive](https://oneapi-src.github.io/oneDNN/dev_guide_binary.html) for details. - `--inplace=BOOL` -- memory mode for the primitive. If `true`, it uses input diff --git a/tests/benchdnn/inputs/binary/harness_binary_bf16 b/tests/benchdnn/inputs/binary/harness_binary_bf16 index 6b5dc832a15..22be4f83cd9 100644 --- a/tests/benchdnn/inputs/binary/harness_binary_bf16 +++ b/tests/benchdnn/inputs/binary/harness_binary_bf16 @@ -4,7 +4,7 @@ --inplace=true,false --ddt=bf16 --sdt=bf16:bf16 ---alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE +--alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE,SELECT --batch=option_set_all --batch=option_set_src0_bcast diff --git a/tests/benchdnn/inputs/binary/harness_binary_f16 b/tests/benchdnn/inputs/binary/harness_binary_f16 index a931e174652..508c4f1c75a 100644 --- a/tests/benchdnn/inputs/binary/harness_binary_f16 +++ b/tests/benchdnn/inputs/binary/harness_binary_f16 @@ -4,7 +4,7 @@ --inplace=true,false --ddt=f16 --sdt=f16:f16 ---alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE +--alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE,SELECT --batch=option_set_all --batch=option_set_src0_bcast diff --git a/tests/benchdnn/inputs/binary/harness_binary_f32 b/tests/benchdnn/inputs/binary/harness_binary_f32 index 6d772ff2b68..32dbf42ff5f 100644 --- a/tests/benchdnn/inputs/binary/harness_binary_f32 +++ b/tests/benchdnn/inputs/binary/harness_binary_f32 @@ -4,7 +4,7 @@ --inplace=true,false --ddt=f32 --sdt=f32:f32 ---alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE +--alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE,SELECT --batch=option_set_all --batch=option_set_src0_bcast diff --git a/tests/benchdnn/inputs/binary/harness_binary_i8 b/tests/benchdnn/inputs/binary/harness_binary_i8 index 22454c9153e..d10c4c7d875 100644 --- a/tests/benchdnn/inputs/binary/harness_binary_i8 +++ b/tests/benchdnn/inputs/binary/harness_binary_i8 @@ -9,7 +9,7 @@ --inplace=true,false --ddt=s8 --sdt=s8:s8 ---alg=ADD,MUL,MAX,MIN,SUB,GE,GT,LE,LT,EQ,NE +--alg=ADD,MUL,MAX,MIN,SUB,GE,GT,LE,LT,EQ,NE,SELECT --batch=option_set_all --batch=option_set_src0_bcast diff --git a/tests/benchdnn/inputs/binary/test_binary_ci b/tests/benchdnn/inputs/binary/test_binary_ci index ab461d707e1..fac598a8548 100644 --- a/tests/benchdnn/inputs/binary/test_binary_ci +++ b/tests/benchdnn/inputs/binary/test_binary_ci @@ -1,7 +1,7 @@ --reset --inplace=true,false ---alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE +--alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE,SELECT --stag=abx:any,axb:any --ddt=f32 --sdt=f32:f32 diff --git a/tests/benchdnn/inputs/binary/test_binary_different_dt_ci b/tests/benchdnn/inputs/binary/test_binary_different_dt_ci index dd82d7d9108..158d279b80b 100644 --- a/tests/benchdnn/inputs/binary/test_binary_different_dt_ci +++ b/tests/benchdnn/inputs/binary/test_binary_different_dt_ci @@ -2,7 +2,7 @@ --inplace=false # Different src and dst data types does not support in-place mode. --ddt=s8,u8,f32,s32 --sdt=s8:u8,u8:s8,s8:f32,f32:u8,f32:f32,f32:s32,s32:f32 ---alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE +--alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE,SELECT --stag=abx:any,axb:any --batch=shapes_ci