Skip to content

Commit

Permalink
apply comments continue
Browse files Browse the repository at this point in the history
Vladislav comments apply-2
  • Loading branch information
chenhu-wang committed Feb 17, 2025
1 parent 4379a78 commit 5411abb
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 35 deletions.
10 changes: 6 additions & 4 deletions src/plugins/intel_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ else()
endif()
ov_dependent_option(ENABLE_SHL_FOR_CPU "Enable SHL for OpenVINO CPU Plugin" ${ENABLE_SHL_FOR_CPU_DEFAULT} "RISCV64" OFF)

# libxsmm doesn't support arm on android, see
# https://github.com/libxsmm/libxsmm/wiki/Q&A#what-operating-systems-are-covered-by-libxsmm-and-what-about-microsoft-windows
if(AARCH64 AND (NOT ANDROID))
set(ENABLE_SNIPPETS_LIBXSMM_TPP ON)
endif()

add_subdirectory(thirdparty)

if(WIN32)
Expand All @@ -156,10 +162,6 @@ if(ENABLE_CPU_DEBUG_CAPS)
add_definitions(-DCPU_DEBUG_CAPS)
endif()

if(AARCH64 AND (NOT ANDROID))
set(ENABLE_SNIPPETS_LIBXSMM_TPP ON)
endif()

if (ENABLE_SNIPPETS_LIBXSMM_TPP)
# Note: LIBXSMM_DEFAULT_CONFIG needed so libxsmm_config can be included without issues
add_definitions(-DSNIPPETS_LIBXSMM_TPP -DLIBXSMM_DEFAULT_CONFIG)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "brgemm_base.hpp"
#include "brgemm_generic.hpp"

#include "common/utils.hpp"
#include "dnnl_extension_utils.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "emitters/snippets/cpu_kernel_executor_table.hpp"
#include "emitters/utils.hpp"
#include "openvino/core/type/element_type.hpp"
#include "snippets/lowered/loop_info.hpp"
#include "snippets/lowered/loop_manager.hpp"
#include "utils/general_utils.h"
Expand Down Expand Up @@ -68,6 +67,7 @@ class BrgemmKernelExecutorHelper {
int loop_id,
const ov::snippets::lowered::ExpandedLoopInfoPtr& current_expanded_loop_info);

// This function returns M, N, K dimensions and beta of brgemm as a tuple, based on loop info in linear_ir.
static std::tuple<int64_t, int64_t, int64_t, float> get_runtime_brgemm_params(
const ov::snippets::lowered::ExpressionPtr& expr,
const ov::snippets::lowered::LinearIRCPtr& linear_ir);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <cpu/x64/brgemm/brgemm.hpp>

#include "cpu/x64/cpu_isa_traits.hpp"
#include "emitters/snippets/brgemm_base.hpp"
#include "emitters/snippets/brgemm_generic.hpp"

namespace ov::intel_cpu::x64 {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precision
}

void jit_brgemm_emitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got" + std::to_string(in.size()));
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got" + std::to_string(out.size()));
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got", in.size());
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got", out.size());
}

