From 21bdc21f37ff835b9ce54d4b713d7bfd65060e30 Mon Sep 17 00:00:00 2001 From: Andrey Kalinin Date: Sat, 18 Feb 2023 14:31:07 -0800 Subject: [PATCH] x64: brgemm bwd_w convolution: update threading for small minibatch --- src/cpu/x64/jit_brgemm_conv_utils.cpp | 155 ++++++++++++++++++++++++++ 1 file changed, 155 insertions(+) diff --git a/src/cpu/x64/jit_brgemm_conv_utils.cpp b/src/cpu/x64/jit_brgemm_conv_utils.cpp index a8899313a25..2d3348b2e68 100644 --- a/src/cpu/x64/jit_brgemm_conv_utils.cpp +++ b/src/cpu/x64/jit_brgemm_conv_utils.cpp @@ -2499,6 +2499,161 @@ void balance_bwd_w(jit_brgemm_conv_conf_t &jcp) { } nthr_ic_b = jcp.nthr / (nthr_mb * nthr_oc_b); nthr = nthr_mb * nthr_g * nthr_oc_b * nthr_ic_b; + } else if (is_amx(jcp.isa) && jcp.mb <= jcp.nthr / 2 && jcp.oc >= 64 + && jcp.ic >= 64 && jcp.ngroups == 1) { + // This heuristic is intended for usual convolutions if the minibatch + // is much less than the number of threads: it tries to divide the + // total amount of work into more-less 4-dimensional (by mb, g, oc, ic) + // "cubic" pieces + enum bwd_w_dims { g, ic, oc, sp }; + constexpr int nd = 4; + // Keep maximum values for each dimension as a map + std::map maxv; + maxv.emplace(bwd_w_dims::g, jcp.ngroups); + maxv.emplace(bwd_w_dims::ic, div_up(jcp.nb_ic, 2)); + maxv.emplace(bwd_w_dims::oc, div_up(jcp.nb_oc, 2)); + maxv.emplace(bwd_w_dims::sp, jcp.mb * jcp.od * jcp.oh); + + // Keep dimension values as a vector + std::vector> dv; + const auto ks = jcp.kd * jcp.kh * jcp.kw; + double v = (jcp.ngroups > 1) ? static_cast(jcp.ic) * jcp.oc + * jcp.ngroups * jcp.ngroups * ks + : 1; + dv.emplace_back(v, bwd_w_dims::g); + v = 5 * div_up(jcp.ic, jcp.amx_h) * ks; + dv.emplace_back(v, bwd_w_dims::ic); + v = 3 * div_up(jcp.oc, jcp.amx_h) * ks; + dv.emplace_back(v, bwd_w_dims::oc); + v = div_up(jcp.mb * jcp.od * jcp.oh * jcp.ow, jcp.amx_w); + dv.emplace_back(v, bwd_w_dims::sp); + // Estimate the size of "cubic" piece + double xd = 1; + for (int j = 0; j < nd; j++) + xd *= dv[j].first; + xd = pow(xd / jcp.nthr, 1.f / nd); + // Adjust piece to fit into dimensions + std::sort(dv.begin(), dv.end()); + double tot_v = 1; + for (int i = 0; i < nd; i++) { + auto &dvf = dv[i].first; + const auto &dvs = dv[i].second; + const auto maxvf = static_cast(maxv[dvs]); + if (dvf < xd) { + v = 1; + xd = 1; + for (int j = i + 1; j < nd; j++) + xd *= dv[j].first; + xd = pow(xd / jcp.nthr, 1.f / (nd - i - 1)); + } else { + v = nstl::min(dvf / xd, maxvf); + } + tot_v *= v; + dvf = v; + } + std::sort(dv.begin(), dv.end()); + + // Normalize dimension values so product should be ~= nthr + double knorm = pow(jcp.nthr / tot_v, 1.f / nd); + tot_v = 1; + for (int i = 0; i < nd; i++) { + auto &dvf = dv[i].first; + auto &dvs = dv[i].second; + const auto maxvf = static_cast(maxv[dvs]); + const auto new_dvf = dvf * knorm; + dvf = utils::saturate(1., maxvf, new_dvf); + knorm *= pow(new_dvf / dvf, 1.f / (nd - i - 1)); + tot_v *= dvf; + } + std::sort(dv.begin(), dv.end()); + knorm = jcp.nthr / tot_v; + for (int i = 0; i < nd; i++) { + auto &dvf = dv[i].first; + auto &dvs = dv[i].second; + const auto maxvf = static_cast(maxv[dvs]); + const auto new_dvf = dvf * knorm; + dvf = utils::saturate(1., maxvf, new_dvf); + knorm = new_dvf / dvf; + } + std::sort(dv.begin(), dv.end()); + + // Selecting the number of threads for every dimension closest to what + // we defined before + auto calc_diff = + [&](const std::vector> &cv) { + auto tot_n = 1; + double res = 1; + for (int i = 0; i < nd; i++) { + const auto nvf = dv[i].first; + const auto n = cv[i].first; + const auto v = maxv[cv[i].second]; + const auto disb + = nvf * static_cast(rnd_up(v, n)) / v; + const auto nf = static_cast(n); + const auto var = ((nf > nvf) ? (nf / nvf) : (nvf / nf)); + tot_n *= n; + res *= disb * var; + } + const auto thr_disb = static_cast(jcp.nthr) / tot_n; + return res * thr_disb; + }; + + // nv: vector to keep result of selection + std::vector> nv; + // Initial vector and estimation + for (int i = 0; i < nd; i++) { + const auto dvf = dv[i].first; + const auto dvs = dv[i].second; + const auto maxvf = maxv[dvs]; + nv.emplace_back( + utils::saturate(1, maxvf, static_cast(dvf + 0.5f)), + dvs); + } + nv[nd - 1].first = jcp.nthr / (nv[0].first * nv[1].first * nv[2].first); + double best_diff = calc_diff(nv); + + // Iterate through all combinations of numbers + std::vector> cv = nv; + const auto n0_max = jcp.nthr; + for (int n0 = 1; n0 <= n0_max; n0++) { + if (n0 > maxv[dv[0].second]) continue; + cv[0].first = n0; + const auto n1_max = n0_max / n0; + for (int n1 = 1; n1 <= n1_max; n1++) { + if (n1 > maxv[dv[1].second]) continue; + cv[1].first = n1; + const auto n2_max = n1_max / n1; + for (int n2 = 1; n2 <= n2_max; n2++) { + if (n2 > maxv[dv[2].second]) continue; + cv[2].first = n2; + const auto n3_max = n2_max / n2; + for (int n3 = n3_max; n3 >= 1; n3--) { + if (n3 > maxv[dv[3].second]) continue; + cv[3].first = n3; + const auto tot_n = n0 * n1 * n2 * n3; + const auto cdiff = calc_diff(cv); + if (cdiff < best_diff && tot_n <= jcp.nthr) { + best_diff = cdiff; + nv = cv; + } + } + } + } + } + + for (size_t i = 0; i < nd; i++) { + const auto &nvf = nv[i].first; + const auto &nvs = nv[i].second; + if (nvs == bwd_w_dims::g) + nthr_g = nvf; + else if (nvs == bwd_w_dims::ic) + nthr_ic_b = nvf; + else if (nvs == bwd_w_dims::oc) + nthr_oc_b = nvf; + else if (nvs == bwd_w_dims::sp) + nthr_mb = nvf; + } + nthr = nthr_mb * nthr_g * nthr_oc_b * nthr_ic_b; } else if (jcp.ngroups == 1 && (jcp.oc > 2048 || jcp.ic > 2048)) { const bool more_oc = (jcp.ic < jcp.oc); if (more_oc) {