From 05d68df233bb67046697758962cd32bd6d23a956 Mon Sep 17 00:00:00 2001 From: Andrey Kalinin Date: Fri, 15 Nov 2024 17:59:06 -0800 Subject: [PATCH] x64: brgemm convolution: update req_cal_comp_pad condition --- src/cpu/x64/jit_brgemm_conv.cpp | 4 ++++ src/cpu/x64/jit_brgemm_conv_utils.cpp | 12 +++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/cpu/x64/jit_brgemm_conv.cpp b/src/cpu/x64/jit_brgemm_conv.cpp index a4a725645cf..0a31fca4672 100644 --- a/src/cpu/x64/jit_brgemm_conv.cpp +++ b/src/cpu/x64/jit_brgemm_conv.cpp @@ -1526,6 +1526,10 @@ status_t brgemm_convolution_fwd_t::cal_compensation( const int max_ker_sz = adjusted_k.size(); const auto comp_buffer_ow = jcp.exec_type != exec_vpad ? jcp.ow : 1; + // TODO: revise the thread distribution here because the work_amount may be + // insufficient + // TODO: revise comp_vpad_pbuffer_ generator to avoid huge code for cases + // with big ow const auto work_amount = static_cast(jcp.ngroups) * jcp.nb_oc * max_ker_sz; const auto is_small_shape = work_amount <= jcp.nthr diff --git a/src/cpu/x64/jit_brgemm_conv_utils.cpp b/src/cpu/x64/jit_brgemm_conv_utils.cpp index 3682e6409e1..4dd38041c6c 100644 --- a/src/cpu/x64/jit_brgemm_conv_utils.cpp +++ b/src/cpu/x64/jit_brgemm_conv_utils.cpp @@ -2296,12 +2296,18 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, // For padding shapes, we calculate the comp along with the computation // inside brgemm kernel when output size is small to get optimal perf - // Or we calculate the comp using brgemm_coomp_pad kernel + // For shapes with large ow we calculate the comp inside brgemm kernel too + // because current implementation of brgemm_comp_pad kernel unrolled by ow + // so not optimal for large ow. + // Otherwise we calculate the comp using brgemm_comp_pad kernel const auto output_sz = static_cast(jcp.mb) * jcp.ngroups * jcp.oc * jcp.od * jcp.oh * jcp.ow; + // TODO: revise below condition to avoid limitation for big ow + const auto shape_for_brgemm_kernel + = (output_sz <= 8192 && jcp.oc < 512) || jcp.ow > 128; + const auto is_relo = jcp.is_relo() && jcp.relo_conv_weights; jcp.req_brg_comp_pad = compensation_w_padding && jcp.exec_type != exec_trans - && IMPLICATION(!(jcp.is_relo() && jcp.relo_conv_weights), - output_sz <= 8192 && jcp.oc < 512); + && IMPLICATION(!is_relo, shape_for_brgemm_kernel); jcp.req_cal_comp_pad = compensation_w_padding && !jcp.req_brg_comp_pad && IMPLICATION(jcp.exec_type == exec_vpad, jcp.t_pad > 0 || jcp.b_pad > 0 || jcp.f_pad > 0