void jit_brgemm_emitter::emit_code_impl(const std::vector<size_t>& in,
Expand All @@ -51,6 +51,7 @@ void jit_brgemm_emitter::emit_code_impl(const std::vector<size_t>& in,

void jit_brgemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
validate_arguments(in, out);
// todo: use optimized reg spill after CVS-162498
std::unordered_set<size_t> exclude = {};
store_context(exclude);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ BrgemmKernelConfig::BrgemmKernelConfig(const element::Type& in0_dtype, const ele

bool BrgemmKernelConfig::operator==(const BrgemmKernelConfig& rhs) const {
return BrgemmGenericKernelConfig::operator==(rhs) &&
(get_static_params() == rhs.get_static_params() ||
*get_static_params() == *(rhs.get_static_params()));
(get_static_params() == rhs.get_static_params() || *get_static_params() == *(rhs.get_static_params()));
}

size_t BrgemmKernelConfig::compute_hash() const {
Expand All @@ -30,6 +29,13 @@ size_t BrgemmKernelConfig::compute_hash() const {

void BrgemmKernelConfig::update(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC, float beta) {
BrgemmGenericKernelConfig::update(M, N, K, LDA, LDB, LDC, beta);
// update compile flag, which should be reset depend on beta. It is combination of beta and static_compile_flag and
// considered in hash() and operator==
libxsmm_bitfield new_flag = get_static_compile_flags();
if (beta == 0) {
new_flag |= LIBXSMM_GEMM_FLAG_BETA_0;
}
set_compile_flags(new_flag);
m_hash = compute_hash();
}

Expand Down Expand Up @@ -76,6 +82,7 @@ std::string BrgemmKernelConfig::to_string() const {
std::stringstream ss;
ss << get_static_params()->to_string() << "\n";
ss << BrgemmGenericKernelConfig::to_string() << "\n";
PRINT(m_compile_flags);
return ss.str();
}
#endif
Expand All @@ -96,8 +103,8 @@ std::shared_ptr<BrgemmTppCompiledKernel> BrgemmKernelExecutor::compile_kernel(co
config.get_LDB(),
config.get_LDA(),
config.get_LDC(),
config.get_type_in0(),
config.get_type_in1(),
config.get_type_in0(),
config.get_type_out0(),
config.get_type_exec());
compiled_kernel->brgemm_kernel =
Expand Down Expand Up @@ -136,8 +143,6 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression
}

config.update(M, N, K, io_strides[0], io_strides[1], io_strides[2], beta);
// update compile flag, which is depend on beta. should be part of hash.
config.set_compile_flags_with_zero_beta(config.get_beta() == 0);
}

void BrgemmKernelExecutor::execute(const BrgemmKernelExecutor* executor, void* in0, void* in1, void* out0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#pragma once

#include "common/utils.hpp"
#include "emitters/snippets/brgemm_base.hpp"
#include "emitters/snippets/brgemm_generic.hpp"
#include "emitters/utils.hpp"
#include "libxsmm.h"

Expand All @@ -32,19 +32,9 @@ struct BrgemmKernelConfig : public BrgemmGenericKernelConfig {
}
size_t compute_hash() const;

libxsmm_bitfield get_static_compile_flags() const {
return m_static_params->m_compile_flags;
}
libxsmm_bitfield get_compile_flags() const {
return m_compile_flags;
}
void set_compile_flags_with_zero_beta(bool zero_beta) {
if (zero_beta) {
m_compile_flags = get_static_compile_flags() | LIBXSMM_GEMM_FLAG_BETA_0;
} else {
m_compile_flags = get_static_compile_flags();
}
}
bool get_prefetching_flags() const {
return m_static_params->m_prefetching_flags;
}
Expand All @@ -60,14 +50,16 @@ struct BrgemmKernelConfig : public BrgemmGenericKernelConfig {
libxsmm_datatype get_type_exec() const {
return m_static_params->m_type_exec;
}
libxsmm_bitfield get_static_compile_flags() const {
return m_static_params->m_compile_flags;
}
#ifdef SNIPPETS_DEBUG_CAPS
std::string to_string() const override;
#endif

private:
struct StaticParams {
StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype);
virtual ~StaticParams() = default;

bool operator==(const StaticParams& rhs) const;
bool operator!=(const StaticParams& rhs) const {
Expand All @@ -91,10 +83,13 @@ struct BrgemmKernelConfig : public BrgemmGenericKernelConfig {

size_t m_hash{SIZE_MAX};
};
std::shared_ptr<StaticParams> get_static_params() const {

const std::shared_ptr<StaticParams>& get_static_params() const {
return m_static_params;
}

void set_compile_flags(const libxsmm_bitfield& compile_flags) {
m_compile_flags = compile_flags;
}
libxsmm_bitfield m_compile_flags{0};
std::shared_ptr<StaticParams> m_static_params{nullptr};

Expand All @@ -112,7 +107,6 @@ struct BrgemmTppCompiledKernel {
class BrgemmKernelExecutor : public CPUKernelExecutor<BrgemmKernelConfig, BrgemmTppCompiledKernel> {
public:
BrgemmKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmKernelConfig config);
virtual ~BrgemmKernelExecutor() = default;

// Function that will be called in runtime to execute the kernel
static void execute(const BrgemmKernelExecutor* executor, void* in0, void* in1, void* out0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ std::set<std::vector<element::Type>> BrgemmTppEmitter::get_supported_precisions(
}

void BrgemmTppEmitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got" + std::to_string(in.size()));
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got" + std::to_string(out.size()));
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got", in.size());
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got", out.size());
}

const uintptr_t BrgemmTppEmitter::get_compiled_kernel_ptr() const {
Expand Down
4 changes: 0 additions & 4 deletions src/plugins/intel_cpu/thirdparty/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,6 @@ function(ov_add_onednn)
endif()
endfunction()

if(AARCH64 AND (NOT ANDROID))
set(ENABLE_SNIPPETS_LIBXSMM_TPP ON)
endif()

if (ENABLE_SNIPPETS_LIBXSMM_TPP)
# This flag is to suppress "warning as error" in libxsmm compilation, such as
# "generator_common_aarch64.c:60:6: error: no previous declaration for ‘libxsmm_generator_vcvt_f32i8_aarch64_sve’ [-Werror=missing-declarations]"
Expand Down

0 comments on commit 5411abb

Please # to comment.