Skip to content

Commit

Permalink
Vladislav comments apply-2
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Feb 17, 2025
1 parent 51d29ef commit ce3e097
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precision
}

void jit_brgemm_emitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got" + std::to_string(in.size()));
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got" + std::to_string(out.size()));
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got", in.size());
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got", out.size());
}

void jit_brgemm_emitter::emit_code_impl(const std::vector<size_t>& in,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ size_t BrgemmKernelConfig::compute_hash() const {

void BrgemmKernelConfig::update(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC, float beta) {
BrgemmGenericKernelConfig::update(M, N, K, LDA, LDB, LDC, beta);
// update compile flag, which should be reset depend on beta. It is combination of beta and static_compile_flag and
// considered in hash() and operator==
libxsmm_bitfield new_flag = get_static_compile_flags();
if (beta == 0) {
new_flag |= LIBXSMM_GEMM_FLAG_BETA_0;
}
set_compile_flags(new_flag);
m_hash = compute_hash();
}

Expand Down Expand Up @@ -136,8 +143,6 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression
}

config.update(M, N, K, io_strides[0], io_strides[1], io_strides[2], beta);
// update compile flag, which is depend on beta. should be part of hash.
config.set_compile_flags_with_zero_beta(config.get_beta() == 0);
}

void BrgemmKernelExecutor::execute(const BrgemmKernelExecutor* executor, void* in0, void* in1, void* out0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,6 @@ struct BrgemmKernelConfig : public BrgemmGenericKernelConfig {
libxsmm_bitfield get_compile_flags() const {
return m_compile_flags;
}
void set_compile_flags_with_zero_beta(bool zero_beta) {
if (zero_beta) {
m_compile_flags = m_static_params->m_compile_flags | LIBXSMM_GEMM_FLAG_BETA_0;
} else {
m_compile_flags = m_static_params->m_compile_flags;
}
}
bool get_prefetching_flags() const {
return m_static_params->m_prefetching_flags;
}
Expand All @@ -57,14 +50,16 @@ struct BrgemmKernelConfig : public BrgemmGenericKernelConfig {
libxsmm_datatype get_type_exec() const {
return m_static_params->m_type_exec;
}
libxsmm_bitfield get_static_compile_flags() const {
return m_static_params->m_compile_flags;
}
#ifdef SNIPPETS_DEBUG_CAPS
std::string to_string() const override;
#endif

private:
struct StaticParams {
StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype);
virtual ~StaticParams() = default;

bool operator==(const StaticParams& rhs) const;
bool operator!=(const StaticParams& rhs) const {
Expand All @@ -88,10 +83,13 @@ struct BrgemmKernelConfig : public BrgemmGenericKernelConfig {

size_t m_hash{SIZE_MAX};
};

const std::shared_ptr<StaticParams>& get_static_params() const {
return m_static_params;
}

void set_compile_flags(const libxsmm_bitfield& compile_flags) {
m_compile_flags = compile_flags;
}
libxsmm_bitfield m_compile_flags{0};
std::shared_ptr<StaticParams> m_static_params{nullptr};

Expand All @@ -109,7 +107,6 @@ struct BrgemmTppCompiledKernel {
class BrgemmKernelExecutor : public CPUKernelExecutor<BrgemmKernelConfig, BrgemmTppCompiledKernel> {
public:
BrgemmKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmKernelConfig config);
virtual ~BrgemmKernelExecutor() = default;

// Function that will be called in runtime to execute the kernel
static void execute(const BrgemmKernelExecutor* executor, void* in0, void* in1, void* out0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ std::set<std::vector<element::Type>> BrgemmTppEmitter::get_supported_precisions(
}

void BrgemmTppEmitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got" + std::to_string(in.size()));
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got" + std::to_string(out.size()));
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got", in.size());
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got", out.size());
}

const uintptr_t BrgemmTppEmitter::get_compiled_kernel_ptr() const {
Expand Down

0 comments on commit ce3e097

Please # to comment.