Skip to content

Commit

Permalink
api: eltwise: rename soft_relu_v2
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin authored and densamoilov committed Oct 3, 2022
1 parent 8cdeb8f commit 22b5364
Show file tree
Hide file tree
Showing 37 changed files with 123 additions and 118 deletions.
5 changes: 2 additions & 3 deletions doc/primitives/eltwise.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ The following operations are supported:
| pow | #dnnl_eltwise_pow | \f$ d = \alpha s^{\beta} \f$ | \f$ ds = dd \cdot \alpha \beta s^{\beta - 1} \f$ | -- |
| relu | #dnnl_eltwise_relu <br> #dnnl_eltwise_relu_use_dst_for_bwd | \f$ d = \begin{cases} s & \text{if}\ s > 0 \\ \alpha s & \text{if}\ s \leq 0 \end{cases} \f$ | \f$ ds = \begin{cases} dd & \text{if}\ s > 0 \\ \alpha \cdot dd & \text{if}\ s \leq 0 \end{cases} \f$ | \f$ ds = \begin{cases} dd & \text{if}\ d > 0 \\ \alpha \cdot dd & \text{if}\ d \leq 0 \end{cases}. See\ (2). \f$ |
| round | #dnnl_eltwise_round | \f$ d = round(s) \f$ | -- | -- |
| soft_relu_v2 | #dnnl_eltwise_soft_relu_v2 | \f$ d =\frac{1}{\alpha} \log_{e}(1+e^{\alpha s}) \f$ | \f$ ds = \frac{dd}{1 + e^{-\alpha s}} \f$ | -- |
| soft_relu | #dnnl_eltwise_soft_relu | \f$ d =\frac{1}{\alpha} \log_{e}(1+e^{\alpha s}) \f$ | \f$ ds = \frac{dd}{1 + e^{-\alpha s}} \f$ | -- |
| sqrt | #dnnl_eltwise_sqrt <br> #dnnl_eltwise_sqrt_use_dst_for_bwd | \f$ d = \sqrt{s} \f$ | \f$ ds = \frac{dd}{2\sqrt{s}} \f$ | \f$ ds = \frac{dd}{2d} \f$ |
| square | #dnnl_eltwise_square | \f$ d = s^2 \f$ | \f$ ds = dd \cdot 2 s \f$ | -- |
| swish | #dnnl_eltwise_swish | \f$ d = \frac{s}{1+e^{-\alpha s}} \f$ | \f$ ds = \frac{dd}{1 + e^{-\alpha s}}(1 + \alpha s (1 - \frac{1}{1 + e^{-\alpha s}})) \f$ | -- |
Expand All @@ -54,8 +54,7 @@ The following operations are supported:
\f$ (3)\ \text{where, } \omega = e^{3s} + 4 \cdot e^{2s} + e^{s} \cdot (4 \cdot s + 6) + 4 \cdot (s + 1) \text{ and } \delta = e^{2s} + 2 \cdot e^{s} + 2. \f$

Note that following equations hold:
* \f$ soft\_relu(s) = soft\_relu\_v2(s, 1) \f$
* \f$ logsigmoid(s) = soft\_relu\_v2(s, -1) \f$
* \f$ logsigmoid(s) = soft\_relu(s, -1) \f$

#### Difference Between Forward Training and Forward Inference

