Skip to content

Commit

Permalink
tests: benchdnn: update binary tests to include select op
Browse files Browse the repository at this point in the history
  • Loading branch information
avmanerikar committed Dec 11, 2024
1 parent 65c8687 commit 545310d
Show file tree
Hide file tree
Showing 13 changed files with 57 additions and 25 deletions.
14 changes: 7 additions & 7 deletions tests/benchdnn/binary/bench_binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,33 +62,33 @@ int verify_input(const settings_t &s) {
static constexpr int n_inputs = 2;

if (s.prb_vdims.n_inputs() != n_inputs) {
fprintf(stderr,
BENCHDNN_PRINT(0, "%s\n",
"ERROR: binary driver: expect problem dimensions in format "
"`A0xA1x...:B0xB1x...`.\n");
"`A0xA1x...:B0xB1x...`.");
SAFE_V(FAIL);
}

for (const auto &i_sdt : s.sdt) {
if (i_sdt.size() != n_inputs) {
fprintf(stderr,
BENCHDNN_PRINT(0, "%s\n",
"ERROR: binary driver: expect data types in format "
"`DT:DT`.\n");
"`DT:DT`.");
SAFE_V(FAIL);
}
}

for (const auto &i_stag : s.stag) {
if (i_stag.size() != n_inputs) {
fprintf(stderr,
BENCHDNN_PRINT(0, "%s\n",
"ERROR: binary driver: expect format tags in format "
"`TAG:TAG`.\n");
"`TAG:TAG`.");
SAFE_V(FAIL);
}
}

for (const auto &i_alg : s.alg) {
if (!(i_alg > alg_t::BINARY_START && i_alg < alg_t::BINARY_END)) {
fprintf(stderr,
BENCHDNN_PRINT(0,
"ERROR: binary driver: algorithm `%s` does not belong to "
"binary algorithm type.\n",
attr_t::post_ops_t::kind2str(i_alg));
Expand Down
24 changes: 19 additions & 5 deletions tests/benchdnn/binary/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ int fill_mem(

const auto dt = mem_dt.dt();
const int range = 16;
const int f_min = dt == dnnl_u8 ? 0 : -range / 2;
const int f_min = (dt == dnnl_u8 || input_idx == 2) ? 0 : -range / 2;

benchdnn_parallel_nd(nelems, [&](int64_t i) {
const int64_t gen = (12 * i + 5 * input_idx + 16) % (range + 1);
Expand Down Expand Up @@ -93,10 +93,14 @@ dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
auto dnnl_attr = make_benchdnn_dnnl_wrapper(
create_dnnl_attr(prb->attr, attr_args));

TIME_C_PD(DNN_SAFE_STATUS(dnnl_binary_primitive_desc_create(
auto src2_d = prb->is_ternary_op() ? dnn_mem_t::init_md(prb->ndims,
prb->vdims[0].data(), dnnl_s8, prb->stag[0])
: nullptr;

TIME_C_PD(DNN_SAFE_STATUS(dnnl_binary_primitive_desc_create_v2(
&init_pd_args.pd, init_pd_args.engine, alg,
init_pd_args.src_md ? init_pd_args.src_md : src0_d, src1_d, dst_d,
dnnl_attr)));
init_pd_args.src_md ? init_pd_args.src_md : src0_d, src1_d, src2_d,
dst_d, dnnl_attr)));

return dnnl_success;
}
Expand All @@ -118,6 +122,12 @@ void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
return;
}

if (prb->is_ternary_op()) {
res->state = SKIPPED;
res->reason = skip_reason::case_not_supported;
return;
}

// gpu does not support s32
for (const auto &dt : dts)
if (dt == dnnl_s32) {
Expand Down Expand Up @@ -188,6 +198,7 @@ std::vector<int> supported_exec_args(dir_t dir) {
static const std::vector<int> exec_args = {
DNNL_ARG_SRC_0,
DNNL_ARG_SRC_1,
DNNL_ARG_SRC_2,
DNNL_ARG_DST,
};
return exec_args;
Expand Down Expand Up @@ -231,9 +242,12 @@ int init_ref_memory_args(dnn_mem_map_t &ref_mem_map, dnn_mem_map_t &mem_map,
case DNNL_ARG_SRC_1:
SAFE(fill_mem(prb, 1, mem, ref_mem), WARN);
break;
case DNNL_ARG_SRC_2:
SAFE(fill_mem(prb, 2, mem, ref_mem), WARN);
break;
case DNNL_ARG_DST:
if (prb->attr.post_ops.find(alg_t::SUM) >= 0) {
SAFE(fill_mem(prb, 2, mem, ref_mem), WARN);
SAFE(fill_mem(prb, 3, mem, ref_mem), WARN);

// Bitwise mode for sum requires a copy due to data for
// post-op will be overwritten and it must be refreshed.
Expand Down
2 changes: 2 additions & 0 deletions tests/benchdnn/binary/binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ struct prb_t : public prb_vdims_t {

const char *str() const { return repro.c_str(); }

const bool is_ternary_op() const { return alg == alg_t::SELECT; }

private:
std::string repro;

Expand Down
9 changes: 8 additions & 1 deletion tests/benchdnn/binary/ref_binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ namespace binary {

void compute_ref(
const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {

const dnn_mem_t &src0 = args.find(DNNL_ARG_SRC_0);
const dnn_mem_t &src1 = args.find(DNNL_ARG_SRC_1);
const dnn_mem_t &src2 = args.find(DNNL_ARG_SRC_2);
const dnn_mem_t &dst = args.find(DNNL_ARG_DST);

float *dst_ptr = (float *)dst;
Expand All @@ -41,8 +43,13 @@ void compute_ref(
benchdnn_parallel_nd(nelems, [&](int64_t i) {
const auto idx_A = dst.get_idx(i, broadcast_mask_A);
const auto idx_B = dst.get_idx(i, broadcast_mask_B);

const bool c_val = prb->is_ternary_op()
? static_cast<bool>(src2.get_elem(idx_A))
: false;

float res = compute_binary(
prb->alg, scales[0] * A[idx_A], scales[1] * B[idx_B]);
prb->alg, scales[0] * A[idx_A], scales[1] * B[idx_B], c_val);
float &dst_fp = dst_ptr[i];

const auto v_po_vals = prepare_po_vals(dst, args, v_po_masks, i);
Expand Down
13 changes: 10 additions & 3 deletions tests/benchdnn/dnn_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ static po_table_entry_t kind_table[] = {
{pk_t::MIN, {"min", "binary_min"}, dnnl_binary_min},
{pk_t::MUL, {"mul", "binary_mul"}, dnnl_binary_mul},
{pk_t::NE, {"ne", "binary_ne"}, dnnl_binary_ne},
{pk_t::SELECT, {"select", "binary_select"}, dnnl_binary_select},
{pk_t::SUB, {"sub", "binary_sub"}, dnnl_binary_sub},
{pk_t::BINARY_END, {"binary_undef"}, dnnl_alg_kind_undef},
// prelu
Expand Down Expand Up @@ -628,7 +629,11 @@ bool attr_t::post_ops_t::entry_t::is_eltwise_kind() const {
return kind > ELTWISE_START && kind < ELTWISE_END;
}
bool attr_t::post_ops_t::entry_t::is_binary_kind() const {
return kind > pk_t::BINARY_START && kind < pk_t::BINARY_END;
// binary select is a ternary operation and not currently
// supported in post-ops for the binary primitive
// TODO: add post-ops support for binary select operation
return kind > pk_t::BINARY_START && kind < pk_t::BINARY_END
&& kind != pk_t::SELECT;
}
bool attr_t::post_ops_t::entry_t::is_prelu_kind() const {
return kind == PRELU;
Expand Down Expand Up @@ -1577,7 +1582,7 @@ float compute_eltwise_bwd(
return NAN;
}

float compute_binary(pk_t kind, float src0, float src1) {
float compute_binary(pk_t kind, float src0, float src1, bool src2) {
// don't compute on nan, propagate it
if (std::isnan(src0) || std::isnan(src1)) return NAN;

Expand Down Expand Up @@ -1605,6 +1610,8 @@ float compute_binary(pk_t kind, float src0, float src1) {
return src0 == src1;
} else if (kind == pk_t::NE) {
return src0 != src1;
} else if (kind == pk_t::SELECT) {
return src2 ? src0 : src1;
} else {
assert(!"operation not supported!");
}
Expand Down Expand Up @@ -1664,7 +1671,7 @@ void maybe_post_ops(const attr_t &attr, float &val, float sum_val,
const auto &b = e.eltwise.beta;
val = compute_eltwise_fwd(e.kind, val, a, b);
} else if (e.is_binary_kind()) {
val = compute_binary(e.kind, val, *it_po);
val = compute_binary(e.kind, val, *it_po, false);
it_po++;
} else if (e.is_prelu_kind()) {
val = val > 0 ? val : val * (*it_po);
Expand Down
4 changes: 3 additions & 1 deletion tests/benchdnn/dnn_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ struct attr_t {
MIN,
MUL,
NE,
SELECT,
SUB,
BINARY_END, // a guard to check kind is binary
// prelu
Expand Down Expand Up @@ -656,7 +657,8 @@ float compute_eltwise_fwd(
attr_t::post_ops_t::kind_t kind, float src, float alpha, float beta);
float compute_eltwise_bwd(attr_t::post_ops_t::kind_t kind, float d_dst,
float src, float alpha, float beta);
float compute_binary(attr_t::post_ops_t::kind_t kind, float src0, float src1);
float compute_binary(
attr_t::post_ops_t::kind_t kind, float src0, float src1, bool src2);
void maybe_dropout(const attr_t &attr, float &val, int64_t offset,
const dnn_mem_t &dropout);
void maybe_round(const attr_t &attr, int arg, float &val, int64_t offset,
Expand Down
4 changes: 2 additions & 2 deletions tests/benchdnn/doc/driver_binary.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ where *binary-knobs* are:
Refer to [tags](knobs_tag.md) for details.
- `--dtag={any [default], ...}` -- physical dst memory layout.
Refer to [tags](knobs_tag.md) for details.
- `--alg={ADD [default], DIV, EQ, GE, GT, LE, LT, MAX, MIN, MUL, NE, SUB}` --
algorithm for binary operations.
- `--alg={ADD [default], DIV, EQ, GE, GT, LE, LT, MAX, MIN, MUL, NE, SELECT, SUB}`
-- algorithm for binary operations.
Refer to [binary primitive](https://oneapi-src.github.io/oneDNN/dev_guide_binary.html)
for details.
- `--inplace=BOOL` -- memory mode for the primitive. If `true`, it uses input
Expand Down
2 changes: 1 addition & 1 deletion tests/benchdnn/inputs/binary/harness_binary_bf16
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
--inplace=true,false
--ddt=bf16 --sdt=bf16:bf16

--alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE
--alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE,SELECT
--batch=option_set_all
--batch=option_set_src0_bcast

Expand Down
2 changes: 1 addition & 1 deletion tests/benchdnn/inputs/binary/harness_binary_f16
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
--inplace=true,false
--ddt=f16 --sdt=f16:f16

--alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE
--alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE,SELECT
--batch=option_set_all
--batch=option_set_src0_bcast

Expand Down
2 changes: 1 addition & 1 deletion tests/benchdnn/inputs/binary/harness_binary_f32
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
--inplace=true,false
--ddt=f32 --sdt=f32:f32

--alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE
--alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE,SELECT
--batch=option_set_all
--batch=option_set_src0_bcast

Expand Down
2 changes: 1 addition & 1 deletion tests/benchdnn/inputs/binary/harness_binary_i8
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
--inplace=true,false
--ddt=s8 --sdt=s8:s8

--alg=ADD,MUL,MAX,MIN,SUB,GE,GT,LE,LT,EQ,NE
--alg=ADD,MUL,MAX,MIN,SUB,GE,GT,LE,LT,EQ,NE,SELECT
--batch=option_set_all
--batch=option_set_src0_bcast

Expand Down
2 changes: 1 addition & 1 deletion tests/benchdnn/inputs/binary/test_binary_ci
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--reset

--inplace=true,false
--alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE
--alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE,SELECT
--stag=abx:any,axb:any

--ddt=f32 --sdt=f32:f32
Expand Down
2 changes: 1 addition & 1 deletion tests/benchdnn/inputs/binary/test_binary_different_dt_ci
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

--inplace=false # Different src and dst data types does not support in-place mode.
--ddt=s8,u8,f32,s32 --sdt=s8:u8,u8:s8,s8:f32,f32:u8,f32:f32,f32:s32,s32:f32
--alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE
--alg=ADD,MUL,MAX,MIN,DIV,SUB,GE,GT,LE,LT,EQ,NE,SELECT
--stag=abx:any,axb:any
--batch=shapes_ci

Expand Down

0 comments on commit 545310d

Please sign in to comment.