From 796a600c3de2993b5d5819995ad13eb70d097496 Mon Sep 17 00:00:00 2001 From: Andrey Kalinin Date: Wed, 12 Apr 2023 17:30:20 -0700 Subject: [PATCH] x64: brgemm bwd_w convolution: update scratchpad data preparing --- src/cpu/x64/jit_brgemm_conv_bwd_w.cpp | 31 ++++++++++++++++++--------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/cpu/x64/jit_brgemm_conv_bwd_w.cpp b/src/cpu/x64/jit_brgemm_conv_bwd_w.cpp index 2c15154b806..cc2ae19ff11 100644 --- a/src/cpu/x64/jit_brgemm_conv_bwd_w.cpp +++ b/src/cpu/x64/jit_brgemm_conv_bwd_w.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022 Intel Corporation +* Copyright 2022-2023 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1358,15 +1358,26 @@ void brgemm_convolution_bwd_weights_t::prepare_scratchpad_data( const auto &jcp = pd()->jcp_; auto tr_src = scratchpad.template get(key_conv_tr_src); - // Zero out guard elements that cross a buffer boundary to prevent a - // race condition due to buffer overflows from memory optimization where - // buffers sharing padding - // TODO: optimize it - for (size_t isb = 1; isb <= jcp.tr_src_buf_count; ++isb) { - src_data_t *ts - = &tr_src[isb * jcp.tr_src_buf_size * jcp.nb_ic_blocking]; - for (int i = 0; i < jcp.tr_src_num_guard_elems; ++i) - ts[i] = 0; + const auto tr_src_full_size = jcp.tr_src_buf_size * jcp.nb_ic_blocking; + if (jcp.oh_block < jcp.oh || jcp.id > 1) { + // if (oh_block < oh) or (id > 1) then we zero all buffer because last + // elements position may vary depending on position of od_s, oh_block, + // padding and kh + parallel_nd(jcp.tr_src_buf_count, [&](size_t isb) { + src_data_t *ts = &tr_src[isb * tr_src_full_size]; + std::memset(ts, 0, jcp.src_dsz * tr_src_full_size); + }); + // Zero out last guard elements + src_data_t *ts = &tr_src[jcp.tr_src_buf_count * tr_src_full_size]; + std::memset(ts, 0, jcp.src_dsz * jcp.tr_src_num_guard_elems); + } else { + // Zero out guard elements that cross a buffer boundary to prevent a + // race condition due to buffer overflows from memory optimization where + // buffers sharing padding + parallel_nd(jcp.tr_src_buf_count, [&](size_t isb) { + src_data_t *ts = &tr_src[(isb + 1) * tr_src_full_size]; + std::memset(ts, 0, jcp.src_dsz * jcp.tr_src_num_guard_elems); + }); } if (jcp.global_transpose && jcp.nthr_oc_b > 1) {