diff --git a/src/cpu/x64/jit_brgemm_conv_utils.cpp b/src/cpu/x64/jit_brgemm_conv_utils.cpp index d6d05242393..ae209175697 100644 --- a/src/cpu/x64/jit_brgemm_conv_utils.cpp +++ b/src/cpu/x64/jit_brgemm_conv_utils.cpp @@ -1897,7 +1897,12 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, best_brgb.oc_block = min_oc_block; brg_blocking_t cur_brgb = zero(); cur_brgb.get_from_jcp(jcp); - auto start_ocb = (is_amx(isa) && jcp.is_os_blocking) ? 2 : 4; + const int est_amx_job = div_up(jcp.mb * div_up(jcp.os, 4 * 16) + * jcp.ngroups * div_up(jcp.oc, 4 * 16), + nthreads); + const bool small_amx_job = est_amx_job < 64 || jcp.oc < 256; + auto start_ocb + = (is_amx(isa) && jcp.is_os_blocking && small_amx_job) ? 2 : 4; if (jcp.wei_plain) start_ocb = nstl::min(jcp.ic > 128 ? (jcp.ic > 256 ? 8 : 16) : 32, div_up(jcp.oc, 16));