diff --git a/src/plugins/intel_cpu/src/emitters/tpp/aarch64/jit_brgemm_emitter.cpp b/src/plugins/intel_cpu/src/emitters/tpp/aarch64/jit_brgemm_emitter.cpp index a0e66ec14061bd..6997c8d0f25389 100644 --- a/src/plugins/intel_cpu/src/emitters/tpp/aarch64/jit_brgemm_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/tpp/aarch64/jit_brgemm_emitter.cpp @@ -37,8 +37,8 @@ std::set> jit_brgemm_emitter::get_supported_precision } void jit_brgemm_emitter::validate_arguments(const std::vector& in, const std::vector& out) const { - OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got" + std::to_string(in.size())); - OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got" + std::to_string(out.size())); + OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got", in.size()); + OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got", out.size()); } void jit_brgemm_emitter::emit_code_impl(const std::vector& in, diff --git a/src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.cpp b/src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.cpp index 2e29f3c92b64df..00f26e4a641013 100644 --- a/src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.cpp +++ b/src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.cpp @@ -29,6 +29,13 @@ size_t BrgemmKernelConfig::compute_hash() const { void BrgemmKernelConfig::update(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC, float beta) { BrgemmGenericKernelConfig::update(M, N, K, LDA, LDB, LDC, beta); + // update compile flag, which should be reset depend on beta. It is combination of beta and static_compile_flag and + // considered in hash() and operator== + libxsmm_bitfield new_flag = get_static_compile_flags(); + if (beta == 0) { + new_flag |= LIBXSMM_GEMM_FLAG_BETA_0; + } + set_compile_flags(new_flag); m_hash = compute_hash(); } @@ -136,8 +143,6 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression } config.update(M, N, K, io_strides[0], io_strides[1], io_strides[2], beta); - // update compile flag, which is depend on beta. should be part of hash. - config.set_compile_flags_with_zero_beta(config.get_beta() == 0); } void BrgemmKernelExecutor::execute(const BrgemmKernelExecutor* executor, void* in0, void* in1, void* out0) { diff --git a/src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.hpp b/src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.hpp index 47952a212bed2a..169f2df364f092 100644 --- a/src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.hpp +++ b/src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.hpp @@ -35,13 +35,6 @@ struct BrgemmKernelConfig : public BrgemmGenericKernelConfig { libxsmm_bitfield get_compile_flags() const { return m_compile_flags; } - void set_compile_flags_with_zero_beta(bool zero_beta) { - if (zero_beta) { - m_compile_flags = m_static_params->m_compile_flags | LIBXSMM_GEMM_FLAG_BETA_0; - } else { - m_compile_flags = m_static_params->m_compile_flags; - } - } bool get_prefetching_flags() const { return m_static_params->m_prefetching_flags; } @@ -57,6 +50,9 @@ struct BrgemmKernelConfig : public BrgemmGenericKernelConfig { libxsmm_datatype get_type_exec() const { return m_static_params->m_type_exec; } + libxsmm_bitfield get_static_compile_flags() const { + return m_static_params->m_compile_flags; + } #ifdef SNIPPETS_DEBUG_CAPS std::string to_string() const override; #endif @@ -64,7 +60,6 @@ struct BrgemmKernelConfig : public BrgemmGenericKernelConfig { private: struct StaticParams { StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype); - virtual ~StaticParams() = default; bool operator==(const StaticParams& rhs) const; bool operator!=(const StaticParams& rhs) const { @@ -88,10 +83,13 @@ struct BrgemmKernelConfig : public BrgemmGenericKernelConfig { size_t m_hash{SIZE_MAX}; }; + const std::shared_ptr& get_static_params() const { return m_static_params; } - + void set_compile_flags(const libxsmm_bitfield& compile_flags) { + m_compile_flags = compile_flags; + } libxsmm_bitfield m_compile_flags{0}; std::shared_ptr m_static_params{nullptr}; @@ -109,7 +107,6 @@ struct BrgemmTppCompiledKernel { class BrgemmKernelExecutor : public CPUKernelExecutor { public: BrgemmKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmKernelConfig config); - virtual ~BrgemmKernelExecutor() = default; // Function that will be called in runtime to execute the kernel static void execute(const BrgemmKernelExecutor* executor, void* in0, void* in1, void* out0); diff --git a/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.cpp b/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.cpp index d41dfae2b860b3..ed6a28482dcba3 100644 --- a/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.cpp @@ -41,8 +41,8 @@ std::set> BrgemmTppEmitter::get_supported_precisions( } void BrgemmTppEmitter::validate_arguments(const std::vector& in, const std::vector& out) const { - OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got" + std::to_string(in.size())); - OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got" + std::to_string(out.size())); + OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got", in.size()); + OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got", out.size()); } const uintptr_t BrgemmTppEmitter::get_compiled_kernel_ptr() const {