Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 84 additions & 3 deletions lib/Transforms/FlattenMemRefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"

namespace circt {
Expand All @@ -46,6 +47,21 @@ struct FunctionRewrite {
FunctionType type;
};

static std::atomic<unsigned> globalCounter(0);
static DenseMap<StringAttr, StringAttr> globalNameMap;

static MemRefType getFlattenedMemRefType(MemRefType type) {
return MemRefType::get(SmallVector<int64_t>{type.getNumElements()},
type.getElementType());
}

static std::string getFlattenedMemRefName(StringAttr baseName,
MemRefType type) {
unsigned uniqueID = globalCounter++;
return llvm::formatv("{0}_{1}x{2}_{3}", baseName, type.getNumElements(),
type.getElementType(), uniqueID);
}

// Flatten indices by generating the product of the i'th index and the [0:i-1]
// shapes, for each index, and then summing these.
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
Expand Down Expand Up @@ -154,13 +170,74 @@ struct AllocOpConversion : public OpConversionPattern<memref::AllocOp> {
MemRefType type = op.getType();
if (isUniDimensional(type) || !type.hasStaticShape())
return failure();
MemRefType newType = MemRefType::get(
SmallVector<int64_t>{type.getNumElements()}, type.getElementType());
MemRefType newType = getFlattenedMemRefType(type);
rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
return success();
}
};

struct GlobalOpConversion : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::GlobalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType type = op.getType();
if (isUniDimensional(type) || !type.hasStaticShape())
return failure();
MemRefType newType = getFlattenedMemRefType(type);

auto cstAttr =
llvm::dyn_cast_or_null<DenseElementsAttr>(op.getConstantInitValue());

SmallVector<Attribute> flattenedVals;
for (auto attr : cstAttr.getValues<Attribute>())
flattenedVals.push_back(attr);

auto newTypeAttr = TypeAttr::get(newType);
auto newNameStr = getFlattenedMemRefName(op.getConstantAttrName(), type);
auto newName = rewriter.getStringAttr(newNameStr);
globalNameMap[op.getSymNameAttr()] = newName;

RankedTensorType tensorType = RankedTensorType::get(
{static_cast<int64_t>(flattenedVals.size())}, type.getElementType());
auto newInitValue = DenseElementsAttr::get(tensorType, flattenedVals);

rewriter.replaceOpWithNewOp<memref::GlobalOp>(
op, newName, op.getSymVisibilityAttr(), newTypeAttr, newInitValue,
op.getConstantAttr(), op.getAlignmentAttr());

return success();
}
};

struct GetGlobalOpConversion : public OpConversionPattern<memref::GetGlobalOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *symbolTableOp = op->getParentWithTrait<mlir::OpTrait::SymbolTable>();
auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
SymbolTable::lookupSymbolIn(symbolTableOp, op.getNameAttr()));

MemRefType type = globalOp.getType();
if (isUniDimensional(type) || !type.hasStaticShape())
return failure();

MemRefType newType = getFlattenedMemRefType(type);
auto originalName = globalOp.getSymNameAttr();
auto newNameIt = globalNameMap.find(originalName);
if (newNameIt == globalNameMap.end())
return failure();
auto newName = newNameIt->second;

rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, newType, newName);

return success();
}
};

// A generic pattern which will replace an op with a new op of the same type
// but using the adaptor (type converted) operands.
template <typename TOp>
Expand Down Expand Up @@ -256,7 +333,10 @@ static void populateFlattenMemRefsLegality(ConversionTarget &target) {
[](memref::StoreOp op) { return op.getIndices().size() == 1; });
target.addDynamicallyLegalOp<memref::LoadOp>(
[](memref::LoadOp op) { return op.getIndices().size() == 1; });

target.addDynamicallyLegalOp<memref::GlobalOp>(
[](memref::GlobalOp op) { return isUniDimensional(op.getType()); });
target.addDynamicallyLegalOp<memref::GetGlobalOp>(
[](memref::GetGlobalOp op) { return isUniDimensional(op.getType()); });
addGenericLegalityConstraint<mlir::cf::CondBranchOp, mlir::cf::BranchOp,
func::CallOp, func::ReturnOp, memref::DeallocOp,
memref::CopyOp>(target);
Expand Down Expand Up @@ -323,6 +403,7 @@ struct FlattenMemRefPass
RewritePatternSet patterns(ctx);
SetVector<StringRef> rewrittenCallees;
patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
GlobalOpConversion, GetGlobalOpConversion,
OperandConversionPattern<func::ReturnOp>,
OperandConversionPattern<memref::DeallocOp>,
CondBranchOpConversion,
Expand Down
86 changes: 86 additions & 0 deletions test/Transforms/flatten_memref.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,89 @@ func.func @dealloc_copy(%arg : memref<4x4xi32>) -> memref<4x4xi32> {
memref.dealloc %0 : memref<4x4xi32>
return %0 : memref<4x4xi32>
}

// -----

