Skip to content

[NupharEP] Multiple optimizations #2380

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 12 commits into from
Nov 14, 2019
36 changes: 25 additions & 11 deletions onnxruntime/core/codegen/mti/math/matmul_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,22 +124,36 @@ tvm::Tensor MatMul(const tvm::Tensor& A, const tvm::Tensor& B, const std::string
tvm::Array<tvm::Expr>
ComputeMatMulShape(
const tvm::Array<tvm::Expr>& A_shape,
const tvm::Array<tvm::Expr>& B_shape) {
const tvm::Array<tvm::Expr>& B_shape,
bool trans_a,
bool trans_b) {
auto a_rank = A_shape.size();
auto b_rank = B_shape.size();
tvm::Array<tvm::Expr> output_shape;
int64_t output_rank = std::max(a_rank, b_rank);
MTI_ASSERT(tvm::ir::Equal(A_shape[a_rank - 1], B_shape[b_rank - 2]));
for (int64_t i = 0; i < output_rank - 2; i++) {
tvm::Expr broadcasted_dim = tvm::make_const(HalideIR::Int(32), 1);
bool broadcasted =
BroadcastDim(A_shape, i, output_rank, broadcasted_dim) &&
BroadcastDim(B_shape, i, output_rank, broadcasted_dim);
MTI_ASSERT(broadcasted);
output_shape.push_back(broadcasted_dim);
MTI_ASSERT(a_rank > 0 && b_rank > 0);
if (a_rank == 1 && b_rank == 1) {
MTI_ASSERT(!trans_a && !trans_b);
// reduction, output shape is empty
} else if (a_rank == 1) {
MTI_ASSERT(!trans_a && !trans_b);
output_shape = SliceShapeToDimension(B_shape, b_rank - 2);
output_shape.push_back(B_shape[b_rank - 1]);
} else if (b_rank == 1) {
MTI_ASSERT(!trans_a && !trans_b);
output_shape = SliceShapeToDimension(A_shape, a_rank - 1);
} else {
for (int64_t i = 0; i < output_rank - 2; i++) {
tvm::Expr broadcasted_dim = tvm::make_const(HalideIR::Int(32), 1);
bool broadcasted =
BroadcastDim(A_shape, i, output_rank, broadcasted_dim) &&
BroadcastDim(B_shape, i, output_rank, broadcasted_dim);
MTI_ASSERT(broadcasted);
output_shape.push_back(broadcasted_dim);
}
output_shape.push_back(A_shape[a_rank - (trans_a ? 1 : 2)]);
output_shape.push_back(B_shape[b_rank - (trans_b ? 2 : 1)]);
}
output_shape.push_back(A_shape[a_rank - 2]);
output_shape.push_back(B_shape[b_rank - 1]);
return output_shape;
}

Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/codegen/mti/math/matmul_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ namespace tvm_codegen {
tvm::Array<tvm::Expr>
ComputeMatMulShape(
const tvm::Array<tvm::Expr>& A_shape,
const tvm::Array<tvm::Expr>& B_shape);
const tvm::Array<tvm::Expr>& B_shape,
bool trans_a = false,
bool trans_b = false);

tvm::Tensor MatMul2D(const tvm::Tensor& A, const tvm::Tensor& B, bool trans_a = false, bool trans_b = false, const std::string& name = "matmul2d");

Expand Down
51 changes: 51 additions & 0 deletions onnxruntime/core/codegen/passes/op_ir_creator/math/unary_funcs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/framework/op_kernel_info.h"

namespace onnxruntime {
namespace tvm_codegen {
// helper class for unary_ops with alpha
class FuncWithAlpha {
public:
FuncWithAlpha(const Node& node) {
ProtoHelperNodeContext ctx(node);
OpNodeProtoHelper<ProtoHelperNodeContext> attrs(&ctx);
ORT_ENFORCE(attrs.GetAttr<float>("alpha", &alpha_).IsOK());
}

protected:
float alpha_;
};

// helper class for unary_ops with alpha and beta
class FuncWithAlphaBeta {
public:
FuncWithAlphaBeta(const Node& node) {
ProtoHelperNodeContext ctx(node);
OpNodeProtoHelper<ProtoHelperNodeContext> attrs(&ctx);
ORT_ENFORCE(attrs.GetAttr<float>("alpha", &alpha_).IsOK());
ORT_ENFORCE(attrs.GetAttr<float>("beta", &beta_).IsOK());
}

protected:
float alpha_;
float beta_;
};

// helper class for unary_ops with alpha and gamma
class FuncWithAlphaGamma {
public:
FuncWithAlphaGamma(const Node& node) {
ProtoHelperNodeContext ctx(node);
OpNodeProtoHelper<ProtoHelperNodeContext> attrs(&ctx);
ORT_ENFORCE(attrs.GetAttr<float>("alpha", &alpha_).IsOK());
ORT_ENFORCE(attrs.GetAttr<float>("gamma", &gamma_).IsOK());
}

protected:
float alpha_;
float gamma_;
};
} // namespace tvm_codegen
} // namespace onnxruntime
45 changes: 1 addition & 44 deletions onnxruntime/core/codegen/passes/op_ir_creator/math/unary_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,54 +5,11 @@

#include "core/codegen/common/op_macro.h"
#include "core/codegen/mti/math/unary_ops.h"
#include "core/framework/op_kernel_info.h"
#include "core/codegen/passes/op_ir_creator/math/unary_funcs.h"

namespace onnxruntime {
namespace tvm_codegen {

// helper class for unary_ops with alpha
class FuncWithAlpha {
public:
FuncWithAlpha(const Node& node) {
ProtoHelperNodeContext ctx(node);
OpNodeProtoHelper<ProtoHelperNodeContext> attrs(&ctx);
ORT_ENFORCE(attrs.GetAttr<float>("alpha", &alpha_).IsOK());
}

protected:
float alpha_;
};

// helper class for unary_ops with alpha and beta
class FuncWithAlphaBeta {
public:
FuncWithAlphaBeta(const Node& node) {
ProtoHelperNodeContext ctx(node);
OpNodeProtoHelper<ProtoHelperNodeContext> attrs(&ctx);
ORT_ENFORCE(attrs.GetAttr<float>("alpha", &alpha_).IsOK());
ORT_ENFORCE(attrs.GetAttr<float>("beta", &beta_).IsOK());
}

protected:
float alpha_;
float beta_;
};

// helper class for unary_ops with alpha and gamma
class FuncWithAlphaGamma {
public:
FuncWithAlphaGamma(const Node& node) {
ProtoHelperNodeContext ctx(node);
OpNodeProtoHelper<ProtoHelperNodeContext> attrs(&ctx);
ORT_ENFORCE(attrs.GetAttr<float>("alpha", &alpha_).IsOK());
ORT_ENFORCE(attrs.GetAttr<float>("gamma", &gamma_).IsOK());
}

protected:
float alpha_;
float gamma_;
};

// helper macro declares unary_ops helper class without attribute
#define FuncClass(name) \
class Func##name { \
Expand Down
4 changes: 0 additions & 4 deletions onnxruntime/core/providers/nuphar/common/nuphar_subgraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,6 @@ struct NupharSubgraphUnit {
return nodes.size() == 1;
}

const std::string& Name() const {
return nodes.front()->Name();
}

std::string UniqueId() const {
return std::to_string(id_);
}
Expand Down
66 changes: 66 additions & 0 deletions onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,5 +170,71 @@ std::string GetPackedFuncName(const nuphar::NupharSubgraphUnit& subgraph, const
return NormalizeCppName("_" + subgraph.UniqueId() + " " + codegen_target.GetTargetName());
}

bool TryCreateConstantScalar(
tvm::Expr& scalar,
const Tensor* tensor) {
if (!tensor)
return false;

auto num_elements = tensor->Shape().Size();
if (num_elements > 1) {
// for non-scalar, only fold to constant scalar when all values are identical
const auto& dtype = tensor->DataType();
auto elem_size = dtype->Size();
const void* data = tensor->DataRaw();

#define CHECK_ALL_TENSOR_SAME(T) \
for (int64_t i = 1; i < num_elements; ++i) { \
if (reinterpret_cast<const T*>(data)[i] != reinterpret_cast<const T*>(data)[0]) \
return false; \
}

switch (elem_size) {
case 1:
CHECK_ALL_TENSOR_SAME(int8_t);
break;
case 2:
CHECK_ALL_TENSOR_SAME(int16_t);
break;
case 4:
CHECK_ALL_TENSOR_SAME(int32_t);
break;
case 8:
CHECK_ALL_TENSOR_SAME(int64_t);
break;
default:
return false;
}

#undef CHECK_ALL_TENSOR_SAME
}

#define ASSIGN_TVM_SCALAR(tvm_type, tensor_type) \
if (tensor->IsDataType<tensor_type>()) { \
scalar = tvm::make_const(tvm_type, *tensor->Data<tensor_type>()); \
}

#define ASSIGN_TVM_SCALAR_ELSE(tvm_type, tensor_type) \
else ASSIGN_TVM_SCALAR(tvm_type, tensor_type)

ASSIGN_TVM_SCALAR(HalideIR::Float(32), float)
ASSIGN_TVM_SCALAR_ELSE(HalideIR::Float(64), double)
ASSIGN_TVM_SCALAR_ELSE(HalideIR::Int(64), int64_t)
ASSIGN_TVM_SCALAR_ELSE(HalideIR::Int(32), int32_t)
ASSIGN_TVM_SCALAR_ELSE(HalideIR::Int(16), int16_t)
ASSIGN_TVM_SCALAR_ELSE(HalideIR::Int(8), int8_t)
ASSIGN_TVM_SCALAR_ELSE(HalideIR::UInt(64), uint64_t)
ASSIGN_TVM_SCALAR_ELSE(HalideIR::UInt(32), uint32_t)
ASSIGN_TVM_SCALAR_ELSE(HalideIR::UInt(16), uint16_t)
ASSIGN_TVM_SCALAR_ELSE(HalideIR::UInt(8), uint8_t)
else {
return false;
}

#undef ASSIGN_TVM_SCALAR

return true;
}

} // namespace nuphar
} // namespace onnxruntime
6 changes: 5 additions & 1 deletion onnxruntime/core/providers/nuphar/common/nuphar_tvm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
#include "core/graph/graph.h"

namespace onnxruntime {
class CodeGenTarget; //forward

//forward
class CodeGenTarget;
class Tensor;

namespace nuphar {

Expand All @@ -22,5 +25,6 @@ void SaveTVMModuleToCache(const std::string& filename, tvm::runtime::Module& mod

std::string GetPackedFuncName(const nuphar::NupharSubgraphUnit& subgraph, const CodeGenTarget& codegen_target);

bool TryCreateConstantScalar(tvm::Expr& scalar, const Tensor* tensor);
} // namespace nuphar
} // namespace onnxruntime
11 changes: 11 additions & 0 deletions onnxruntime/core/providers/nuphar/compiler/nuphar_codegen_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "core/providers/nuphar/compiler/initializer_info.h"
#include "core/providers/nuphar/compiler/nuphar_handle.h"

#include <set>
#include <tvm/tvm.h>

namespace onnxruntime {
Expand Down Expand Up @@ -121,7 +122,17 @@ class NupharCodeGenCtx : public tvm_codegen::CodeGenContext {
return tvm_tensor_ctx_;
}

void InsertLiteral(const std::string& str) {
literalized_scalars_.insert(str);
}

bool CheckLiteral(const std::string& str) {
return literalized_scalars_.count(str) > 0;
}

private:
std::set<std::string> literalized_scalars_;

std::unique_ptr<NupharSubgraphUnitStats> graph_stats_;

const NupharCodeGenHandle* nuphar_handle_;
Expand Down
34 changes: 34 additions & 0 deletions onnxruntime/core/providers/nuphar/compiler/nuphar_op_ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "core/codegen/passes/op_ir_creator/tvm_ir_builder.h"
#include "core/codegen/passes/utils/ort_tvm_utils.h"
#include "core/common/common.h"
#include "core/providers/nuphar/common/nuphar_tvm_utils.h"
#include "core/providers/nuphar/compiler/initializer_info.h"
#include "core/providers/nuphar/compiler/x86/op_ir_creator/all_ops.h"

Expand All @@ -28,6 +29,10 @@ static const tvm::Tensor& GetOrCreateInitializer(const NodeArg* def,
bool is_sliced,
NupharCodeGenCtx& ctx_codegen);

static bool CreateScalarTensorFromInitializer(const Tensor* tensor,
const std::string& name,
NupharCodeGenCtx& ctx_codegen);

// CreateInputPlaceholder create tvm input placeholder (tvm::Tensor)
// NOTE: here we assume axis 0 is sequence
// TODO: add support for sequence not axis 0
Expand All @@ -51,6 +56,12 @@ static bool CreateInput(
return false;

ORT_ENFORCE(def->Shape());

if (nullptr != initialized_tensor &&
CreateScalarTensorFromInitializer(initialized_tensor, def->Name(), ctx_codegen)) {
return false; // constant scalar tensor do not need to be in input
}

if (nullptr != initialized_tensor) {
input = GetOrCreateInitializer(def, initialized_tensor, is_sliced, ctx_codegen);
} else {
Expand All @@ -68,6 +79,29 @@ static bool CreateInput(
return true;
}

bool CreateScalarTensorFromInitializer(const Tensor* tensor,
const std::string& name,
NupharCodeGenCtx& ctx_codegen) {
TVMTensorCtx& ctx_tensor = ctx_codegen.GetTVMTensorCtx();
ORT_ENFORCE(tensor != nullptr);

tvm::Expr constant_scalar;
if (!TryCreateConstantScalar(constant_scalar, tensor))
return false;

std::string normalized_name = NormalizeCppName(name);
auto tvm_tensor = tvm::compute(
tvm_codegen::ToTvmArray(tensor->Shape().GetDims()),
[&](const tvm::Array<tvm::Var>&) {
return constant_scalar;
},
normalized_name);

ctx_codegen.InsertLiteral(normalized_name);
ctx_tensor.inputs.emplace(name, std::move(tvm_tensor));
return true;
}

// GetOrCreateInitializer create tvm::placeholder for a marshalled weight
// with correpsonding data layout transfomration for a weight,
// Note the weight is fed during build
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ static void Traverse(const tvm::Tensor& tensor,
if (t->op->InputTensors().size() > 0) {
auto current_node = ctx_codegen.FindNode(t);
Traverse(t, current_node, ctx_codegen, ctx_schedule);
} else if (ctx_codegen.CheckLiteral(t->op->name)) {
TryInlineSchedule(t, ctx_schedule);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ namespace nuphar {
#define NUPHAR_TVM_X86_OP_IR_CREATOR_STRING(OP) \
STRINGIZE(NUPHAR_TVM_X86_OP_IR_CREATOR_CLASS(OP))

#define LIST_X86_POOL_OPS() \
POOL_OP(MaxPool) \
POOL_OP(AveragePool) \
POOL_OP(GlobalMaxPool) \
#define LIST_X86_POOL_OPS() \
POOL_OP(MaxPool) \
POOL_OP(AveragePool) \
POOL_OP(GlobalMaxPool) \
POOL_OP(GlobalAveragePool)

#define LIST_X86_UNARY_OPS() \
Expand All @@ -41,7 +41,8 @@ namespace nuphar {
#define LIST_REDUCE_V_OPS() \
REDUCE_V_OP(ReduceMax) \
REDUCE_V_OP(ReduceMin) \
REDUCE_V_OP(ReduceSum)
REDUCE_V_OP(ReduceSum) \
REDUCE_V_OP(ReduceMean)

#define LIST_ALL_X86_OPS() \
LIST_REDUCE_V_OPS() \
Expand All @@ -52,6 +53,7 @@ namespace nuphar {
ADD_OP_ITEM(MatMul) \
ADD_OP_ITEM(MatMulInteger) \
ADD_OP_ITEM(MatMulInteger16) \
ADD_OP_ITEM(Pow) \
ADD_OP_ITEM(Scatter) \
ADD_OP_ITEM(ScatterElements) \
ADD_OP_ITEM(Slice) \
Expand Down
Loading