diff --git a/src/cpu/x64/jit_brgemm_inner_product.cpp b/src/cpu/x64/jit_brgemm_inner_product.cpp index 09f31408ad9..dae79ae396e 100644 --- a/src/cpu/x64/jit_brgemm_inner_product.cpp +++ b/src/cpu/x64/jit_brgemm_inner_product.cpp @@ -1669,8 +1669,10 @@ void brgemm_inner_product_bwd_weights_t< const bool is_f32_out = jbgp.wei_dt == data_type::f32; const int icb_scale = is_f32_out ? jbgp.ic_block / jbgp.simd_w : 1; - const int icb_work = ti->ic_c_work * jbgp.nb_ic_blocking; - const int ocb_work = ti->oc_c_work * jbgp.nb_oc_blocking; + const int icb_work = nstl::min(ti->ic_c_work * jbgp.nb_ic_blocking, + jbgp.nb_ic - ti->ic_c_start * jbgp.nb_ic_blocking); + const int ocb_work = nstl::min(ti->oc_c_work * jbgp.nb_oc_blocking, + jbgp.nb_oc - ti->oc_c_start * jbgp.nb_oc_blocking); const int work = ocb_work * icb_work; int os_chunks = utils::div_up(jbgp.nb_os, jbgp.nb_os_blocking);