module {
// CHECK-LABEL: memref.global "private" constant @constant_10xf32_0 : memref<10xf32> = dense<[0.433561265, 0.0884729773, -0.39487046, -0.190938368, 0.705071926, -0.648731529, -0.00710275536, -0.278010637, -0.573243499, 5.029220e-01]> {alignment = 64 : i64}
memref.global "private" constant @__constant_5x2xf32 : memref<5x2xf32> = dense<[[0.433561265, 0.0884729773], [-0.39487046, -0.190938368], [0.705071926, -0.648731529], [-0.00710275536, -0.278010637], [-0.573243499, 5.029220e-01]]> {alignment = 64 : i64}

// CHECK-LABEL: func.func @forward() -> f32 {
// CHECK: %[[VAL_0:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_2:.*]] = memref.get_global @constant_10xf32_0 : memref<10xf32>
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_4:.*]] = arith.shli %[[VAL_0]], %[[VAL_3]] : index
// CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_1]] : index
// CHECK: %[[VAL_6:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref<10xf32>
// CHECK: return %[[VAL_6]] : f32
// CHECK: }
// CHECK: }
func.func @forward() -> f32 {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%0 = memref.get_global @__constant_5x2xf32 : memref<5x2xf32>
%1 = memref.load %0[%c2, %c1] : memref<5x2xf32>
return %1 :f32
}
}

// GlobalOp/GetGlobalOp may result in name conflict after flattening

module {
// CHECK-LABEL: module {
// CHECK: memref.global "private" constant @__constant_1xf32 : memref<1xf32> = dense<-0.344258487> {alignment = 64 : i64}
// CHECK: memref.global "private" constant @constant_2xf32_1 : memref<2xf32> = dense<[-0.154929623, 0.142687559]> {alignment = 64 : i64}
// CHECK: memref.global "private" constant @__constant_2xf32 : memref<2xf32> = dense<[-0.23427248, 0.918611288]> {alignment = 64 : i64}
// CHECK: memref.global "private" constant @constant_2xf32_2 : memref<2xf32> = dense<[0.764538527, 0.83000791]> {alignment = 64 : i64}
memref.global "private" constant @__constant_1xf32 : memref<1xf32> = dense<-0.344258487> {alignment = 64 : i64}
memref.global "private" constant @__constant_1x2xf32 : memref<1x2xf32> = dense<[[-0.154929623, 0.142687559]]> {alignment = 64 : i64}
memref.global "private" constant @__constant_2xf32 : memref<2xf32> = dense<[-0.23427248, 0.918611288]> {alignment = 64 : i64}
memref.global "private" constant @__constant_2x1xf32 : memref<2x1xf32> = dense<[[0.764538527], [0.83000791]]> {alignment = 64 : i64}

// CHECK: func.func @main(%[[VAL_0:.*]]: memref<2xf32>, %[[VAL_1:.*]]: memref<1xf32>) {
// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_6:.*]] = memref.get_global @constant_2xf32_2 : memref<2xf32>
// CHECK: %[[VAL_7:.*]] = memref.get_global @__constant_2xf32 : memref<2xf32>
// CHECK: %[[VAL_8:.*]] = memref.get_global @constant_2xf32_1 : memref<2xf32>
// CHECK: %[[VAL_9:.*]] = memref.get_global @__constant_1xf32 : memref<1xf32>
// CHECK: %[[VAL_10:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_11:.*]] = arith.shli %[[VAL_3]], %[[VAL_10]] : index
// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_4]] : index
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<2xf32>
// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<2xf32>
// CHECK: %[[VAL_15:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_16:.*]] = arith.shli %[[VAL_4]], %[[VAL_15]] : index
// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_3]] : index
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<2xf32>
// CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_13]], %[[VAL_14]] : f32
// CHECK: %[[VAL_20:.*]] = arith.addf %[[VAL_18]], %[[VAL_19]] : f32
// CHECK: memref.store %[[VAL_20]], %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref<1xf32>
// CHECK: memref.copy %[[VAL_9]], %[[VAL_1]] : memref<1xf32> to memref<1xf32>
// CHECK: return
// CHECK: }
// CHECK: }

func.func @main(%arg0: memref<2x1xf32>, %arg1: memref<1xf32>) {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = memref.get_global @__constant_2x1xf32 : memref<2x1xf32>
%1 = memref.get_global @__constant_2xf32 : memref<2xf32>
%2 = memref.get_global @__constant_1x2xf32 : memref<1x2xf32>
%3 = memref.get_global @__constant_1xf32 : memref<1xf32>
%4 = memref.load %0[%c1, %c0] : memref<2x1xf32>
%5 = memref.load %1[%c0] : memref<2xf32>
%6 = memref.load %2[%c0, %c1] : memref<1x2xf32>
%7 = arith.mulf %4, %5 : f32
%8 = arith.addf %6, %7 : f32
memref.store %8, %3[%c0] : memref<1xf32>
memref.copy %3, %arg1 : memref<1xf32> to memref<1xf32>
return
}
}

Loading