Expand Down
4 changes: 2 additions & 2 deletions include/oneapi/dnnl/dnnl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,8 @@ enum class algorithm {
eltwise_linear = dnnl_eltwise_linear,
/// Elementwise: bounded_relu
eltwise_bounded_relu = dnnl_eltwise_bounded_relu,
/// Elementwise: soft_relu version 2
eltwise_soft_relu_v2 = dnnl_eltwise_soft_relu_v2,
/// Elementwise: soft_relu
eltwise_soft_relu = dnnl_eltwise_soft_relu,
/// Elementwise: mish
eltwise_mish = dnnl_eltwise_mish,
/// Elementwise: logistic
Expand Down
4 changes: 2 additions & 2 deletions include/oneapi/dnnl/dnnl_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -1456,8 +1456,8 @@ typedef enum {
dnnl_eltwise_linear = 0x7f,
/// Eltwise: bounded_relu
dnnl_eltwise_bounded_relu = 0x8f,
/// Eltwise: soft_relu version 2
dnnl_eltwise_soft_relu_v2 = 0xa0,
/// Eltwise: soft_relu
dnnl_eltwise_soft_relu = 0xa0,
/// Eltwise: hardsigmoid
dnnl_eltwise_hardsigmoid = 0xa1,
/// Eltwise: logistic
Expand Down
2 changes: 1 addition & 1 deletion src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ const alg_kind_t eltwise_sqrt = dnnl_eltwise_sqrt;
const alg_kind_t eltwise_swish = dnnl_eltwise_swish;
const alg_kind_t eltwise_linear = dnnl_eltwise_linear;
const alg_kind_t eltwise_bounded_relu = dnnl_eltwise_bounded_relu;
const alg_kind_t eltwise_soft_relu_v2 = dnnl_eltwise_soft_relu_v2;
const alg_kind_t eltwise_soft_relu = dnnl_eltwise_soft_relu;
const alg_kind_t eltwise_logistic = dnnl_eltwise_logistic;
const alg_kind_t eltwise_mish = dnnl_eltwise_mish;
const alg_kind_t eltwise_exp = dnnl_eltwise_exp;
Expand Down
2 changes: 1 addition & 1 deletion src/common/dnnl_debug_autogenerated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ const char *dnnl_alg_kind2str(dnnl_alg_kind_t v) {
if (v == dnnl_eltwise_sqrt) return "eltwise_sqrt";
if (v == dnnl_eltwise_linear) return "eltwise_linear";
if (v == dnnl_eltwise_bounded_relu) return "eltwise_bounded_relu";
if (v == dnnl_eltwise_soft_relu_v2) return "eltwise_soft_relu_v2";
if (v == dnnl_eltwise_soft_relu) return "eltwise_soft_relu";
if (v == dnnl_eltwise_hardsigmoid) return "eltwise_hardsigmoid";
if (v == dnnl_eltwise_logistic) return "eltwise_logistic";
if (v == dnnl_eltwise_exp) return "eltwise_exp";
Expand Down
2 changes: 1 addition & 1 deletion src/common/eltwise_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ struct eltwise_bwd_pd_t : public eltwise_pd_t {
eltwise_clip_v2, eltwise_elu, eltwise_exp,
eltwise_gelu_erf, eltwise_gelu_tanh, eltwise_hardsigmoid,
eltwise_linear, eltwise_logistic, eltwise_mish,
eltwise_relu, eltwise_soft_relu_v2, eltwise_square,
eltwise_relu, eltwise_soft_relu, eltwise_square,
eltwise_swish, eltwise_tanh)
|| one_of(alg, eltwise_elu_use_dst_for_bwd,
eltwise_exp_use_dst_for_bwd,
Expand Down
12 changes: 6 additions & 6 deletions src/common/math_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,27 +246,27 @@ inline U logistic_bwd_use_dst(T dd, T d) {

template <typename T, typename A,
typename U = typename utils::remove_reference<T>::type>
inline U soft_relu_v2_fwd(T s, A alpha) {
inline U soft_relu_fwd(T s, A alpha) {
float exp_overflow_bound = 88.72283172607421875;
float in = (float)s * (float)alpha;
float v = (in < exp_overflow_bound ? (U)(::log1pf(::expf(in))) : (U)in);
return (U)(v / alpha);
}
template <typename T, typename A,
typename U = typename utils::remove_reference<T>::type>
inline U soft_relu_v2_bwd(T dd, T s, A alpha) {
inline U soft_relu_bwd(T dd, T s, A alpha) {
float in = (float)s * (float)alpha;
return (U)(dd * logistic_fwd<float>(in));
}

template <typename T, typename U = typename utils::remove_reference<T>::type>
inline U mish_fwd(T s) {
return s * tanh_fwd(soft_relu_v2_fwd(s, 1.f));
return s * tanh_fwd(soft_relu_fwd(s, 1.f));
}
template <typename T, typename U = typename utils::remove_reference<T>::type>
inline U mish_bwd(T dd, T s) {
const float tanh = tanh_fwd(soft_relu_v2_fwd(s, 1.f));
const float srelu_bwd = soft_relu_v2_bwd(1.f, s, 1.f);
const float tanh = tanh_fwd(soft_relu_fwd(s, 1.f));
const float srelu_bwd = soft_relu_bwd(1.f, s, 1.f);
const float derivative = tanh + s * srelu_bwd * (1 - ::powf(tanh, 2.0f));
return dd * derivative;
}
Expand Down Expand Up @@ -411,7 +411,7 @@ inline bool is_eltwise_ok(
const bool eltwise_use_src
= one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu,
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
eltwise_bounded_relu, eltwise_soft_relu_v2, eltwise_mish,
eltwise_bounded_relu, eltwise_soft_relu, eltwise_mish,
eltwise_logistic, eltwise_exp, eltwise_gelu_tanh,
eltwise_hardsigmoid, eltwise_hardswish, eltwise_swish,
eltwise_log, eltwise_clip, eltwise_clip_v2, eltwise_pow,
Expand Down
4 changes: 2 additions & 2 deletions src/common/opdesc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ struct eltwise_desc_t {
// #dnnl_eltwise_tanh, #dnnl_eltwise_elu, #dnnl_eltwise_square,
// #dnnl_eltwise_abs, #dnnl_eltwise_sqrt, #dnnl_eltwise_linear,
// #dnnl_eltwise_bounded_relu,
// #dnnl_eltwise_soft_relu_v2, #dnnl_eltwise_logistic, #dnnl_eltwise_exp,
// #dnnl_eltwise_soft_relu, #dnnl_eltwise_logistic, #dnnl_eltwise_exp,
// #dnnl_eltwise_gelu_tanh, #dnnl_eltwise_swish, #dnnl_eltwise_log,
// #dnnl_eltwise_clip, #dnnl_eltwise_clip_v2, #dnnl_eltwise_pow,
// #dnnl_eltwise_gelu_erf, #dnnl_eltwise_round,
Expand All @@ -222,7 +222,7 @@ struct eltwise_desc_t {
// - #dnnl_eltwise_sqrt: @p alpha and @p beta ignored
// - #dnnl_eltwise_linear: @p alpha -- scale, @p beta -- shift
// - #dnnl_eltwise_bounded_relu: @p alpha -- upper bound, @p beta ignored
// - #dnnl_eltwise_soft_relu_v2: @p alpha -- soft_relu_v2 arg scaling, @p beta ignored
// - #dnnl_eltwise_soft_relu: @p alpha -- soft_relu arg scaling, @p beta ignored
// - #dnnl_eltwise_logistic: @p alpha and @p beta ignored
// - #dnnl_eltwise_exp: @p alpha and @p beta ignored
// - #dnnl_eltwise_gelu_tanh: @p alpha and @p beta ignored
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/aarch64/acl_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ status_t convert_to_acl_act(alg_kind_t eltwise_alg, float alpha, float beta,
case eltwise_bounded_relu:
act_info = ActivationLayerInfo(act_func::BOUNDED_RELU, alpha, beta);
break;
case eltwise_soft_relu_v2:
case eltwise_soft_relu:
act_info = ActivationLayerInfo(act_func::SOFT_RELU, alpha, beta);
break;
case eltwise_logistic:
Expand Down
10 changes: 5 additions & 5 deletions src/cpu/aarch64/injectors/jit_uni_eltwise_injector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ bool is_alg_supported(alg_kind_t alg) {
using namespace alg_kind;
return utils::one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu,
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
eltwise_bounded_relu, /*eltwise_soft_relu_v2,*/
eltwise_bounded_relu, /*eltwise_soft_relu,*/
eltwise_logistic, /*eltwise_mish,*/ eltwise_exp, eltwise_gelu_tanh,
/*eltwise_hardswish,*/ eltwise_swish, eltwise_log, eltwise_clip,
/*eltwise_clip_v2, eltwise_pow,*/ eltwise_gelu_erf, eltwise_round,
Expand Down Expand Up @@ -257,7 +257,7 @@ void jit_uni_eltwise_injector_f32<isa>::set_coef_to_regs() {
table_val(beta, vmm_aux0);
break;
case eltwise_bounded_relu: table_val(alpha, z_tmp); break;
// case eltwise_soft_relu_v2: // TODO: enable me.
// case eltwise_soft_relu: // TODO: enable me.
case eltwise_logistic_use_dst_for_bwd:
case eltwise_logistic:
case eltwise_exp_use_dst_for_bwd:
Expand Down Expand Up @@ -287,7 +287,7 @@ void jit_uni_eltwise_injector_f32<isa>::set_coef_to_regs() {
case eltwise_sqrt:
case eltwise_linear:
case eltwise_bounded_relu:
// case eltwise_soft_relu_v2:
// case eltwise_soft_relu:
case eltwise_logistic_use_dst_for_bwd:
case eltwise_logistic:
case eltwise_exp_use_dst_for_bwd:
Expand Down Expand Up @@ -1297,7 +1297,7 @@ size_t jit_uni_eltwise_injector_f32<isa>::aux_vecs_count() {
case eltwise_sqrt: return 0;
case eltwise_linear: return 2;
case eltwise_bounded_relu: return 1;
// case eltwise_soft_relu_v2: return 5;
// case eltwise_soft_relu: return 5;
case eltwise_logistic_use_dst_for_bwd:
case eltwise_logistic: return 5; /* = exp + 1 */
case eltwise_exp_use_dst_for_bwd:
Expand All @@ -1324,7 +1324,7 @@ size_t jit_uni_eltwise_injector_f32<isa>::aux_vecs_count() {
case eltwise_sqrt: return 2;
case eltwise_linear: return 1;
case eltwise_bounded_relu: return 1;
// case eltwise_soft_relu_v2: return 5; /* = logistic */
// case eltwise_soft_relu: return 5; /* = logistic */
case eltwise_logistic_use_dst_for_bwd: return 2;
case eltwise_logistic: return 5; /* = logistic */
case eltwise_exp_use_dst_for_bwd: return 0;
Expand Down
4 changes: 2 additions & 2 deletions src/cpu/aarch64/jit_uni_eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ status_t jit_uni_eltwise_fwd_t<isa, d_type>::pd_t::init(engine_t *engine) {
eltwise_relu, eltwise_elu_use_dst_for_bwd, eltwise_elu,
eltwise_tanh_use_dst_for_bwd, eltwise_tanh, eltwise_square,
eltwise_abs, eltwise_sqrt_use_dst_for_bwd, eltwise_sqrt,
eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu_v2,
eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu,
eltwise_logistic_use_dst_for_bwd, eltwise_logistic,
eltwise_exp_use_dst_for_bwd, eltwise_exp, eltwise_gelu_tanh,
eltwise_swish, eltwise_log, eltwise_clip, eltwise_gelu_erf,
Expand Down Expand Up @@ -279,7 +279,7 @@ status_t jit_uni_eltwise_bwd_t<isa, d_type>::pd_t::init(engine_t *engine) {
eltwise_relu, eltwise_elu_use_dst_for_bwd, eltwise_elu,
eltwise_tanh_use_dst_for_bwd, eltwise_tanh, eltwise_square,
eltwise_abs, eltwise_sqrt_use_dst_for_bwd, eltwise_sqrt,
eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu_v2,
eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu,
eltwise_logistic_use_dst_for_bwd, eltwise_logistic,
eltwise_exp_use_dst_for_bwd, eltwise_exp, eltwise_gelu_tanh,
eltwise_swish, eltwise_log, eltwise_clip, eltwise_gelu_erf);
Expand Down
6 changes: 3 additions & 3 deletions src/cpu/primitive_attr_postops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ float compute_eltwise_scalar_fwd(
case eltwise_sqrt: d = sqrt_fwd(s); break;
case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
case eltwise_bounded_relu: d = bounded_relu_fwd(s, alpha); break;
case eltwise_soft_relu_v2: d = soft_relu_v2_fwd(s, alpha); break;
case eltwise_soft_relu: d = soft_relu_fwd(s, alpha); break;
case eltwise_logistic: d = logistic_fwd(s); break;
case eltwise_exp: d = exp_fwd(s); break;
case eltwise_gelu_tanh: d = gelu_tanh_fwd(s); break;
Expand Down Expand Up @@ -97,7 +97,7 @@ float compute_eltwise_scalar_bwd(
case eltwise_sqrt: ds = sqrt_bwd(dd, s); break;
case eltwise_linear: ds = linear_bwd(dd, s, alpha, beta); break;
case eltwise_bounded_relu: ds = bounded_relu_bwd(dd, s, alpha); break;
case eltwise_soft_relu_v2: ds = soft_relu_v2_bwd(dd, s, alpha); break;
case eltwise_soft_relu: ds = soft_relu_bwd(dd, s, alpha); break;
case eltwise_logistic: ds = logistic_bwd(dd, s); break;
case eltwise_exp: ds = exp_bwd(dd, s); break;
case eltwise_gelu_tanh: ds = gelu_tanh_bwd(dd, s); break;
Expand Down Expand Up @@ -154,7 +154,7 @@ ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(
: alg_(alg), alpha_(alpha), beta_(beta), scale_(scale) {
assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
eltwise_bounded_relu, eltwise_soft_relu_v2, eltwise_mish,
eltwise_bounded_relu, eltwise_soft_relu, eltwise_mish,
eltwise_logistic, eltwise_exp, eltwise_gelu_tanh, eltwise_swish,
eltwise_log, eltwise_clip, eltwise_clip_v2, eltwise_pow,
eltwise_gelu_erf, eltwise_round, eltwise_hardsigmoid,
Expand Down
Loading

0 comments on commit 22b5364

Please sign in to comment.