Skip to content

Commit

Permalink
gpu: ocl: remove unecessary post_op kind
Browse files Browse the repository at this point in the history
The binary and eltwise enums are guaranteed to not overlap via the API
  • Loading branch information
rjoursler authored and vpirogov committed May 12, 2023
1 parent 9a66ac6 commit 87fd48f
Showing 1 changed file with 28 additions and 37 deletions.
65 changes: 28 additions & 37 deletions src/gpu/ocl/ocl_post_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,23 @@
#include "gpu/ocl/ocl_eltwise.h"
#include "gpu/ocl/ocl_types.h"

float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
float alpha, float beta, float scale) {
if (kind == PO_BINARY) {
switch (algorithm) {
// binary
case BINARY_ADD: return x + y; break;
case BINARY_MUL: return x * y; break;
case BINARY_MIN: return x < y ? x : y; break;
case BINARY_MAX: return x > y ? x : y; break;
case BINARY_DIV: return x / y; break;
case BINARY_SUB: return x - y; break;
case BINARY_GE: return x >= y; break;
case BINARY_GT: return x > y; break;
case BINARY_LE: return x <= y; break;
case BINARY_LT: return x < y; break;
case BINARY_EQ: return x == y; break;
case BINARY_NE: return x != y; break;
case RELU: // binary && relu = prelu
return fwd_eltwise_common(RELU, x, y, beta, scale);
break;
default: return 0.f;
}
} else { // eltwise kind
return fwd_eltwise_common(algorithm, x, alpha, beta, scale);
float fwd_Xnary(unsigned algorithm, float x, float y, float alpha, float beta,
float scale) {
switch (algorithm) {
// binary
case BINARY_ADD: return x + y; break;
case BINARY_MUL: return x * y; break;
case BINARY_MIN: return x < y ? x : y; break;
case BINARY_MAX: return x > y ? x : y; break;
case BINARY_DIV: return x / y; break;
case BINARY_SUB: return x - y; break;
case BINARY_GE: return x >= y; break;
case BINARY_GT: return x > y; break;
case BINARY_LE: return x <= y; break;
case BINARY_LT: return x < y; break;
case BINARY_EQ: return x == y; break;
case BINARY_NE: return x != y; break;
default: return fwd_eltwise_common(algorithm, x, alpha, beta, scale);
}
}

Expand All @@ -65,28 +58,26 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
ret_val; \
})

#define FWD_XNARY_GENERIC_DT(po_kind, algorithm, result, result_elem_dt, \
arg0_ptr, arg0_len, arg1_ptr, arg1_len, alpha, beta, scale) \
#define FWD_XNARY_GENERIC_DT(algorithm, result, result_elem_dt, arg0_ptr, \
arg0_len, arg1_ptr, arg1_len, alpha, beta, scale) \
{ \
auto ty = arg0_len + arg1_len; \
const typeof(ty) out_len \
= max((typeof(ty))arg0_len, (typeof(ty))arg1_len); \
result_elem_dt *res_ptr = (result_elem_dt *)(&result); \
unroll_for(typeof(out_len + 0) idx = 0; idx < out_len; ++idx) { \
if (arg0_len == 1 && arg1_len == 1) { \
*res_ptr = fwd_Xnary(po_kind, algorithm, \
convert_float(*arg0_ptr), convert_float(*arg1_ptr), \
alpha, beta, scale); \
*res_ptr = fwd_Xnary(algorithm, convert_float(*arg0_ptr), \
convert_float(*arg1_ptr), alpha, beta, scale); \
} else if (arg0_len == 1) { \
res_ptr[idx] = fwd_Xnary(po_kind, algorithm, \
convert_float(*arg0_ptr), \
res_ptr[idx] = fwd_Xnary(algorithm, convert_float(*arg0_ptr), \
convert_float(arg1_ptr[idx]), alpha, beta, scale); \
} else if (arg1_len == 1) { \
res_ptr[idx] = fwd_Xnary(po_kind, algorithm, \
convert_float(arg0_ptr[idx]), \
convert_float(*arg1_ptr), alpha, beta, scale); \
res_ptr[idx] \
= fwd_Xnary(algorithm, convert_float(arg0_ptr[idx]), \
convert_float(*arg1_ptr), alpha, beta, scale); \
} else { \
res_ptr[idx] = fwd_Xnary(po_kind, algorithm, \
res_ptr[idx] = fwd_Xnary(algorithm, \
convert_float(arg0_ptr[idx]), \
convert_float(arg1_ptr[idx]), alpha, beta, scale); \
} \
Expand Down Expand Up @@ -277,7 +268,7 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
REPLICATE_DATA(bin_arg_ptr, bin_arg_size, x0_s, x1_size, x2_s, \
x3_s, x4_s, x5_s); \
} \
FWD_XNARY_GENERIC_DT(PO_BINARY, CONCAT3(PO_, idx, _ALG), accumulator, \
FWD_XNARY_GENERIC_DT(CONCAT3(PO_, idx, _ALG), accumulator, \
acc_elem_dt, ((acc_elem_dt *)(&accumulator)), \
(sizeof(accumulator) / sizeof(acc_elem_dt)), bin_arg_ptr, \
bin_arg_size, 0.0f, 0.0f, 1.0f); \
Expand All @@ -292,7 +283,7 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,

#define APPLY_PO_ELTWISE(idx, accumulator, acc_elem_dt) \
{ \
FWD_XNARY_GENERIC_DT(PO_ELTWISE, CONCAT3(PO_, idx, _ALG), accumulator, \
FWD_XNARY_GENERIC_DT(CONCAT3(PO_, idx, _ALG), accumulator, \
acc_elem_dt, ((acc_elem_dt *)(&accumulator)), \
(sizeof(accumulator) / sizeof(acc_elem_dt)), \
((acc_elem_dt *)(&accumulator)), \
Expand Down

0 comments on commit 87fd48f

Please sign in to comment.