Skip to content

Commit

Permalink
Revert "gpu: ocl: remove unecessary post_op kind"
Browse files Browse the repository at this point in the history
This reverts commit 00d1bac.
  • Loading branch information
dyoussif authored and karturov committed May 31, 2023
1 parent ad3c62f commit c8943f5
Showing 1 changed file with 28 additions and 21 deletions.
49 changes: 28 additions & 21 deletions src/gpu/ocl/ocl_post_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,30 @@
#include "gpu/ocl/ocl_eltwise.h"
#include "gpu/ocl/ocl_types.h"

float fwd_Xnary(unsigned algorithm, float x, float y, float alpha, float beta,
float scale) {
switch (algorithm) {
// binary
case BINARY_ADD: return x + y; break;
case BINARY_MUL: return x * y; break;
case BINARY_MIN: return x < y ? x : y; break;
case BINARY_MAX: return x > y ? x : y; break;
case BINARY_DIV: return x / y; break;
case BINARY_SUB: return x - y; break;
case BINARY_GE: return x >= y; break;
case BINARY_GT: return x > y; break;
case BINARY_LE: return x <= y; break;
case BINARY_LT: return x < y; break;
case BINARY_EQ: return x == y; break;
case BINARY_NE: return x != y; break;
default: return fwd_eltwise_common(algorithm, x, alpha, beta, scale);
float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
float alpha, float beta, float scale) {
if (kind == PO_BINARY) {
switch (algorithm) {
// binary
case BINARY_ADD: return x + y; break;
case BINARY_MUL: return x * y; break;
case BINARY_MIN: return x < y ? x : y; break;
case BINARY_MAX: return x > y ? x : y; break;
case BINARY_DIV: return x / y; break;
case BINARY_SUB: return x - y; break;
case BINARY_GE: return x >= y; break;
case BINARY_GT: return x > y; break;
case BINARY_LE: return x <= y; break;
case BINARY_LT: return x < y; break;
case BINARY_EQ: return x == y; break;
case BINARY_NE: return x != y; break;
case RELU: // binary && relu = prelu
return fwd_eltwise_common(RELU, x, y, beta, scale);
break;
default: return 0.f;
}
} else { // eltwise kind
return fwd_eltwise_common(algorithm, x, alpha, beta, scale);
}
}

Expand All @@ -58,8 +65,8 @@ float fwd_Xnary(unsigned algorithm, float x, float y, float alpha, float beta,
ret_val; \
})

#define FWD_XNARY_GENERIC_DT(algorithm, result, result_elem_dt, arg0_ptr, \
arg0_len, arg1_ptr, arg1_len, alpha, beta, scale) \
#define FWD_XNARY_GENERIC_DT(po_kind, algorithm, result, result_elem_dt, \
arg0_ptr, arg0_len, arg1_ptr, arg1_len, alpha, beta, scale) \
{ \
auto ty = arg0_len + arg1_len; \
const typeof(ty) out_len \
Expand Down Expand Up @@ -258,7 +265,7 @@ float fwd_Xnary(unsigned algorithm, float x, float y, float alpha, float beta,
REPLICATE_DATA(bin_arg_ptr, bin_arg_size, x0_s, x1_size, x2_s, \
x3_s, x4_s, x5_s); \
} \
FWD_XNARY_GENERIC_DT(CONCAT3(PO_, idx, _ALG), accumulator, \
FWD_XNARY_GENERIC_DT(PO_BINARY, CONCAT3(PO_, idx, _ALG), accumulator, \
acc_elem_dt, ((acc_elem_dt *)(&accumulator)), \
(sizeof(accumulator) / sizeof(acc_elem_dt)), bin_arg_ptr, \
bin_arg_size, 0.0f, 0.0f, 1.0f); \
Expand All @@ -273,7 +280,7 @@ float fwd_Xnary(unsigned algorithm, float x, float y, float alpha, float beta,

#define APPLY_PO_ELTWISE(idx, accumulator, acc_elem_dt) \
{ \
FWD_XNARY_GENERIC_DT(CONCAT3(PO_, idx, _ALG), accumulator, \
FWD_XNARY_GENERIC_DT(PO_ELTWISE, CONCAT3(PO_, idx, _ALG), accumulator, \
acc_elem_dt, ((acc_elem_dt *)(&accumulator)), \
(sizeof(accumulator) / sizeof(acc_elem_dt)), \
((acc_elem_dt *)(&accumulator)), \
Expand Down

0 comments on commit c8943f5

Please sign in to comment.