From 3ba7e8b9c14948da35c86d4d74725f0d23511fc8 Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Tue, 11 Apr 2023 14:22:26 -0700 Subject: [PATCH] gpu: jit: gemm: save ldco in fixed systolic inner loop --- src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp b/src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp index b465449c841..4163c8611e1 100644 --- a/src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp +++ b/src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp @@ -18794,6 +18794,7 @@ bool gemm_kernel_generator_t::sysgemmAccumulateC( auto diagC = saveData[1].ud(1); auto effCO = saveData[1].uq(1); auto slotAB = saveData[1].ud(4); + auto ldco = saveData[1].ud(5); auto effAs = saveData[1].uq(2).reinterpret(0, state.effA.getType()); auto effBs = saveData[1].uq(3).reinterpret(0, state.effB.getType()); auto saveI0 = saveData[1].ud(1); @@ -18825,6 +18826,7 @@ bool gemm_kernel_generator_t::sysgemmAccumulateC( if (state.effCO.isValid()) { effCO = effCO.reinterpret(0, state.effCO.getType()); emov(1, effCO, state.effCO, strategy, state); + if (state.inputs.ldco.isValid()) mov(1, ldco, state.inputs.ldco); } if (problem.hasBinaryPostOp()) { if (state.diagC.isValid()) stub(); @@ -18862,6 +18864,7 @@ bool gemm_kernel_generator_t::sysgemmAccumulateC( state.ra.release(state.remFusedStorage); state.ra.release(state.diagC); state.ra.release(state.effCO); + state.ra.release(state.inputs.ldco); state.ra.release(state.fusedGEMM.slotA); state.ra.release(state.fusedGEMM.slotB); @@ -18916,6 +18919,7 @@ bool gemm_kernel_generator_t::sysgemmAccumulateC( } if (state.diagC.isValid()) state.diagC = diagC; if (state.effCO.isValid()) state.effCO = effCO; + if (state.inputs.ldco.isValid()) state.inputs.ldco = ldco; if (state.fusedGEMM.slotA.isValid()) { state.fusedGEMM.slotA = slotAB.uw(0); state.fusedGEMM.slotB = slotAB.uw(1); @@ -20228,6 +20232,7 @@ bool gemm_kernel_generator_t::sysgemm2AccumulateC( auto effCO = saveData[1].uq(1); auto C_ptr = saveData[1].uq(2); auto slotAB = saveData[1].ud(6); + auto ldco = saveData[1].ud(7); auto effAs = a0.ud(4); // dwords 4-5 auto effBs = a0.ud(6); // dwords 6-7 auto saveI0 = saveData[1].ud(1); @@ -20260,6 +20265,7 @@ bool gemm_kernel_generator_t::sysgemm2AccumulateC( if (state.effCO.isValid()) { effCO = effCO.reinterpret(0, state.effCO.getType()); emov(1, effCO, state.effCO, strategy, state); + if (state.inputs.ldco.isValid()) mov(1, ldco, state.inputs.ldco); } if (problem.hasBinaryPostOp()) { if (state.diagC.isValid()) stub(); @@ -20303,6 +20309,7 @@ bool gemm_kernel_generator_t::sysgemm2AccumulateC( state.ra.release(state.remFusedStorage); state.ra.release(state.diagC); state.ra.release(state.effCO); + state.ra.release(state.inputs.ldco); state.ra.release(state.fusedGEMM.slotA); state.ra.release(state.fusedGEMM.slotB); @@ -20355,6 +20362,7 @@ bool gemm_kernel_generator_t::sysgemm2AccumulateC( } if (state.diagC.isValid()) state.diagC = diagC; if (state.effCO.isValid()) state.effCO = effCO; + if (state.inputs.ldco.isValid()) state.inputs.ldco = ldco; if (state.fusedGEMM.slotA.isValid()) { state.fusedGEMM.slotA = slotAB.uw(0); state.fusedGEMM.slotB = slotAB.uw(1);