diff --git a/src/gpu/ocl/ocl_post_ops.h b/src/gpu/ocl/ocl_post_ops.h index 38c58872c36..5443c896e48 100644 --- a/src/gpu/ocl/ocl_post_ops.h +++ b/src/gpu/ocl/ocl_post_ops.h @@ -27,30 +27,23 @@ #include "gpu/ocl/ocl_eltwise.h" #include "gpu/ocl/ocl_types.h" -float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y, - float alpha, float beta, float scale) { - if (kind == PO_BINARY) { - switch (algorithm) { - // binary - case BINARY_ADD: return x + y; break; - case BINARY_MUL: return x * y; break; - case BINARY_MIN: return x < y ? x : y; break; - case BINARY_MAX: return x > y ? x : y; break; - case BINARY_DIV: return x / y; break; - case BINARY_SUB: return x - y; break; - case BINARY_GE: return x >= y; break; - case BINARY_GT: return x > y; break; - case BINARY_LE: return x <= y; break; - case BINARY_LT: return x < y; break; - case BINARY_EQ: return x == y; break; - case BINARY_NE: return x != y; break; - case RELU: // binary && relu = prelu - return fwd_eltwise_common(RELU, x, y, beta, scale); - break; - default: return 0.f; - } - } else { // eltwise kind - return fwd_eltwise_common(algorithm, x, alpha, beta, scale); +float fwd_Xnary(unsigned algorithm, float x, float y, float alpha, float beta, + float scale) { + switch (algorithm) { + // binary + case BINARY_ADD: return x + y; break; + case BINARY_MUL: return x * y; break; + case BINARY_MIN: return x < y ? x : y; break; + case BINARY_MAX: return x > y ? x : y; break; + case BINARY_DIV: return x / y; break; + case BINARY_SUB: return x - y; break; + case BINARY_GE: return x >= y; break; + case BINARY_GT: return x > y; break; + case BINARY_LE: return x <= y; break; + case BINARY_LT: return x < y; break; + case BINARY_EQ: return x == y; break; + case BINARY_NE: return x != y; break; + default: return fwd_eltwise_common(algorithm, x, alpha, beta, scale); } } @@ -65,8 +58,8 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y, ret_val; \ }) -#define FWD_XNARY_GENERIC_DT(po_kind, algorithm, result, result_elem_dt, \ - arg0_ptr, arg0_len, arg1_ptr, arg1_len, alpha, beta, scale) \ +#define FWD_XNARY_GENERIC_DT(algorithm, result, result_elem_dt, arg0_ptr, \ + arg0_len, arg1_ptr, arg1_len, alpha, beta, scale) \ { \ auto ty = arg0_len + arg1_len; \ const typeof(ty) out_len \ @@ -74,19 +67,17 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y, result_elem_dt *res_ptr = (result_elem_dt *)(&result); \ unroll_for(typeof(out_len + 0) idx = 0; idx < out_len; ++idx) { \ if (arg0_len == 1 && arg1_len == 1) { \ - *res_ptr = fwd_Xnary(po_kind, algorithm, \ - convert_float(*arg0_ptr), convert_float(*arg1_ptr), \ - alpha, beta, scale); \ + *res_ptr = fwd_Xnary(algorithm, convert_float(*arg0_ptr), \ + convert_float(*arg1_ptr), alpha, beta, scale); \ } else if (arg0_len == 1) { \ - res_ptr[idx] = fwd_Xnary(po_kind, algorithm, \ - convert_float(*arg0_ptr), \ + res_ptr[idx] = fwd_Xnary(algorithm, convert_float(*arg0_ptr), \ convert_float(arg1_ptr[idx]), alpha, beta, scale); \ } else if (arg1_len == 1) { \ - res_ptr[idx] = fwd_Xnary(po_kind, algorithm, \ - convert_float(arg0_ptr[idx]), \ - convert_float(*arg1_ptr), alpha, beta, scale); \ + res_ptr[idx] \ + = fwd_Xnary(algorithm, convert_float(arg0_ptr[idx]), \ + convert_float(*arg1_ptr), alpha, beta, scale); \ } else { \ - res_ptr[idx] = fwd_Xnary(po_kind, algorithm, \ + res_ptr[idx] = fwd_Xnary(algorithm, \ convert_float(arg0_ptr[idx]), \ convert_float(arg1_ptr[idx]), alpha, beta, scale); \ } \ @@ -277,7 +268,7 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y, REPLICATE_DATA(bin_arg_ptr, bin_arg_size, x0_s, x1_size, x2_s, \ x3_s, x4_s, x5_s); \ } \ - FWD_XNARY_GENERIC_DT(PO_BINARY, CONCAT3(PO_, idx, _ALG), accumulator, \ + FWD_XNARY_GENERIC_DT(CONCAT3(PO_, idx, _ALG), accumulator, \ acc_elem_dt, ((acc_elem_dt *)(&accumulator)), \ (sizeof(accumulator) / sizeof(acc_elem_dt)), bin_arg_ptr, \ bin_arg_size, 0.0f, 0.0f, 1.0f); \ @@ -292,7 +283,7 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y, #define APPLY_PO_ELTWISE(idx, accumulator, acc_elem_dt) \ { \ - FWD_XNARY_GENERIC_DT(PO_ELTWISE, CONCAT3(PO_, idx, _ALG), accumulator, \ + FWD_XNARY_GENERIC_DT(CONCAT3(PO_, idx, _ALG), accumulator, \ acc_elem_dt, ((acc_elem_dt *)(&accumulator)), \ (sizeof(accumulator) / sizeof(acc_elem_dt)), \ ((acc_elem_dt *)(&accumulator)), \