diff --git a/src/cpu/x64/jit_brgemm_conv_bwd_w.cpp b/src/cpu/x64/jit_brgemm_conv_bwd_w.cpp index 666eed571ba..0c6237c5658 100644 --- a/src/cpu/x64/jit_brgemm_conv_bwd_w.cpp +++ b/src/cpu/x64/jit_brgemm_conv_bwd_w.cpp @@ -137,7 +137,8 @@ status_t brgemm_convolution_bwd_weights_t::pd_t::init(engine_t *engine) { brgattr.max_bottom_vpad = 0; brgattr.LDA2 = jcp_.tr_iw * jcp_.ih_block * jcp_.id; - brgattr.LDB2 = jcp_.tr_ow * jcp_.oc_block * jcp_.oh * jcp_.od; + brgattr.LDB2 + = jcp_.tr_ow * jcp_.oc_block * jcp_.oh_block * jcp_.od; brgattr.LDC2_M = jcp_.oc_block * jcp_.kd * jcp_.kh * jcp_.kw; brgattr.LDC2_N = jcp_.nb_ic * jcp_.ic_block * jcp_.oc_block * jcp_.kd * jcp_.kh * jcp_.kw; @@ -474,7 +475,7 @@ struct brgemm_convolution_bwd_weights_t::thread_info_t { size_t tr_diff_dst_off(int g, int ocb, int od, int oh) const { const size_t tr_row_size = jcp.tr_ow * jcp.oc_block; - const size_t tr_3d_size = tr_row_size * jcp.oh; + const size_t tr_3d_size = tr_row_size * jcp.oh_block; int adj = (jcp.global_transpose) ? 1 : jcp.nb_oc_blocking; return tr_diff_dst_buf_number(g, ocb) * adj * jcp.tr_diff_dst_buf_size + od * tr_3d_size + oh * tr_row_size; @@ -1026,7 +1027,7 @@ void brgemm_convolution_bwd_weights_t::compute_diff_weights_3d( + (bs_id_s - id_s) * jcp.ih_block * jcp.tr_iw * jcp.ic_block; const void *ptr_B = ((diff_dst_data_t *)p_dst) + (bs_oh_s - oh_s) * jcp.tr_ow * jcp.oc_block - + (bs_od_s - od_s) * jcp.oh * jcp.tr_ow * jcp.oc_block; + + (bs_od_s - od_s) * jcp.oh_block * jcp.tr_ow * jcp.oc_block; void *ptr_C = (jcp.transform_to_vnni) ? diff_wei + wei_offset_int(g, oc_b, ic_b, kd, kh, kw) : diff_wei @@ -1049,7 +1050,7 @@ void brgemm_convolution_bwd_weights_t::compute_diff_weights_3d( * jcp.ic_block * jcp.stride_d; ti->brg_batch[odb * bs_h + ohb].ptr.B = (char *)ptr_B + ohb * jcp.typesize_in * jcp.tr_ow * jcp.oc_block - + odb * jcp.typesize_in * jcp.oh * jcp.tr_ow + + odb * jcp.typesize_in * jcp.oh_block * jcp.tr_ow * jcp.oc_block; } } @@ -1127,8 +1128,8 @@ void brgemm_convolution_bwd_weights_t::compute_diff_weights_3d( && (odb_s == od_s) && (iodb == odb_s) && (ohb_s == oh_s); bp.dst = ((diff_dst_data_t *)p_dst) - + (iodb - od_s) * jcp.oh * jcp.tr_ow - * jcp.oc_block + + (iodb - od_s) * jcp.oh_block + * jcp.tr_ow * jcp.oc_block + (ohb_s - oh_s) * jcp.tr_ow * jcp.oc_block; (*diff_bias_kernel_)(&bp); diff --git a/src/cpu/x64/jit_brgemm_conv_utils.cpp b/src/cpu/x64/jit_brgemm_conv_utils.cpp index fff992e1740..a8899313a25 100644 --- a/src/cpu/x64/jit_brgemm_conv_utils.cpp +++ b/src/cpu/x64/jit_brgemm_conv_utils.cpp @@ -2760,7 +2760,7 @@ status_t init_conf_bwd_w(jit_brgemm_conv_conf_t &jcp, ? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups : jcp.nthr; jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih_block * jcp.id; - jcp.tr_diff_dst_buf_size = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od; + jcp.tr_diff_dst_buf_size = jcp.tr_ow * jcp.oc_block * jcp.oh_block * jcp.od; const int iframe_size = irow_size * jcp.id; const int oframe_size = orow_size * jcp.od;