diff --git a/src/gpu/ocl/ocl_post_ops.h b/src/gpu/ocl/ocl_post_ops.h index d007175d87a..0b2f0fcc5c5 100644 --- a/src/gpu/ocl/ocl_post_ops.h +++ b/src/gpu/ocl/ocl_post_ops.h @@ -66,21 +66,10 @@ float fwd_Xnary(unsigned algorithm, float x, float y, float alpha, float beta, = max((typeof(ty))arg0_len, (typeof(ty))arg1_len); \ result_elem_dt *res_ptr = (result_elem_dt *)(&result); \ unroll_for(typeof(out_len + 0) idx = 0; idx < out_len; ++idx) { \ - if (arg0_len == 1 && arg1_len == 1) { \ - *res_ptr = fwd_Xnary(algorithm, convert_float(*arg0_ptr), \ - convert_float(*arg1_ptr), alpha, beta, scale); \ - } else if (arg0_len == 1) { \ - res_ptr[idx] = fwd_Xnary(algorithm, convert_float(*arg0_ptr), \ - convert_float(arg1_ptr[idx]), alpha, beta, scale); \ - } else if (arg1_len == 1) { \ - res_ptr[idx] \ - = fwd_Xnary(algorithm, convert_float(arg0_ptr[idx]), \ - convert_float(*arg1_ptr), alpha, beta, scale); \ - } else { \ - res_ptr[idx] = fwd_Xnary(algorithm, \ - convert_float(arg0_ptr[idx]), \ - convert_float(arg1_ptr[idx]), alpha, beta, scale); \ - } \ + const int arg0_idx = arg0_len == 1 ? 0 : idx; \ + const int arg1_idx = arg1_len == 1 ? 0 : idx; \ + res_ptr[idx] = fwd_Xnary(algorithm, convert_float(arg0_ptr[idx]), \ + convert_float(arg1_ptr[idx]), alpha, beta, scale); \ } \ }