From ac8eb3848fd6c26c71bdd90509e5115a0c9a4dca Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 6 Feb 2025 15:52:29 -0800 Subject: [PATCH 01/16] refactor ContextAwareTypeConversion --- .../Secret/Conversions/SecretToBGV/BUILD | 1 + .../Conversions/SecretToBGV/SecretToBGV.cpp | 18 +- .../Conversions/SecretToCKKS/SecretToCKKS.cpp | 96 +++++---- lib/Utils/BUILD | 16 ++ lib/Utils/ContextAwareTypeConversion.cpp | 172 ++++++++++++++++ lib/Utils/ContextAwareTypeConversion.h | 123 +++++++++++ lib/Utils/ConversionUtils.cpp | 151 ++------------ lib/Utils/ConversionUtils.h | 194 ++++++++---------- .../Conversions/secret_to_bgv/invalid.mlir | 7 +- .../Conversions/secret_to_ckks/invalid.mlir | 12 +- 10 files changed, 492 insertions(+), 298 deletions(-) create mode 100644 lib/Utils/ContextAwareTypeConversion.cpp create mode 100644 lib/Utils/ContextAwareTypeConversion.h diff --git a/lib/Dialect/Secret/Conversions/SecretToBGV/BUILD b/lib/Dialect/Secret/Conversions/SecretToBGV/BUILD index d9f110792..20f23abac 100644 --- a/lib/Dialect/Secret/Conversions/SecretToBGV/BUILD +++ b/lib/Dialect/Secret/Conversions/SecretToBGV/BUILD @@ -22,6 +22,7 @@ cc_library( "@heir//lib/Dialect/Secret/IR:Dialect", "@heir//lib/Parameters/BGV:Params", "@heir//lib/Utils", + "@heir//lib/Utils:ContextAwareTypeConversion", "@heir//lib/Utils:ConversionUtils", "@heir//lib/Utils/Polynomial", "@llvm-project//llvm:Support", diff --git a/lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.cpp b/lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.cpp index 7d9baba4a..744dba416 100644 --- a/lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.cpp +++ b/lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.cpp @@ -23,6 +23,7 @@ #include "lib/Dialect/Secret/IR/SecretOps.h" #include "lib/Dialect/Secret/IR/SecretTypes.h" #include "lib/Parameters/BGV/Params.h" +#include "lib/Utils/ContextAwareTypeConversion.h" #include "lib/Utils/ConversionUtils.h" #include "lib/Utils/Polynomial/Polynomial.h" #include "lib/Utils/Utils.h" @@ -88,12 +89,14 @@ polynomial::RingAttr getRlweRNSRingWithLevel(polynomial::RingAttr ringAttr, } // namespace -class SecretToBGVTypeConverter : public TypeWithAttrTypeConverter { +class SecretToBGVTypeConverter + : public UniquelyNamedAttributeAwareTypeConverter { public: SecretToBGVTypeConverter(MLIRContext *ctx, ::mlir::heir::polynomial::RingAttr rlweRing, int64_t ptm) - : TypeWithAttrTypeConverter(mgmt::MgmtDialect::kArgMgmtAttrName) { + : UniquelyNamedAttributeAwareTypeConverter( + mgmt::MgmtDialect::kArgMgmtAttrName) { ring = rlweRing; plaintextModulus = ptm; @@ -132,10 +135,8 @@ class SecretToBGVTypeConverter : public TypeWithAttrTypeConverter { lwe::ModulusChainAttr::get(ctx, moduliChain, level)); } - Type convertTypeWithAttr(Type type, Attribute attr) const override { - auto secretTy = dyn_cast(type); - // guard against null attribute - if (secretTy && attr) { + FailureOr convert(Type type, Attribute attr) const override { + if (auto secretTy = dyn_cast(type)) { auto mgmtAttr = dyn_cast(attr); if (mgmtAttr) { return convertSecretTypeWithMgmtAttr(secretTy, mgmtAttr); @@ -283,9 +284,8 @@ struct SecretToBGV : public impl::SecretToBGVBase { target.addIllegalDialect(); target.addIllegalOp(); target.addIllegalOp(); - target.addDynamicallyLegalOp([&](func::FuncOp op) { - return typeConverter.isFuncArgumentAndResultLegal(op); - }); + target.addDynamicallyLegalOp( + [&](func::FuncOp op) { return typeConverter.isLegal(op); }); patterns.add< ConvertFuncWithContextAwareTypeConverter, diff --git a/lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.cpp b/lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.cpp index 65292d59d..e391bc0a6 100644 --- a/lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.cpp +++ b/lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.cpp @@ -116,12 +116,14 @@ FailureOr> getNonUnitDimension( } // namespace -class SecretToCKKSTypeConverter : public TypeWithAttrTypeConverter { +class SecretToCKKSTypeConverter + : public UniquelyNamedAttributeAwareTypeConverter { public: SecretToCKKSTypeConverter(MLIRContext *ctx, ::mlir::heir::polynomial::RingAttr rlweRing, bool packTensorInSlots) - : TypeWithAttrTypeConverter(mgmt::MgmtDialect::kArgMgmtAttrName) { + : UniquelyNamedAttributeAwareTypeConverter( + mgmt::MgmtDialect::kArgMgmtAttrName) { addConversion([](Type type) { return type; }); ring_ = rlweRing; @@ -184,10 +186,8 @@ class SecretToCKKSTypeConverter : public TypeWithAttrTypeConverter { ciphertext); } - Type convertTypeWithAttr(Type type, Attribute attr) const override { - auto secretTy = dyn_cast(type); - // guard against null attribute - if (secretTy && attr) { + FailureOr convert(Type type, Attribute attr) const override { + if (auto secretTy = dyn_cast(type)) { auto mgmtAttr = dyn_cast(attr); if (mgmtAttr) { return convertSecretTypeWithMgmtAttr(secretTy, mgmtAttr); @@ -281,25 +281,35 @@ class SecretForOpConversion : public OpConversionPattern { affine::AffineForOp forOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { const auto *contextAwareTypeConverter = - dynamic_cast(getTypeConverter()); + dynamic_cast(getTypeConverter()); SmallVector newInitTypes; - contextAwareTypeConverter->convertValueRangeTypes(forOp.getInits(), - newInitTypes); + if (failed(contextAwareTypeConverter->convertValueRangeTypes( + forOp.getInits(), newInitTypes))) + return failure(); + + // A hack because OpAdaptor doesn't use the context aware type converter to + // convert operand types. + for (size_t i = 0; i < newInitTypes.size(); i++) { + adaptor.getInits()[i].setType(newInitTypes[i]); + } Location loc = forOp.getLoc(); auto newForOp = rewriter.create( loc, ValueRange(forOp.getLowerBoundOperands()), forOp.getLowerBoundMap(), ValueRange(forOp.getUpperBoundOperands()), forOp.getUpperBoundMap(), forOp.getStep().getZExtValue(), - adaptor.getInits(), [](OpBuilder &, Location, Value, ValueRange) {}); + adaptor.getInits()); + newForOp->setAttrs(forOp->getAttrs()); mlir::Block *newBody = newForOp.getBody(); mlir::Block *oldBody = forOp.getBody(); - rewriter.setInsertionPoint(newBody, newBody->begin()); + rewriter.setInsertionPointToStart(newBody); + IRMapping mp; SmallVector newBlockArgs; for (auto arg : newBody->getArguments()) { + auto oldArg = oldBody->getArgument(arg.getArgNumber()); if (auto lweTy = dyn_cast(arg.getType())) { // For each block arg that is a secret type, we need to create an // operation inside the for loops block that can hold the mgmt attr for @@ -309,19 +319,36 @@ class SecretForOpConversion : public OpConversionPattern { lweTy.getApplicationData().getMessageType()); auto cast = rewriter.create( loc, underlyingTy, arg); - cast->setAttrs(oldBody->getArgument(arg.getArgNumber()) - .getUsers() - .begin() - ->getAttrs()); + cast->setAttrs(oldArg.getUsers().begin()->getAttrs()); newBlockArgs.push_back(cast.getResult(0)); + mp.map(oldArg, cast.getResult(0)); } else { newBlockArgs.push_back(arg); + mp.map(oldArg, arg); } } - // Move the body of the old ForOp to the new one. - rewriter.mergeBlocks(oldBody, newBody, newBlockArgs); + for (auto &op : *oldBody) { + rewriter.clone(op, mp); + } + + // Hack: ensure the yield is converted at the same time as the for op. + auto yieldOp = + cast(newForOp.getBody()->getTerminator()); + SmallVector newYieldTypes; + if (failed(contextAwareTypeConverter->convertValueRangeTypes( + yieldOp.getOperands(), newYieldTypes))) + return failure(); + + rewriter.modifyOpInPlace(yieldOp, [&] { + for (auto [newYield, newYieldType] : + llvm::zip(yieldOp.getOperands(), newYieldTypes)) { + newYield.setType(newYieldType); + } + }); + rewriter.replaceOp(forOp, newForOp); + return success(); } }; @@ -471,37 +498,21 @@ struct SecretToCKKS : public impl::SecretToCKKSBase { target.addIllegalOp(); // for mod reduce on tensor ciphertext target.addLegalOp(); - target.addDynamicallyLegalOp([&](func::FuncOp op) { - return typeConverter.isFuncArgumentAndResultLegal(op); - }); // to resolve unlinked block arguments target.addLegalOp(); + target.addDynamicallyLegalOp( + [&](func::FuncOp op) { return typeConverter.isLegal(op); }); // We add an explicit allowlist of operations to mark legal. If we use // markUnknownOpDynamicallyLegal, then ConvertAny will be applied to any // remaining operations and potentially cause a crash. + target.addDynamicallyLegalOp( + [&](Operation *op) { return typeConverter.isLegal(op); }); target.addDynamicallyLegalOp( - [&](Operation *op) { return typeConverter.isOperationLegal(op); }); - target.addDynamicallyLegalOp( - [&](Operation *op) { return typeConverter.isOperationLegal(op); }); - target.addDynamicallyLegalOp( - [&](Operation *op) { return typeConverter.isOperationLegal(op); }); - // We can't use typeConverter.isOperationLegal here because it requires the - // region being converted. - target.addDynamicallyLegalOp( - [&](affine::AffineForOp op) { - for (auto operand : op->getOperands()) { - if (!typeConverter.isValueLegal(operand)) { - return false; - } - } - for (auto result : op->getResults()) { - if (!typeConverter.isValueLegal(result)) { - return false; - } - } - return true; - }); + [&](Operation *op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](Operation *op) { return typeConverter.isLegal(op); }); patterns.add< ConvertFuncWithContextAwareTypeConverter, @@ -525,7 +536,8 @@ struct SecretToCKKS : public impl::SecretToCKKSBase { SecretGenericOpCipherPlainConversion, SecretGenericOpCipherPlainConversion, SecretForOpConversion, ConvertAny, - SecretGenericFuncCallConversion>(typeConverter, context); + ConvertAny, SecretGenericFuncCallConversion>( + typeConverter, context); if (failed(applyPartialConversion(module, target, std::move(patterns)))) { return signalPassFailure(); diff --git a/lib/Utils/BUILD b/lib/Utils/BUILD index 65eaa55ec..db36d904e 100644 --- a/lib/Utils/BUILD +++ b/lib/Utils/BUILD @@ -31,6 +31,7 @@ cc_library( srcs = ["ConversionUtils.cpp"], hdrs = ["ConversionUtils.h"], deps = [ + ":ContextAwareTypeConversion", "@heir//lib/Dialect/LWE/IR:Dialect", "@heir//lib/Dialect/Mgmt/IR:Dialect", "@heir//lib/Dialect/Secret/IR:Dialect", @@ -50,6 +51,21 @@ cc_library( ], ) +cc_library( + name = "ContextAwareTypeConversion", + srcs = ["ContextAwareTypeConversion.cpp"], + hdrs = ["ContextAwareTypeConversion.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + cc_library( name = "TargetUtils", srcs = ["TargetUtils.cpp"], diff --git a/lib/Utils/ContextAwareTypeConversion.cpp b/lib/Utils/ContextAwareTypeConversion.cpp new file mode 100644 index 000000000..2e7bb0c45 --- /dev/null +++ b/lib/Utils/ContextAwareTypeConversion.cpp @@ -0,0 +1,172 @@ +#include "lib/Utils/ContextAwareTypeConversion.h" + +#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace heir { + +bool ContextAwareTypeConverter::isLegal(Operation *op) const { + SmallVector newOperandTypes; + if (failed(convertValueRangeTypes(op->getOperands(), newOperandTypes))) + return false; + + SmallVector newResultTypes; + if (failed(convertValueRangeTypes(op->getResults(), newResultTypes))) + return false; + + return op->getOperandTypes() == newOperandTypes && + op->getResultTypes() == newResultTypes; +} + +bool ContextAwareTypeConverter::isLegal(FunctionOpInterface funcOp) const { + SmallVector newOperandTypes; + SmallVector newResultTypes; + + if (failed(convertFuncSignature(funcOp, newOperandTypes, newResultTypes))) + return false; + + return funcOp.getArgumentTypes() == ArrayRef(newOperandTypes) && + funcOp.getResultTypes() == ArrayRef(newResultTypes); +} + +// Convert a range of values, with converted types stored in newTypes. +LogicalResult AttributeAwareTypeConverter::convertValueRangeTypes( + ValueRange values, SmallVectorImpl &newTypes) const { + newTypes.reserve(values.size()); + for (auto value : values) { + FailureOr attr = getContextualAttr(value); + // If no contextual attribute is found, it may be a type that doesn't need + // conversion. In this case, just use the type as is. An example of this is + // a bgv.rotate op, which consumes a ciphertext (which can be converted) + // and an index to rotate (which needs no conversion). + if (failed(attr)) { + newTypes.push_back(value.getType()); + continue; + } + + FailureOr newType = convert(value.getType(), attr.value()); + if (failed(newType)) return failure(); + + newTypes.push_back(newType.value()); + } + + return success(); +} + +// Convert types of the arguments and results of a function. +LogicalResult AttributeAwareTypeConverter::convertFuncSignature( + FunctionOpInterface funcOp, SmallVectorImpl &newArgTypes, + SmallVectorImpl &newResultTypes) const { + if (funcOp.isDeclaration()) { + if (failed(convertTypes(funcOp.getArgumentTypes(), newArgTypes))) + return failure(); + if (failed(convertTypes(funcOp.getResultTypes(), newResultTypes))) + return failure(); + return success(); + } + + for (auto argument : funcOp.getArguments()) { + FailureOr attr = getContextualAttr(argument); + if (failed(attr)) { + newArgTypes.push_back(argument.getType()); + continue; + } + FailureOr newType = convert(argument.getType(), attr.value()); + if (failed(newType)) return failure(); + newArgTypes.push_back(newType.value()); + } + // To get the value corresponding to the func's return types, we need to get + // the terminator operands. + for (auto &block : funcOp.getBlocks()) { + for (auto result : block.getTerminator()->getOperands()) { + FailureOr attr = getContextualAttr(result); + if (failed(attr)) return failure(); + FailureOr newType = convert(result.getType(), attr.value()); + if (failed(newType)) return failure(); + newResultTypes.push_back(newType.value()); + } + } + return success(); +} + +LogicalResult ConvertFuncWithContextAwareTypeConverter::matchAndRewrite( + func::FuncOp op, PatternRewriter &rewriter) const { + auto funcOp = cast(op); + + SmallVector newFuncOperandsType; + SmallVector newFuncResultsType; + if (failed(contextAwareTypeConverter->convertFuncSignature( + op, newFuncOperandsType, newFuncResultsType))) + return failure(); + + auto newFuncType = + FunctionType::get(getContext(), newFuncOperandsType, newFuncResultsType); + rewriter.modifyOpInPlace(funcOp, [&] { + funcOp.setType(newFuncType); + + if (funcOp.isDeclaration()) return; + + // Set the block argument types to match the new signature + for (auto [arg, newType] : llvm::zip( + funcOp.getBody().front().getArguments(), newFuncOperandsType)) { + arg.setType(newType); + } + + // This is a weird part related to the hacky nature of this context-aware + // type conversion. In order to make this work, we have to also modify the + // func.return at the same time as the func.func containing it. Otherwise, + // if we tried to make a separate "context aware" conversion pattern for + // the func.return op, it would not have the type-converted operands + // available to its OpAdaptor. Furthermore, when I tried making a separate + // context-aware pattern for func.return in isolation, I couldn't get it to + // legalize and the conversion engine looped infinitely. + Block &block = funcOp.getBody().front(); + for (auto [returnOperand, newType] : + llvm::zip(block.getTerminator()->getOperands(), newFuncResultsType)) { + returnOperand.setType(newType); + } + }); + + return success(); +} + +LogicalResult convertAnyOperand(const ContextAwareTypeConverter *typeConverter, + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) { + if (typeConverter->isLegal(op)) { + return failure(); + } + + SmallVector newOperandTypes; + if (failed(typeConverter->convertValueRangeTypes(op->getOperands(), + newOperandTypes))) + return failure(); + + SmallVector newResultTypes; + if (failed(typeConverter->convertValueRangeTypes(op->getResults(), + newResultTypes))) + return failure(); + + SmallVector, 1> regions; + if (!op->getRegions().empty()) { + // Because the Dialect conversion framework handles converting region types + // and it requires some extra work supporting block signature type + // conversion, etc. Do this when the need arises. + return op->emitError( + "Generic context-aware op conversion requires op to have no regions"); + } + + Operation *newOp = rewriter.create(OperationState( + op->getLoc(), op->getName().getStringRef(), operands, newResultTypes, + op->getAttrs(), op->getSuccessors(), regions)); + + rewriter.replaceOp(op, newOp); + return success(); +} + +} // namespace heir +} // namespace mlir diff --git a/lib/Utils/ContextAwareTypeConversion.h b/lib/Utils/ContextAwareTypeConversion.h new file mode 100644 index 000000000..8128fe2b5 --- /dev/null +++ b/lib/Utils/ContextAwareTypeConversion.h @@ -0,0 +1,123 @@ +#ifndef LIB_UTILS_CONTEXTAWARETYPECONVERSION_H_ +#define LIB_UTILS_CONTEXTAWARETYPECONVERSION_H_ + +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir { +namespace heir { + +// A class to manage type conversions when the context of the value matters. +// Note this excludes the ability to use this to convert ops that don't define +// values, such as a func.func declaration (isDeclaration() is true). +// +// This framework only supports 1-1 type conversions, without dropping or +// inserting additional types during the conversion. +// +// The inheritance from TypeConverter does nothing, but is necessary to allow +// this class to be used with the DialectConversion framework, particularly +// the ConversionPattern constructor requires a TypeConverter instance. +struct ContextAwareTypeConverter : TypeConverter { + public: + ContextAwareTypeConverter() { + // Required to conform to dialect conversion, otherwise using this in a + // conversion pattern will always fail. + addConversion([](Type type) { return type; }); + } + + // Convert a range of values, with converted types stored in newTypes. + virtual LogicalResult convertValueRangeTypes( + ValueRange values, SmallVectorImpl &newTypes) const = 0; + + // Convert types of the arguments and results of a function. + virtual LogicalResult convertFuncSignature( + FunctionOpInterface funcOp, SmallVectorImpl &newArgTypes, + SmallVectorImpl &newResultTypes) const = 0; + + // For use with the normal DialectConversion framework to trigger conversion + // via dynamic legality checks. + bool isLegal(Operation *op) const; + bool isLegal(FunctionOpInterface funcOp) const; + bool isLegal(func::FuncOp funcOp) const { + return isLegal(cast(funcOp.getOperation())); + } +}; + +// A ContextAwareTypeConverter for which the only context needed is an +// attribute, which this class is in charge of retrieving. +struct AttributeAwareTypeConverter : ContextAwareTypeConverter { + public: + virtual FailureOr convert(Type type, Attribute attr) const = 0; + + // Return an Attribute used by convertValueRangeTypes to convert the type of + // the input `value`. If no usable attribute is found, returns a failure. + // This may indicate that no type conversion is necessary. As a result, + // the returned Attribute is never nullptr. + virtual FailureOr getContextualAttr(Value value) const = 0; + + // Convert a range of values, with converted types stored in newTypes. + LogicalResult convertValueRangeTypes( + ValueRange values, SmallVectorImpl &newTypes) const override; + + // Convert types of the arguments and results of a function. + LogicalResult convertFuncSignature( + FunctionOpInterface funcOp, SmallVectorImpl &newArgTypes, + SmallVectorImpl &newResultTypes) const override; +}; + +// An AttributeAwareTypeConverter for which the attribute is determined uniquely +// by a specific string name on the defining op or as a func arg attr. +struct UniquelyNamedAttributeAwareTypeConverter : AttributeAwareTypeConverter { + public: + UniquelyNamedAttributeAwareTypeConverter(StringRef attrName) + : attrName(attrName) {} + + FailureOr getContextualAttr(Value value) const override { + if (auto blockArg = dyn_cast(value)) { + auto *parentOp = blockArg.getOwner()->getParentOp(); + auto funcOp = dyn_cast(parentOp); + if (!funcOp) return failure(); + auto argAttr = funcOp.getArgAttr(blockArg.getArgNumber(), attrName); + if (!argAttr) return failure(); + + return argAttr; + } + + auto *parentOp = value.getDefiningOp(); + if (!parentOp || !parentOp->hasAttr(attrName)) return failure(); + + return parentOp->getAttr(attrName); + } + + private: + std::string attrName; +}; + +struct ConvertFuncWithContextAwareTypeConverter + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + ConvertFuncWithContextAwareTypeConverter( + const ContextAwareTypeConverter &contextAwareTypeConverter, + MLIRContext *context) + : OpRewritePattern(context), + contextAwareTypeConverter(&contextAwareTypeConverter) {} + + LogicalResult matchAndRewrite(func::FuncOp op, + PatternRewriter &rewriter) const override; + + private: + const ContextAwareTypeConverter *contextAwareTypeConverter; +}; + +} // namespace heir +} // namespace mlir + +#endif // LIB_UTILS_CONTEXTAWARETYPECONVERSION_H_ diff --git a/lib/Utils/ConversionUtils.cpp b/lib/Utils/ConversionUtils.cpp index db1a4db53..a24ea0c33 100644 --- a/lib/Utils/ConversionUtils.cpp +++ b/lib/Utils/ConversionUtils.cpp @@ -1,6 +1,5 @@ #include "lib/Utils/ConversionUtils.h" -#include #include #include #include @@ -23,9 +22,8 @@ #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/IR/Verifier.h" // from @llvm-project #include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir { @@ -39,10 +37,10 @@ LogicalResult convertAnyOperand(const TypeConverter *typeConverter, Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) { const auto *typeWithAttrTypeConverter = - dynamic_cast(typeConverter); + dynamic_cast(typeConverter); if (typeWithAttrTypeConverter) { - if (typeWithAttrTypeConverter->isOperationLegal(op)) { + if (typeWithAttrTypeConverter->isLegal(op)) { return failure(); } } else { @@ -54,10 +52,18 @@ LogicalResult convertAnyOperand(const TypeConverter *typeConverter, SmallVector newOperandTypes; SmallVector newResultTypes; if (typeWithAttrTypeConverter) { - typeWithAttrTypeConverter->convertOpResultTypes(op, newResultTypes); - typeWithAttrTypeConverter->convertValueRangeTypes(op->getOperands(), - newOperandTypes); + if (failed(typeWithAttrTypeConverter->convertValueRangeTypes( + op->getResults(), newResultTypes))) + return failure(); + + if (failed(typeWithAttrTypeConverter->convertValueRangeTypes( + op->getOperands(), newOperandTypes))) + return failure(); + if (newOperandTypes == op->getOperandTypes() && + newResultTypes == op->getResultTypes()) { + return failure(); + } } else { auto result = typeConverter->convertTypes(op->getResultTypes(), newResultTypes); @@ -308,127 +314,14 @@ int widthFromEncodingAttr(Attribute encoding) { }); } -Attribute TypeWithAttrTypeConverter::getValueAttr(Value value) const { - Attribute attr; - if (auto blockArg = dyn_cast(value)) { - auto *parentOp = blockArg.getOwner()->getParentOp(); - auto funcOp = dyn_cast(parentOp); - if (funcOp) { - attr = funcOp.getArgAttr(blockArg.getArgNumber(), attrName); - } - } else { - auto *parentOp = value.getDefiningOp(); - attr = parentOp->getAttr(attrName); - } - return attr; -} - -void TypeWithAttrTypeConverter::convertValueRangeTypes( - ValueRange values, SmallVectorImpl &newTypes) const { - newTypes.reserve(values.size()); - for (auto value : values) { - Attribute attr = getValueAttr(value); - auto newType = convertTypeWithAttr(value.getType(), attr); - // this is actually unsafe... - // all the thing should be done through the rewriter, - // if we are using the rewriter - value.setType(newType); - newTypes.push_back(newType); - } -} - -void TypeWithAttrTypeConverter::convertOpResultTypes( - Operation *op, SmallVectorImpl &newResultTypes) const { - newResultTypes.reserve(op->getResultTypes().size()); - auto attr = op->getAttr(attrName); - for (auto resultType : op->getResultTypes()) { - auto newType = convertTypeWithAttr(resultType, attr); - newResultTypes.push_back(newType); - } -} - -void TypeWithAttrTypeConverter::convertFuncArgumentAndResultTypes( - FunctionOpInterface funcOp, SmallVectorImpl &newArgTypes, - SmallVectorImpl &newResultTypes) const { - for (auto argument : funcOp.getArguments()) { - auto attr = funcOp.getArgAttr(argument.getArgNumber(), attrName); - auto newType = convertTypeWithAttr(argument.getType(), attr); - // this is actually unsafe... - // we should go through rewriter.convertRegionTypes, - // which will create unresolved_materializaton, - // and everything is safe. - argument.setType(newType); - newArgTypes.push_back(newType); - } - // did not convert block arg/signature though.. - for (auto &block : funcOp.getBlocks()) { - for (auto result : block.getTerminator()->getOperands()) { - auto attr = getValueAttr(result); - auto newType = convertTypeWithAttr(result.getType(), attr); - result.setType(newType); - newResultTypes.push_back(newType); - } - } -} - -bool TypeWithAttrTypeConverter::isValueLegal(Value value) const { - auto attr = getValueAttr(value); - return value.getType() == convertTypeWithAttr(value.getType(), attr); -} - -bool TypeWithAttrTypeConverter::isOperationLegal(Operation *op) const { - for (auto operand : op->getOperands()) { - if (!isValueLegal(operand)) { - return false; - } - } - for (auto result : op->getResults()) { - if (!isValueLegal(result)) { - return false; - } - } - for (auto ®ion : op->getRegions()) { - for (auto &block : region) { - for (auto argument : block.getArguments()) { - if (!isValueLegal(argument)) { - return false; - } - } - for (auto result : block.getTerminator()->getOperands()) { - if (!isValueLegal(result)) { - return false; - } - } - } - } - return true; -} - -bool TypeWithAttrTypeConverter::isFuncArgumentAndResultLegal( - FunctionOpInterface funcOp) { - for (auto argument : funcOp.getArguments()) { - if (!isValueLegal(argument)) { - return false; - } - } - for (auto &block : funcOp.getBlocks()) { - for (auto result : block.getTerminator()->getOperands()) { - if (!isValueLegal(result)) { - return false; - } - } - } - return true; -} - FailureOr getContextualArgFromFunc(Operation *op, Type argType) { - for (auto block_arg : op->getParentOfType() - .getBody() - .getBlocks() - .front() - .getArguments()) { - if (block_arg.getType() == argType) { - return block_arg; + for (auto blockArg : op->getParentOfType() + .getBody() + .getBlocks() + .front() + .getArguments()) { + if (blockArg.getType() == argType) { + return blockArg; } } return failure(); diff --git a/lib/Utils/ConversionUtils.h b/lib/Utils/ConversionUtils.h index 527d01acc..591f7eae2 100644 --- a/lib/Utils/ConversionUtils.h +++ b/lib/Utils/ConversionUtils.h @@ -12,6 +12,7 @@ #include "lib/Dialect/Secret/IR/SecretOps.h" #include "lib/Dialect/TensorExt/IR/TensorExtOps.h" #include "lib/Dialect/TfheRust/IR/TfheRustTypes.h" +#include "lib/Utils/ContextAwareTypeConversion.h" #include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project #include "llvm/include/llvm/Support/Casting.h" // from @llvm-project #include "llvm/include/llvm/Support/ErrorHandling.h" // from @llvm-project @@ -100,97 +101,6 @@ struct ConvertBinOp : public OpConversionPattern { } }; -struct ContextAwareTypeConverter : public TypeConverter { - public: - // Convert types of the values in the input range, taking into account the - // context of the values (e.g., defining ops or uses). - // NOTE that this also converts the types of the values themselves, - // beyond just calculate the new type - virtual void convertValueRangeTypes( - ValueRange values, SmallVectorImpl &newTypes) const = 0; - - // Convert types of the results of an op, taking into account the context of - // the op when selecting the new type. - // NOTE that this also converts the types of the results themselves, - // beyond just calculate the new type - virtual void convertOpResultTypes( - Operation *op, SmallVectorImpl &newResultTypes) const = 0; - - // Convert types of the arguments and results of a function, taking into - // account the context of the function when selecting the new types. - // Note that this method is not used for converting the function type itself. - // NOTE that this also converts the types of the arguments/results themselves, - // beyond just calculate the new type - virtual void convertFuncArgumentAndResultTypes( - FunctionOpInterface funcOp, SmallVectorImpl &newArgTypes, - SmallVectorImpl &newResultTypes) const = 0; -}; - -struct TypeWithAttrTypeConverter : public ContextAwareTypeConverter { - TypeWithAttrTypeConverter(llvm::StringLiteral attrName) - : attrName(attrName) {} - - // inherited TypeConverter should implement this to do actual type conversion - virtual Type convertTypeWithAttr(Type type, Attribute attr) const = 0; - - // Find the attribute associated with the value, if any. - Attribute getValueAttr(Value value) const; - - // Impl the ContextAwareTypeConverter interface - // in it we will use convertTypeWithAttr to do the actual conversion - void convertValueRangeTypes(ValueRange values, - SmallVectorImpl &newTypes) const override; - - void convertOpResultTypes( - Operation *op, SmallVectorImpl &newResultTypes) const override; - - void convertFuncArgumentAndResultTypes( - FunctionOpInterface funcOp, SmallVectorImpl &newArgTypes, - SmallVectorImpl &newResultTypes) const override; - - // Custom hook to check legality - bool isValueLegal(Value value) const; - - bool isOperationLegal(Operation *op) const; - - bool isFuncArgumentAndResultLegal(FunctionOpInterface funcOp); - - protected: - llvm::StringLiteral attrName; -}; - -struct ConvertFuncWithContextAwareTypeConverter - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - ConvertFuncWithContextAwareTypeConverter( - const ContextAwareTypeConverter &contextAwareTypeConverter, - MLIRContext *context) - : OpRewritePattern(context), - contextAwareTypeConverter(&contextAwareTypeConverter) {} - - LogicalResult matchAndRewrite(func::FuncOp op, - PatternRewriter &rewriter) const override { - auto funcOp = cast(op); - - SmallVector newFuncOperandsType; - SmallVector newFuncResultsType; - contextAwareTypeConverter->convertFuncArgumentAndResultTypes( - op, newFuncOperandsType, newFuncResultsType); - - // update the signature - auto newFuncType = FunctionType::get(getContext(), newFuncOperandsType, - newFuncResultsType); - rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newFuncType); }); - - return success(); - } - - private: - const ContextAwareTypeConverter *contextAwareTypeConverter; -}; - template class SecretGenericOpConversion : public OpConversionPattern { @@ -231,14 +141,23 @@ class SecretGenericOpConversion if (contextAwareTypeConverter) { // manually do the OpAdaptor's work SmallVector inputTypes; - contextAwareTypeConverter->convertValueRangeTypes(inputs, inputTypes); + if (failed(contextAwareTypeConverter->convertValueRangeTypes(inputs, + inputTypes))) + return failure(); + rewriter.modifyOpInPlace(op, [&] { + for (const auto &[input, newType] : llvm::zip(inputs, inputTypes)) { + input.setType(newType); + } + }); } // else OpAdaptor will do it for us // convert the result types SmallVector resultTypes; if (contextAwareTypeConverter) { - contextAwareTypeConverter->convertOpResultTypes(op, resultTypes); + if (failed(contextAwareTypeConverter->convertValueRangeTypes( + op.getResults(), resultTypes))) + return failure(); } else { auto result = getTypeConverter()->convertTypes(op.getResultTypes(), resultTypes); @@ -285,14 +204,16 @@ class SecretGenericOpCipherConversion : public SecretGenericOpConversion { secret::GenericOp op, TypeRange outputTypes, ValueRange inputs, ArrayRef attributes, ConversionPatternRewriter &rewriter) const override { - // Check that all inputs are ciphertext. - if (!llvm::all_of(inputs, [&](Value input) { - return isa(input.getType()); - })) { - return failure(); - } - rewriter.replaceOpWithNewOp(op, outputTypes, inputs) - ->setDialectAttrs(attributes); + auto newOp = rewriter.replaceOpWithNewOp(op, outputTypes, inputs); + newOp->setDialectAttrs(attributes); + + // The context-aware type conversions may look for attributes attached + // to the ops that generate operands given to a later op. This new op + // may be the op that produces those operands, so we need to preserve + // the attributes from the original (secret.generic) op on the newly + // created op. + for (auto attribute : op->getAttrs()) + newOp->setAttr(attribute.getName(), attribute.getValue()); return success(); } }; @@ -303,6 +224,14 @@ class SecretGenericOpCipherPlainConversion public: using SecretGenericOpConversion::SecretGenericOpConversion; + // Ciphertext-plaintext ops should take precedence over ciphertext-ciphertext + // ops because the ops being converted (e.g., addi) don't have a plaintext + // variant. + SecretGenericOpCipherPlainConversion(const TypeConverter &typeConverter, + MLIRContext *context) + : SecretGenericOpConversion(typeConverter, context, /*benefit=*/2) { + } + LogicalResult matchAndRewriteInner( secret::GenericOp op, TypeRange outputTypes, ValueRange inputs, ArrayRef attributes, @@ -338,8 +267,10 @@ class SecretGenericOpCipherPlainConversion ciphertextTy.getPlaintextSpace().getEncoding(), ciphertextTy.getPlaintextSpace().getRing()); - rewriter.replaceOpWithNewOp(op, ciphertext, plaintext) - ->setDialectAttrs(attributes); + auto newOp = rewriter.replaceOpWithNewOp(op, ciphertext, plaintext); + newOp->setDialectAttrs(attributes); + for (auto attribute : op->getAttrs()) + newOp->setAttr(attribute.getName(), attribute.getValue()); return success(); } }; @@ -364,11 +295,41 @@ class SecretGenericOpRelinearizeConversion } SmallVector toBasis = {0, 1}; - rewriter - .replaceOpWithNewOp(op, inputs[0], - rewriter.getDenseI32ArrayAttr(fromBasis), - rewriter.getDenseI32ArrayAttr(toBasis)) - ->setDialectAttrs(attributes); + auto newOp = rewriter.replaceOpWithNewOp( + op, inputs[0], rewriter.getDenseI32ArrayAttr(fromBasis), + rewriter.getDenseI32ArrayAttr(toBasis)); + newOp->setDialectAttrs(attributes); + for (auto attribute : op->getAttrs()) + newOp->setAttr(attribute.getName(), attribute.getValue()); + return success(); + } +}; + +template +class SecretGenericOpMulConversion : public SecretGenericOpConversion { + public: + using SecretGenericOpConversion::SecretGenericOpConversion; + + LogicalResult matchAndRewriteInner( + secret::GenericOp op, TypeRange outputTypes, ValueRange inputs, + ArrayRef attributes, + ConversionPatternRewriter &rewriter) const override { + auto plaintextValues = + llvm::to_vector(llvm::make_filter_range(inputs, [&](Value input) { + return !isa(input.getType()); + })); + if (!plaintextValues.empty()) { + return failure(); + } + + // only left for CKKS, should be removed later + auto newOp = rewriter.replaceOpWithNewOp( + op, rewriter.create(op.getLoc(), inputs), + rewriter.getDenseI32ArrayAttr({0, 1, 2}), + rewriter.getDenseI32ArrayAttr({0, 1})); + newOp->setDialectAttrs(attributes); + for (auto attribute : op->getAttrs()) + newOp->setAttr(attribute.getName(), attribute.getValue()); return success(); } }; @@ -392,8 +353,11 @@ class SecretGenericOpRotateConversion op.emitError("expected constant offset for rotate"); } auto offsetAttr = llvm::dyn_cast(constantOffset.getValue()); - rewriter.replaceOpWithNewOp(op, outputTypes, inputs[0], offsetAttr) - ->setDialectAttrs(attributes); + + auto newOp = rewriter.replaceOpWithNewOp(op, inputs[0], offsetAttr); + newOp->setDialectAttrs(attributes); + for (auto attribute : op->getAttrs()) + newOp->setAttr(attribute.getName(), attribute.getValue()); return success(); } }; @@ -448,10 +412,16 @@ class SecretGenericOpModulusSwitchConversion resultOp = insert; } rewriter.replaceOp(op, resultOp); + for (auto attribute : op->getAttrs()) + resultOp->setAttr(attribute.getName(), attribute.getValue()); return success(); } - rewriter.replaceOpWithNewOp(op, outputTypes[0], inputs[0], outputRing) - ->setDialectAttrs(attributes); + + auto newOp = rewriter.replaceOpWithNewOp(op, outputTypes[0], inputs[0], + outputRing); + newOp->setDialectAttrs(attributes); + for (auto attribute : op->getAttrs()) + newOp->setAttr(attribute.getName(), attribute.getValue()); return success(); } }; diff --git a/tests/Dialect/Secret/Conversions/secret_to_bgv/invalid.mlir b/tests/Dialect/Secret/Conversions/secret_to_bgv/invalid.mlir index d3ecc3208..2015c30c7 100644 --- a/tests/Dialect/Secret/Conversions/secret_to_bgv/invalid.mlir +++ b/tests/Dialect/Secret/Conversions/secret_to_bgv/invalid.mlir @@ -1,17 +1,20 @@ // RUN: heir-opt --split-input-file --secret-to-bgv="poly-mod-degree=1024" --verify-diagnostics %s | FileCheck %s // Tests invalid secret types +#mgmt = #mgmt.mgmt module { // expected-error@below {{expected batched secret types to be tensors with dimension matching ring parameter}} - func.func @test_invalid_dimension(%arg0 : !secret.secret>) -> (!secret.secret>) { + func.func @test_invalid_dimension(%arg0 : !secret.secret> {mgmt.mgmt = #mgmt}) -> (!secret.secret>) { return %arg0 : !secret.secret> } } // ----- +#mgmt = #mgmt.mgmt + // CHECK: test_valid_dimension -func.func @test_valid_dimension(%arg0 : !secret.secret>) -> (!secret.secret>) { +func.func @test_valid_dimension(%arg0 : !secret.secret> {mgmt.mgmt = #mgmt}) -> (!secret.secret>) { return %arg0 : !secret.secret> } diff --git a/tests/Dialect/Secret/Conversions/secret_to_ckks/invalid.mlir b/tests/Dialect/Secret/Conversions/secret_to_ckks/invalid.mlir index 9c9006851..f4ed33a6d 100644 --- a/tests/Dialect/Secret/Conversions/secret_to_ckks/invalid.mlir +++ b/tests/Dialect/Secret/Conversions/secret_to_ckks/invalid.mlir @@ -2,17 +2,19 @@ // Tests invalid secret types +#mgmt = #mgmt.mgmt // expected-warning@below {{expected secret types to be tensors with dimension matching ring parameter, pass will not pack tensors into ciphertext SIMD slots}} module { - func.func @test_invalid_dimension(%arg0 : !secret.secret>) -> (!secret.secret>) { + func.func @test_invalid_dimension(%arg0 : !secret.secret> {mgmt.mgmt = #mgmt}) -> (!secret.secret>) { return %arg0 : !secret.secret> } } // ----- +#mgmt = #mgmt.mgmt // CHECK: test_valid_dimension -func.func @test_valid_dimension(%arg0 : !secret.secret>) -> (!secret.secret>) { +func.func @test_valid_dimension(%arg0 : !secret.secret> {mgmt.mgmt = #mgmt}) -> (!secret.secret>) { return %arg0 : !secret.secret> } @@ -21,10 +23,12 @@ func.func @test_valid_dimension(%arg0 : !secret.secret>) -> (!se // Currently we don't support lowering adds on tensors of ciphertexts - the // lowering must implement a loop of add operations on each element. +#mgmt = #mgmt.mgmt // expected-warning@below {{expected secret types to be tensors with dimension matching ring parameter, pass will not pack tensors into ciphertext SIMD slots}} module { - func.func @test_add_tensor_not_packed(%arg0 : !secret.secret>) -> (!secret.secret>) { - // expected-error@below {{failed to legalize}} + // expected-error@below {{failed to legalize}} + // expected-error@below {{Failed to convert function signature}} + func.func @test_add_tensor_not_packed(%arg0 : !secret.secret> {mgmt.mgmt = #mgmt}) -> (!secret.secret>) { %0 = secret.generic ins(%arg0 : !secret.secret>) { ^bb0(%ARG0 : tensor<1023xf32>): %1 = arith.addf %ARG0, %ARG0 : tensor<1023xf32> From d430fa27a7c56ac61dc239cae089079e324c9d12 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Fri, 31 Jan 2025 16:27:52 -0800 Subject: [PATCH 02/16] Add empty shell for new pass and starter tests --- .../ConvertToCiphertextSemantics/BUILD | 29 ++++++++++ .../ConvertToCiphertextSemantics.cpp | 52 ++++++++++++++++++ .../ConvertToCiphertextSemantics.h | 18 ++++++ .../ConvertToCiphertextSemantics.td | 55 +++++++++++++++++++ .../convert_to_ciphertext_semantics/BUILD | 10 ++++ .../convert_to_ciphertext_semantics.mlir | 16 ++++++ .../linalg_reduce.mlir | 39 +++++++++++++ 7 files changed, 219 insertions(+) create mode 100644 lib/Transforms/ConvertToCiphertextSemantics/BUILD create mode 100644 lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp create mode 100644 lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.h create mode 100644 lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.td create mode 100644 tests/Transforms/convert_to_ciphertext_semantics/BUILD create mode 100644 tests/Transforms/convert_to_ciphertext_semantics/convert_to_ciphertext_semantics.mlir create mode 100644 tests/Transforms/convert_to_ciphertext_semantics/linalg_reduce.mlir diff --git a/lib/Transforms/ConvertToCiphertextSemantics/BUILD b/lib/Transforms/ConvertToCiphertextSemantics/BUILD new file mode 100644 index 000000000..cc0166f9d --- /dev/null +++ b/lib/Transforms/ConvertToCiphertextSemantics/BUILD @@ -0,0 +1,29 @@ +load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "ConvertToCiphertextSemantics", + srcs = ["ConvertToCiphertextSemantics.cpp"], + hdrs = ["ConvertToCiphertextSemantics.h"], + deps = [ + ":pass_inc_gen", + "@heir//lib/Dialect/TensorExt/IR:Dialect", + "@heir//lib/Utils:ContextAwareTypeConversion", + "@heir//lib/Utils:ConversionUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) + +add_heir_transforms( + generated_target_name = "pass_inc_gen", + pass_name = "ConvertToCiphertextSemantics", + td_file = "ConvertToCiphertextSemantics.td", +) diff --git a/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp b/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp new file mode 100644 index 000000000..974f2e6ce --- /dev/null +++ b/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp @@ -0,0 +1,52 @@ +#include "lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.h" + +#include "lib/Dialect/TensorExt/IR/TensorExtDialect.h" +#include "lib/Utils/ContextAwareTypeConversion.h" +#include "lib/Utils/ConversionUtils.h" +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project + +#define DEBUG_TYPE "linalg-to-tensor-ext" + +namespace mlir { +namespace heir { + +#define GEN_PASS_DEF_CONVERTTOCIPHERTEXTSEMANTICS +#include "lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.h.inc" + +bool isPowerOfTwo(int64_t n) { return (n > 0) && ((n & (n - 1)) == 0); } + +// This type converter converts types like tensor where the dimensions +// represent tensor-semantic data to tensor, where the last dimension represents the ciphertext or plaintext slot +// count, and the other dimensions are determined by a layout attribute +// indexing. +struct LayoutMaterializationTypeConverter : public AttributeAwareTypeConverter { + public: + LayoutMaterializationTypeConverter(int numSlots) : numSlots(numSlots) {} + + private: + // The number of slots available in each ciphertext. + int numSlots; +}; + +struct ConvertToCiphertextSemantics + : impl::ConvertToCiphertextSemanticsBase { + using ConvertToCiphertextSemanticsBase::ConvertToCiphertextSemanticsBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + auto *module = getOperation(); + + LayoutMaterializationTypeConverter typeConverter; + + // RewritePatternSet patterns(context); + + // (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace heir +} // namespace mlir diff --git a/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.h b/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.h new file mode 100644 index 000000000..b660841d1 --- /dev/null +++ b/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.h @@ -0,0 +1,18 @@ +#ifndef LIB_TRANSFORMS_CONVERTTOCIPHERTEXTSEMANTICS_CONVERTTOCIPHERTEXTSEMANTICS_H_ +#define LIB_TRANSFORMS_CONVERTTOCIPHERTEXTSEMANTICS_CONVERTTOCIPHERTEXTSEMANTICS_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DECL +#include "lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.h.inc" + +#define GEN_PASS_REGISTRATION +#include "lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.h.inc" + +} // namespace heir +} // namespace mlir + +#endif // LIB_TRANSFORMS_CONVERTTOCIPHERTEXTSEMANTICS_CONVERTTOCIPHERTEXTSEMANTICS_H_ diff --git a/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.td b/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.td new file mode 100644 index 000000000..6914e00d9 --- /dev/null +++ b/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.td @@ -0,0 +1,55 @@ +#ifndef LIB_TRANSFORMS_CONVERTTOCIPHERTEXTSEMANTICS_CONVERTTOCIPHERTEXTSEMANTICS_TD_ +#define LIB_TRANSFORMS_CONVERTTOCIPHERTEXTSEMANTICS_CONVERTTOCIPHERTEXTSEMANTICS_TD_ + +include "mlir/Pass/PassBase.td" + +def ConvertToCiphertextSemantics : Pass<"convert-to-ciphertext-semantics"> { + let summary = "Converts programs with tensor semantics to ciphertext semantics"; + let description = [{ + This pass performs two inherently intertwined transformations: + + 1. Convert a program from tensor semantics to ciphertext semantics, explained below. + 2. Implement ops defined on tensor-semantic types in terms of ops defined on + ciphertext-semantic types. + + A program is defined to have _tensor semantics_ if the tensor-typed values + are manipulated according to standard MLIR tensor operations and semantics. + + A program is defined to have _ciphertext semantics_ if the tensor-typed + values correspond to tensors of FHE ciphertexts, where the last dimension of + the tensor type is the number of ciphertext slots. + + For example, a tensor of type `tensor<32x32xi16>` with tensor semantics might + be converted by this pass, depending on the pass options, to a single + ciphertext-semantics `tensor<65536xi16>`. A larger tensor might, depending on + the layout chosen by earlier passes, be converted to a `tensor<4x32768xi16>`, + where the trailing dimension corresponds to the number of slots in the + ciphertext. + + Tensors with ciphertext semantics can be thought of as an intermediate step + between lowering from tensor types with tensor semantics to concrete `lwe` + dialect ciphertext types in a particular FHE scheme. Having this intermediate + step is useful because some optimizations are easier to implement, and can be + implemented more generically, in the abstract FHE computational model + where the data types are large tensors, and the operations are SIMD additions, + multiplications, and cyclic rotations. + + Function arguments and return values are annotated with the original tensor + type in the `secret.original_type` attribute. This enables later lowerings + to implement appropriate encoding and decoding routines for FHE schemes. + + The second role of this pass is to implement FHE kernels for various high-level + tensor operations, such as `linalg.matvec`. This must happen at the same time + as the type conversion because the high-level ops like `linalg.matvec` are + not well-defined on ciphertext-semantic tensors, while their implementation + as SIMD/rotation ops are not well-defined on tensor-semantic tensors. + + FIXME: provide examples + }]; + let dependentDialects = [ + "mlir::tensor::TensorDialect", + "mlir::heir::tensor_ext::TensorExtDialect", + ]; +} + +#endif // LIB_TRANSFORMS_CONVERTTOCIPHERTEXTSEMANTICS_CONVERTTOCIPHERTEXTSEMANTICS_TD_ diff --git a/tests/Transforms/convert_to_ciphertext_semantics/BUILD b/tests/Transforms/convert_to_ciphertext_semantics/BUILD new file mode 100644 index 000000000..c571e6fc6 --- /dev/null +++ b/tests/Transforms/convert_to_ciphertext_semantics/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/Transforms/convert_to_ciphertext_semantics/convert_to_ciphertext_semantics.mlir b/tests/Transforms/convert_to_ciphertext_semantics/convert_to_ciphertext_semantics.mlir new file mode 100644 index 000000000..c933c03d0 --- /dev/null +++ b/tests/Transforms/convert_to_ciphertext_semantics/convert_to_ciphertext_semantics.mlir @@ -0,0 +1,16 @@ +// RUN: heir-opt %s --convert-to-ciphertext-semantics | FileCheck %s + +// CHECK-LABEL: @convert_minimal_example +#map = affine_map<(d0, d1) -> (d0 * 32 + d1)> + +func.func @convert_minimal_example( + %arg0: !secret.secret> {tensor_ext.layout = #tensor_ext.layout (d0 * 32 + d1)>}) -> + (!secret.secret> {tensor_ext.layout = #tensor_ext.layout (d0 * 32 + d1)>}) { + %0 = secret.generic ins(%arg0, : !secret.secret>) + attrs = {arg0 = {layout = #map}, layout = [#map]} { + ^body(%input0: tensor<32x32xi16>): + %1 = linalg.concat dim(0) %input0, %input0 : (tensor<32x32xi16>, tensor<32x32xi16>) -> tensor<64x32xi16> + secret.yield %1 : tensor<32xi16> + } -> !secret.secret> + return %0 : !secret.secret> +} diff --git a/tests/Transforms/convert_to_ciphertext_semantics/linalg_reduce.mlir b/tests/Transforms/convert_to_ciphertext_semantics/linalg_reduce.mlir new file mode 100644 index 000000000..6260e14ee --- /dev/null +++ b/tests/Transforms/convert_to_ciphertext_semantics/linalg_reduce.mlir @@ -0,0 +1,39 @@ +// RUN: heir-opt %s --convert-to-ciphertext-semantics | FileCheck %s + +// CHECK-LABEL: @convert_linalg_reduce +#map = affine_map<(d0, d1) -> (d0 * 32 + d1)> +#map1 = affine_map<(d0) -> (d0)> +#map2 = affine_map<(d0) -> (d0 * 32)> + +// FIXME: decide what this pass should do +func.func @convert_linalg_reduce( + %arg0: !secret.secret> {tensor_ext.layout = #tensor_ext.layout (d0 * 32 + d1)>}, + %arg1: !secret.secret> {tensor_ext.layout = #tensor_ext.layout (d0 * 32 + d1)>}) -> + (!secret.secret> {tensor_ext.layout = #tensor_ext.layout (d0)>}) { + %cst = arith.constant dense<0> : tensor<32xi16> + %cst_0 = arith.constant dense<0> : tensor<32xi16> + + %0 = secret.generic ins(%arg0, %arg1 : !secret.secret>, !secret.secret>) + attrs = {arg0 = {layout = #map}, arg1 = {layout = #map}, layout = [#map1]} { + ^body(%input0: tensor<32x32xi16>, %input1: tensor<32x32xi16>): + %1 = tensor_ext.assign_layout %cst {layout = #map1} : tensor<32xi16> + + %reduced = linalg.reduce { arith.addi {overflowFlags = #arith.overflow} } + ins(%input0 : tensor<32x32xi16>) + outs(%1 : tensor<32xi16>) + dimensions = [0] {layout = [#map1]} + + %2 = tensor_ext.assign_layout %cst_0 {layout = #map1} : tensor<32xi16> + %3 = tensor_ext.convert_layout %2 {from_layout = #map1, layout = [#map2], to_layout = #map2} : tensor<32xi16> + + %reduced_1 = linalg.reduce { arith.addi {overflowFlags = #arith.overflow} } + ins(%input1 : tensor<32x32xi16>) + outs(%3 : tensor<32xi16>) + dimensions = [1] {layout = [#map2]} + + %4 = tensor_ext.convert_layout %reduced_1 {from_layout = #map2, layout = [#map1], to_layout = #map1} : tensor<32xi16> + %5 = arith.addi %reduced, %4 {layout = [#map1]} : tensor<32xi16> + secret.yield %5 : tensor<32xi16> + } -> !secret.secret> + return %0 : !secret.secret> +} From 15a03e04e83be92afdc6b0c18123601304f4e257 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 19 Feb 2025 16:07:12 -0800 Subject: [PATCH 03/16] add pass to heir-opt --- tools/BUILD | 1 + tools/heir-opt.cpp | 2 ++ 2 files changed, 3 insertions(+) diff --git a/tools/BUILD b/tools/BUILD index 3a7f63c5f..8b21d9a0a 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -106,6 +106,7 @@ cc_binary( "@heir//lib/Transforms/ConvertSecretForToStaticFor", "@heir//lib/Transforms/ConvertSecretInsertToStaticInsert", "@heir//lib/Transforms/ConvertSecretWhileToStaticFor", + "@heir//lib/Transforms/ConvertToCiphertextSemantics", "@heir//lib/Transforms/DropUnitDims", "@heir//lib/Transforms/ElementwiseToAffine", "@heir//lib/Transforms/ForwardInsertToExtract", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 0afc29821..d70d297f7 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -59,6 +59,7 @@ #include "lib/Transforms/ConvertSecretForToStaticFor/ConvertSecretForToStaticFor.h" #include "lib/Transforms/ConvertSecretInsertToStaticInsert/ConvertSecretInsertToStaticInsert.h" #include "lib/Transforms/ConvertSecretWhileToStaticFor/ConvertSecretWhileToStaticFor.h" +#include "lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.h" #include "lib/Transforms/DropUnitDims/DropUnitDims.h" #include "lib/Transforms/ElementwiseToAffine/ElementwiseToAffine.h" #include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h" @@ -246,6 +247,7 @@ int main(int argc, char **argv) { registerConvertSecretWhileToStaticForPasses(); registerConvertSecretExtractToStaticExtractPasses(); registerConvertSecretInsertToStaticInsertPasses(); + registerConvertToCiphertextSemanticsPasses(); registerDropUnitDims(); registerAnnotateSecretnessPasses(); registerApplyFoldersPasses(); From ce43e35e34d2724c9800fa651a2dfe95c2e8d26e Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 19 Feb 2025 16:07:39 -0800 Subject: [PATCH 04/16] add helper for iterating over a tensor index domain --- lib/Utils/Utils.cpp | 19 +++++++++++++++++++ lib/Utils/Utils.h | 7 +++++++ 2 files changed, 26 insertions(+) diff --git a/lib/Utils/Utils.cpp b/lib/Utils/Utils.cpp index 0002028ef..17528166a 100644 --- a/lib/Utils/Utils.cpp +++ b/lib/Utils/Utils.cpp @@ -69,5 +69,24 @@ bool containsArgumentOfType(Operation *op, TypePredicate predicate) { }); } +void iterateIndices(const std::vector &shape, + const IndexTupleConsumer &process) { + if (shape.empty()) return; + std::vector index(shape.size(), 0); + bool done = false; + while (!done) { + process(index); + for (int i = shape.size() - 1; i >= 0; --i) { + if (++index[i] < shape[i]) { + break; + } + index[i] = 0; + if (i == 0) { + done = true; + } + } + } +} + } // namespace heir } // namespace mlir diff --git a/lib/Utils/Utils.h b/lib/Utils/Utils.h index f8c0a08d0..e98957453 100644 --- a/lib/Utils/Utils.h +++ b/lib/Utils/Utils.h @@ -23,6 +23,8 @@ typedef std::function TypePredicate; typedef std::function DialectPredicate; +using IndexTupleConsumer = std::function &)>; + template OpPredicate OpEqual() { return [](Operation *op) { return mlir::isa(op); }; @@ -102,6 +104,11 @@ bool containsArgumentOfType(Operation *op) { return containsArgumentOfType(op, TypeEqual()); } +// A helper to iterate over the space of indices of a multidimensional +// tensor whose shape is given by `shape` +void iterateIndices(const std::vector &shape, + const IndexTupleConsumer &processFunc); + } // namespace heir } // namespace mlir From fcc5ad8ad9046332d0e82c6fe0ee69d0a993cbbd Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 20 Feb 2025 09:09:32 -0800 Subject: [PATCH 05/16] add func op finalization hook --- lib/Utils/ContextAwareTypeConversion.cpp | 24 ++++++++++++++---------- lib/Utils/ContextAwareTypeConversion.h | 10 ++++++++++ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/lib/Utils/ContextAwareTypeConversion.cpp b/lib/Utils/ContextAwareTypeConversion.cpp index 2e7bb0c45..3d366cec9 100644 --- a/lib/Utils/ContextAwareTypeConversion.cpp +++ b/lib/Utils/ContextAwareTypeConversion.cpp @@ -94,17 +94,17 @@ LogicalResult AttributeAwareTypeConverter::convertFuncSignature( } LogicalResult ConvertFuncWithContextAwareTypeConverter::matchAndRewrite( - func::FuncOp op, PatternRewriter &rewriter) const { - auto funcOp = cast(op); - - SmallVector newFuncOperandsType; - SmallVector newFuncResultsType; + func::FuncOp funcOp, PatternRewriter &rewriter) const { + SmallVector oldFuncOperandTypes(funcOp.getFunctionType().getInputs()); + SmallVector oldFuncResultTypes(funcOp.getFunctionType().getResults()); + SmallVector newFuncOperandTypes; + SmallVector newFuncResultTypes; if (failed(contextAwareTypeConverter->convertFuncSignature( - op, newFuncOperandsType, newFuncResultsType))) - return failure(); + funcOp, newFuncOperandTypes, newFuncResultTypes))) + return funcOp->emitError("Failed to convert function signature"); auto newFuncType = - FunctionType::get(getContext(), newFuncOperandsType, newFuncResultsType); + FunctionType::get(getContext(), newFuncOperandTypes, newFuncResultTypes); rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newFuncType); @@ -112,7 +112,7 @@ LogicalResult ConvertFuncWithContextAwareTypeConverter::matchAndRewrite( // Set the block argument types to match the new signature for (auto [arg, newType] : llvm::zip( - funcOp.getBody().front().getArguments(), newFuncOperandsType)) { + funcOp.getBody().front().getArguments(), newFuncOperandTypes)) { arg.setType(newType); } @@ -126,11 +126,15 @@ LogicalResult ConvertFuncWithContextAwareTypeConverter::matchAndRewrite( // legalize and the conversion engine looped infinitely. Block &block = funcOp.getBody().front(); for (auto [returnOperand, newType] : - llvm::zip(block.getTerminator()->getOperands(), newFuncResultsType)) { + llvm::zip(block.getTerminator()->getOperands(), newFuncResultTypes)) { returnOperand.setType(newType); } }); + if (failed(finalizeFuncOpModification(funcOp, oldFuncOperandTypes, + oldFuncResultTypes, rewriter))) + return failure(); + return success(); } diff --git a/lib/Utils/ContextAwareTypeConversion.h b/lib/Utils/ContextAwareTypeConversion.h index 8128fe2b5..5c7437c77 100644 --- a/lib/Utils/ContextAwareTypeConversion.h +++ b/lib/Utils/ContextAwareTypeConversion.h @@ -113,6 +113,16 @@ struct ConvertFuncWithContextAwareTypeConverter LogicalResult matchAndRewrite(func::FuncOp op, PatternRewriter &rewriter) const override; + // An overridable hook that allows subclasses to perform additional + // modifications of the func op after its type signature has been converted. + // For example, a subclass may use this hook to modify arg attrs. + LogicalResult finalizeFuncOpModification(func::FuncOp op, + ArrayRef oldArgTypes, + ArrayRef oldResultTypes, + PatternRewriter &rewriter) const { + return success(); + }; + private: const ContextAwareTypeConverter *contextAwareTypeConverter; }; From ef3dd89a8d28095426fd6cb25225624a6f92229f Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Mon, 24 Feb 2025 08:56:04 -0800 Subject: [PATCH 06/16] add original_type attribute --- lib/Dialect/TensorExt/IR/TensorExtAttributes.td | 13 +++++++++++++ lib/Dialect/TensorExt/IR/TensorExtDialect.td | 2 ++ 2 files changed, 15 insertions(+) diff --git a/lib/Dialect/TensorExt/IR/TensorExtAttributes.td b/lib/Dialect/TensorExt/IR/TensorExtAttributes.td index 9bac27eac..6ae7056a5 100644 --- a/lib/Dialect/TensorExt/IR/TensorExtAttributes.td +++ b/lib/Dialect/TensorExt/IR/TensorExtAttributes.td @@ -64,4 +64,17 @@ def SIMDPacking_Attr : TensorExt_Attr<"SIMDPacking", "simd_packing", let assemblyFormat = "`<` struct(params) `>`"; } +def OriginalType_Attr : TensorExt_Attr<"OriginalType", "original_type"> { + let summary = "The original type of a secret tensor whose layout has been converted to ciphertext semantics."; + let description = [{ + // FIXME: add description + }]; + let parameters = (ins + "::mlir::Type":$originalType, + "::mlir::AffineMap":$layout + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + + #endif // LIB_DIALECT_TENSOREXT_IR_TENSOREXTATTRIBUTES_TD_ diff --git a/lib/Dialect/TensorExt/IR/TensorExtDialect.td b/lib/Dialect/TensorExt/IR/TensorExtDialect.td index 0815f2785..bea5daf7e 100644 --- a/lib/Dialect/TensorExt/IR/TensorExtDialect.td +++ b/lib/Dialect/TensorExt/IR/TensorExtDialect.td @@ -20,6 +20,8 @@ def TensorExt_Dialect : Dialect { let extraClassDeclaration = [{ constexpr const static ::llvm::StringLiteral kLayoutAttrName = "tensor_ext.layout"; + constexpr const static ::llvm::StringLiteral + kOriginalTypeAttrName = "tensor_ext.original_type"; }]; From 9725d3a1c28bfdf465eb982c42298b32e4060264 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 20 Feb 2025 09:09:58 -0800 Subject: [PATCH 07/16] checkpoint implementation --- .../ConvertToCiphertextSemantics/BUILD | 2 + .../ConvertToCiphertextSemantics.cpp | 173 +++++++++++++++++- .../ConvertToCiphertextSemantics.td | 14 +- lib/Utils/ContextAwareTypeConversion.cpp | 1 + lib/Utils/ContextAwareTypeConversion.h | 7 +- lib/Utils/ConversionUtils.h | 4 +- .../convert_to_ciphertext_semantics.mlir | 14 +- 7 files changed, 195 insertions(+), 20 deletions(-) diff --git a/lib/Transforms/ConvertToCiphertextSemantics/BUILD b/lib/Transforms/ConvertToCiphertextSemantics/BUILD index cc0166f9d..adb2a4d8a 100644 --- a/lib/Transforms/ConvertToCiphertextSemantics/BUILD +++ b/lib/Transforms/ConvertToCiphertextSemantics/BUILD @@ -12,9 +12,11 @@ cc_library( deps = [ ":pass_inc_gen", "@heir//lib/Dialect/TensorExt/IR:Dialect", + "@heir//lib/Utils", "@heir//lib/Utils:ContextAwareTypeConversion", "@heir//lib/Utils:ConversionUtils", "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", diff --git a/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp b/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp index 974f2e6ce..c6ecd67d9 100644 --- a/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp +++ b/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp @@ -1,14 +1,17 @@ #include "lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.h" +#include "lib/Dialect/TensorExt/IR/TensorExtAttributes.h" #include "lib/Dialect/TensorExt/IR/TensorExtDialect.h" #include "lib/Utils/ContextAwareTypeConversion.h" #include "lib/Utils/ConversionUtils.h" +#include "lib/Utils/Utils.h" #include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project -#define DEBUG_TYPE "linalg-to-tensor-ext" +#define DEBUG_TYPE "convert-to-ciphertext-semantics" namespace mlir { namespace heir { @@ -25,11 +28,159 @@ bool isPowerOfTwo(int64_t n) { return (n > 0) && ((n & (n - 1)) == 0); } // indexing. struct LayoutMaterializationTypeConverter : public AttributeAwareTypeConverter { public: - LayoutMaterializationTypeConverter(int numSlots) : numSlots(numSlots) {} + LayoutMaterializationTypeConverter(int ciphertextSize) + : ciphertextSize(ciphertextSize) {} + + FailureOr convert(Type type, Attribute attr) const override { + // Convert secret> to secret> + // Convert tensor<...> to tensor<...> + bool isSecret = isa(type); + if (isSecret) { + auto secretType = cast(type); + auto innerType = secretType.getValueType(); + auto convertedInnerType = convert(innerType, attr); + if (failed(convertedInnerType)) return failure(); + return secret::SecretType::get(convertedInnerType.value()); + } + + auto rankedTensorType = dyn_cast(type); + if (!rankedTensorType) return failure(); + + auto layoutAttr = dyn_cast(attr); + if (!layoutAttr) return failure(); + AffineMap layout = layoutAttr.getValue(); + + MLIRContext *ctx = type.getContext(); + OpBuilder b(ctx); + + // Each ciphertext will always have ciphertextSize many slots, so the main + // goal is to determine how many ciphertexts are needed. We do this by + // iterating over the input type's index domain, and apply the layout + // affine map to each index, and keep track of the maximum value of each + // index of the map results. These maxima (plus 1 for zero indexing) + // will be the shape of the new type. + SmallVector outputTensorShape(layout.getNumResults(), 0); + outputTensorShape[layout.getNumResults() - 1] = ciphertextSize; + + // Evaluate the affine map on the input indices and update the + // outputTensorShape to be a max over visited indices. + IndexTupleConsumer evaluateNextIndex = + [&](const std::vector &indices) { + SmallVector mapInputs = llvm::map_to_vector( + indices, + [&](int64_t i) { return cast(b.getIndexAttr(i)); }); + + // Evaluate the affine map on the inputs + SmallVector results; + if (failed(layout.constantFold(mapInputs, results))) { + assert(false && "constant folding should never fail here"); + } + + // minus 1 to skip the last dimension (ciphertext dimension) + for (int i = 0; i < layout.getNumResults() - 1; ++i) { + // 1 + to account for zero indexing + outputTensorShape[i] = + std::max(outputTensorShape[i], + 1 + cast(results[i]).getInt()); + } + }; + + iterateIndices(rankedTensorType.getShape(), evaluateNextIndex); + return RankedTensorType::get(outputTensorShape, + rankedTensorType.getElementType()); + } + + // Each value is expected to be produced by an operation whose `layout` + // attributes correspond to the chosen layouts of the operation results. + FailureOr getContextualAttr(Value value) const override { + auto *parentOp = value.getDefiningOp(); + + // It may be a block argument + if (!parentOp) { + // It may be a func arg + auto blockArg = dyn_cast(value); + auto *parentOp = blockArg.getOwner()->getParentOp(); + auto funcOp = dyn_cast(parentOp); + if (funcOp) { + auto argAttr = + funcOp.getArgAttr(blockArg.getArgNumber(), + tensor_ext::TensorExtDialect::kLayoutAttrName); + if (!argAttr) return failure(); + + return argAttr; + } + + // It may be a secret.generic arg + auto genericOp = dyn_cast(parentOp); + if (genericOp) { + return cast( + genericOp.getArgAttr(blockArg.getArgNumber(), "layout")); + } + + return failure(); + } + + // For any other op, the layout attribute is an array of result layouts + ArrayAttr resultLayouts = parentOp->getAttrOfType("layout"); + + int valueIndex = -1; + for (auto result : parentOp->getResults()) { + ++valueIndex; + if (result == value) break; + } + + if (valueIndex == -1) { + return failure(); + } + + return cast(resultLayouts[valueIndex]); + } private: // The number of slots available in each ciphertext. - int numSlots; + int ciphertextSize; +}; + +bool hasLayoutArgAttrs(func::FuncOp op) { + for (int i = 0; i < op.getNumArguments(); ++i) { + if (op.getArgAttr(i, tensor_ext::TensorExtDialect::kLayoutAttrName)) + return true; + } + return false; +} + +bool hasLayoutResultAttrs(Operation *op) { + return (op->getAttrOfType("layout") != nullptr); +} + +struct ConvertFunc : public ConvertFuncWithContextAwareTypeConverter { + public: + using ConvertFuncWithContextAwareTypeConverter:: + ConvertFuncWithContextAwareTypeConverter; + + ConvertFunc(const ContextAwareTypeConverter &typeConverter, + MLIRContext *context) + : ConvertFuncWithContextAwareTypeConverter(typeConverter, context) {} + + LogicalResult finalizeFuncOpModification( + func::FuncOp op, ArrayRef oldArgTypes, + ArrayRef oldResultTypes, PatternRewriter &rewriter) const override { + // Replace layout arg attrs with secret.original_type arg attrs + rewriter.modifyOpInPlace(op, [&] { + for (int i = 0; i < op.getNumArguments(); ++i) { + auto layoutAttr = + op.getArgAttr(i, tensor_ext::TensorExtDialect::kLayoutAttrName); + if (!layoutAttr) continue; + + op.removeArgAttr(i, tensor_ext::TensorExtDialect::kLayoutAttrName); + AffineMap layout = cast(layoutAttr).getValue(); + op.setArgAttr(i, tensor_ext::TensorExtDialect::kOriginalTypeAttrName, + tensor_ext::OriginalTypeAttr::get( + getContext(), oldArgTypes[i], layout)); + } + }); + return success(); + }; }; struct ConvertToCiphertextSemantics @@ -40,11 +191,21 @@ struct ConvertToCiphertextSemantics MLIRContext *context = &getContext(); auto *module = getOperation(); - LayoutMaterializationTypeConverter typeConverter; + LayoutMaterializationTypeConverter typeConverter = + LayoutMaterializationTypeConverter(ciphertextSize); + + RewritePatternSet patterns(context); + ConversionTarget target(*context); + target.addDynamicallyLegalOp( + [&](func::FuncOp op) { return !hasLayoutArgAttrs(op); }); + target.markUnknownOpDynamicallyLegal( + [&](Operation *op) { return !hasLayoutResultAttrs(op); }); - // RewritePatternSet patterns(context); + patterns.add(typeConverter, context); - // (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + return signalPassFailure(); + } } }; diff --git a/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.td b/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.td index 6914e00d9..f5ea6b733 100644 --- a/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.td +++ b/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.td @@ -47,8 +47,20 @@ def ConvertToCiphertextSemantics : Pass<"convert-to-ciphertext-semantics"> { FIXME: provide examples }]; let dependentDialects = [ - "mlir::tensor::TensorDialect", "mlir::heir::tensor_ext::TensorExtDialect", + "mlir::linalg::LinalgDialect", + "mlir::tensor::TensorDialect", + ]; + + // TODO(#4102): reevaluate flag name + let options = [ + Option< + "ciphertextSize", + "ciphertext-size", + "int", + /*default=*/"1024", + "Power of two length of the ciphertexts the data is packed in." + > ]; } diff --git a/lib/Utils/ContextAwareTypeConversion.cpp b/lib/Utils/ContextAwareTypeConversion.cpp index 3d366cec9..89d0836cc 100644 --- a/lib/Utils/ContextAwareTypeConversion.cpp +++ b/lib/Utils/ContextAwareTypeConversion.cpp @@ -99,6 +99,7 @@ LogicalResult ConvertFuncWithContextAwareTypeConverter::matchAndRewrite( SmallVector oldFuncResultTypes(funcOp.getFunctionType().getResults()); SmallVector newFuncOperandTypes; SmallVector newFuncResultTypes; + if (failed(contextAwareTypeConverter->convertFuncSignature( funcOp, newFuncOperandTypes, newFuncResultTypes))) return funcOp->emitError("Failed to convert function signature"); diff --git a/lib/Utils/ContextAwareTypeConversion.h b/lib/Utils/ContextAwareTypeConversion.h index 5c7437c77..64492b433 100644 --- a/lib/Utils/ContextAwareTypeConversion.h +++ b/lib/Utils/ContextAwareTypeConversion.h @@ -116,10 +116,9 @@ struct ConvertFuncWithContextAwareTypeConverter // An overridable hook that allows subclasses to perform additional // modifications of the func op after its type signature has been converted. // For example, a subclass may use this hook to modify arg attrs. - LogicalResult finalizeFuncOpModification(func::FuncOp op, - ArrayRef oldArgTypes, - ArrayRef oldResultTypes, - PatternRewriter &rewriter) const { + virtual LogicalResult finalizeFuncOpModification( + func::FuncOp op, ArrayRef oldArgTypes, + ArrayRef oldResultTypes, PatternRewriter &rewriter) const { return success(); }; diff --git a/lib/Utils/ConversionUtils.h b/lib/Utils/ConversionUtils.h index 591f7eae2..c9b8d0f87 100644 --- a/lib/Utils/ConversionUtils.h +++ b/lib/Utils/ConversionUtils.h @@ -353,8 +353,8 @@ class SecretGenericOpRotateConversion op.emitError("expected constant offset for rotate"); } auto offsetAttr = llvm::dyn_cast(constantOffset.getValue()); - - auto newOp = rewriter.replaceOpWithNewOp(op, inputs[0], offsetAttr); + auto newOp = + rewriter.replaceOpWithNewOp(op, outputTypes, inputs[0], offsetAttr); newOp->setDialectAttrs(attributes); for (auto attribute : op->getAttrs()) newOp->setAttr(attribute.getName(), attribute.getValue()); diff --git a/tests/Transforms/convert_to_ciphertext_semantics/convert_to_ciphertext_semantics.mlir b/tests/Transforms/convert_to_ciphertext_semantics/convert_to_ciphertext_semantics.mlir index c933c03d0..f90830e21 100644 --- a/tests/Transforms/convert_to_ciphertext_semantics/convert_to_ciphertext_semantics.mlir +++ b/tests/Transforms/convert_to_ciphertext_semantics/convert_to_ciphertext_semantics.mlir @@ -4,13 +4,13 @@ #map = affine_map<(d0, d1) -> (d0 * 32 + d1)> func.func @convert_minimal_example( - %arg0: !secret.secret> {tensor_ext.layout = #tensor_ext.layout (d0 * 32 + d1)>}) -> - (!secret.secret> {tensor_ext.layout = #tensor_ext.layout (d0 * 32 + d1)>}) { - %0 = secret.generic ins(%arg0, : !secret.secret>) + %arg0: !secret.secret> {tensor_ext.layout = affine_map<(d0, d1) -> (d0 * 32 + d1)>}) -> + (!secret.secret> {tensor_ext.layout = affine_map<(d0, d1) -> (d0 * 32 + d1)>}) { + %0 = secret.generic ins(%arg0 : !secret.secret>) attrs = {arg0 = {layout = #map}, layout = [#map]} { ^body(%input0: tensor<32x32xi16>): - %1 = linalg.concat dim(0) %input0, %input0 : (tensor<32x32xi16>, tensor<32x32xi16>) -> tensor<64x32xi16> - secret.yield %1 : tensor<32xi16> - } -> !secret.secret> - return %0 : !secret.secret> + %1 = arith.addi %input0, %input0 : tensor<32x32xi16> + secret.yield %1 : tensor<32x32xi16> + } -> !secret.secret> + return %0 : !secret.secret> } From b94ce85101d008d27a729b78a98af62f0d73f805 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Fri, 28 Feb 2025 17:22:56 -0800 Subject: [PATCH 08/16] more struggles with type conversion --- lib/Dialect/Secret/IR/SecretOps.cpp | 23 ++- lib/Dialect/Secret/IR/SecretOps.td | 3 + .../ConvertToCiphertextSemantics.cpp | 172 +++++++++++++++--- lib/Utils/ConversionUtils.cpp | 25 +-- lib/Utils/ConversionUtils.h | 80 +++++++- .../convert_to_ciphertext_semantics.mlir | 2 +- 6 files changed, 256 insertions(+), 49 deletions(-) diff --git a/lib/Dialect/Secret/IR/SecretOps.cpp b/lib/Dialect/Secret/IR/SecretOps.cpp index fd173a697..83f9e43fa 100644 --- a/lib/Dialect/Secret/IR/SecretOps.cpp +++ b/lib/Dialect/Secret/IR/SecretOps.cpp @@ -277,19 +277,24 @@ YieldOp GenericOp::getYieldOp() { return *getBody()->getOps().begin(); } -GenericOp cloneWithNewResultTypes(GenericOp op, TypeRange newTypes, - PatternRewriter &rewriter) { - return rewriter.create( - op.getLoc(), op.getOperands(), newTypes, +GenericOp GenericOp::cloneWithNewResultTypes(TypeRange newTypes, + PatternRewriter &rewriter, + bool preserveAttrs) { + auto newOp = rewriter.create( + getLoc(), getOperands(), newTypes, [&](OpBuilder &b, Location loc, ValueRange blockArguments) { IRMapping mp; - for (BlockArgument blockArg : op.getBody()->getArguments()) { + for (BlockArgument blockArg : getBody()->getArguments()) { mp.map(blockArg, blockArguments[blockArg.getArgNumber()]); } - for (auto &op : op.getBody()->getOperations()) { + for (auto &op : getBody()->getOperations()) { b.clone(op, mp); } }); + if (preserveAttrs) { + newOp->setAttrs(getOperation()->getAttrs()); + } + return newOp; } std::pair GenericOp::addNewYieldedValues( @@ -301,7 +306,7 @@ std::pair GenericOp::addNewYieldedValues( SecretType newTy = secret::SecretType::get(t); return newTy; })); - GenericOp newOp = cloneWithNewResultTypes(*this, newTypes, rewriter); + GenericOp newOp = cloneWithNewResultTypes(newTypes, rewriter); auto newResultStartIter = newOp.getResults().drop_front( newOp.getNumResults() - newValuesToYield.size()); @@ -340,7 +345,7 @@ GenericOp GenericOp::removeYieldedValues(ValueRange yieldedValuesToRemove, return newTy; })); - return cloneWithNewResultTypes(*this, newResultTypes, rewriter); + return cloneWithNewResultTypes(newResultTypes, rewriter); } GenericOp GenericOp::removeYieldedValues(ArrayRef yieldedIndicesToRemove, @@ -371,7 +376,7 @@ GenericOp GenericOp::removeYieldedValues(ArrayRef yieldedIndicesToRemove, return newTy; })); - return cloneWithNewResultTypes(*this, newResultTypes, rewriter); + return cloneWithNewResultTypes(newResultTypes, rewriter); } GenericOp GenericOp::extractOpBeforeGeneric(Operation *opToExtract, diff --git a/lib/Dialect/Secret/IR/SecretOps.td b/lib/Dialect/Secret/IR/SecretOps.td index 8ee1915e0..e630b3285 100644 --- a/lib/Dialect/Secret/IR/SecretOps.td +++ b/lib/Dialect/Secret/IR/SecretOps.td @@ -246,6 +246,9 @@ def Secret_GenericOp : Secret_Op<"generic", [ inlineInPlaceDroppingSecrets(rewriter, getOperands()); } + GenericOp cloneWithNewResultTypes(TypeRange newTypes, PatternRewriter &rewriter, + bool preserveAttrs = false); + //===------------------------------------------------------------------===// // Argument Attributes //===------------------------------------------------------------------===// diff --git a/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp b/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp index c6ecd67d9..e747bd5f9 100644 --- a/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp +++ b/lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp @@ -16,22 +16,60 @@ namespace mlir { namespace heir { +auto &kLayoutAttrName = tensor_ext::TensorExtDialect::kLayoutAttrName; +auto &kOriginalTypeAttrName = + tensor_ext::TensorExtDialect::kOriginalTypeAttrName; + #define GEN_PASS_DEF_CONVERTTOCIPHERTEXTSEMANTICS #include "lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.h.inc" bool isPowerOfTwo(int64_t n) { return (n > 0) && ((n & (n - 1)) == 0); } +int getIndexOfOpResult(Operation *op, Value result) { + int index = 0; + for (auto res : op->getResults()) { + if (res == result) return index; + ++index; + } + return -1; +} + +// Remove the layout attribute from the defining op of a given value. Since ops +// may have multiple results, this will not delete the attribute, but rather +// set it to nullptr and expect the rest of this pass to treat a null attribute +// as meaning the type has already been converted. +void tryRemoveLayoutAttrFromDefiningOp(Value value) { + auto *parentOp = value.getDefiningOp(); + if (!parentOp) return; + + ArrayAttr resultLayouts = parentOp->getAttrOfType("layout"); + int resultIndex = getIndexOfOpResult(parentOp, value); + if (resultIndex == -1) return; + + SmallVector newResultLayouts(resultLayouts.begin(), + resultLayouts.end()); + newResultLayouts[resultIndex] = nullptr; + parentOp->setAttr("layout", + ArrayAttr::get(value.getContext(), newResultLayouts)); +} + // This type converter converts types like tensor where the dimensions // represent tensor-semantic data to tensor, where the last dimension represents the ciphertext or plaintext slot // count, and the other dimensions are determined by a layout attribute // indexing. +// +// The presence of a layout attribute on the op definine a value is required +// for this type converter to trigger. So patterns that use this and convert +// types must remove any layout attributes when they are done. struct LayoutMaterializationTypeConverter : public AttributeAwareTypeConverter { public: LayoutMaterializationTypeConverter(int ciphertextSize) : ciphertextSize(ciphertextSize) {} FailureOr convert(Type type, Attribute attr) const override { + LLVM_DEBUG(llvm::dbgs() << "Converting type " << type << " with layout " + << attr << "\n"); // Convert secret> to secret> // Convert tensor<...> to tensor<...> bool isSecret = isa(type); @@ -103,9 +141,10 @@ struct LayoutMaterializationTypeConverter : public AttributeAwareTypeConverter { auto funcOp = dyn_cast(parentOp); if (funcOp) { auto argAttr = - funcOp.getArgAttr(blockArg.getArgNumber(), - tensor_ext::TensorExtDialect::kLayoutAttrName); + funcOp.getArgAttr(blockArg.getArgNumber(), kLayoutAttrName); if (!argAttr) return failure(); + LLVM_DEBUG(llvm::dbgs() + << "Found layout attr " << argAttr << " on function args\n"); return argAttr; } @@ -113,8 +152,13 @@ struct LayoutMaterializationTypeConverter : public AttributeAwareTypeConverter { // It may be a secret.generic arg auto genericOp = dyn_cast(parentOp); if (genericOp) { - return cast( + auto attr = dyn_cast_or_null( genericOp.getArgAttr(blockArg.getArgNumber(), "layout")); + if (!attr) return failure(); + LLVM_DEBUG(llvm::dbgs() + << "Found layout attr " << attr + << " on secret generic, for value " << value << "\n"); + return attr; } return failure(); @@ -122,18 +166,17 @@ struct LayoutMaterializationTypeConverter : public AttributeAwareTypeConverter { // For any other op, the layout attribute is an array of result layouts ArrayAttr resultLayouts = parentOp->getAttrOfType("layout"); - - int valueIndex = -1; - for (auto result : parentOp->getResults()) { - ++valueIndex; - if (result == value) break; - } - - if (valueIndex == -1) { + if (!resultLayouts) return failure(); + int resultIndex = getIndexOfOpResult(parentOp, value); + if (resultIndex == -1) { return failure(); } - return cast(resultLayouts[valueIndex]); + auto attr = dyn_cast_or_null(resultLayouts[resultIndex]); + if (!attr) return failure(); + LLVM_DEBUG(llvm::dbgs() << "Found layout attr " << attr + << " on defining op, for value " << value << "\n"); + return attr; } private: @@ -143,14 +186,33 @@ struct LayoutMaterializationTypeConverter : public AttributeAwareTypeConverter { bool hasLayoutArgAttrs(func::FuncOp op) { for (int i = 0; i < op.getNumArguments(); ++i) { - if (op.getArgAttr(i, tensor_ext::TensorExtDialect::kLayoutAttrName)) - return true; + if (op.getArgAttr(i, kLayoutAttrName)) return true; + } + return false; +} + +bool hasLayoutArgAttrs(secret::GenericOp op) { + for (int i = 0; i < op.getNumOperands(); ++i) { + if (op.getArgAttr(i, "layout")) return true; } return false; } bool hasLayoutResultAttrs(Operation *op) { - return (op->getAttrOfType("layout") != nullptr); + auto layoutAttrs = op->getAttrOfType("layout"); + if (!layoutAttrs) return false; + + // If any layout attribute is nullptr, it means the type has already been + // converted. If any result has a non-null attribute, it still needs to be + // converted. + return llvm::any_of(layoutAttrs, + [](Attribute attr) { return attr != nullptr; }); +} + +bool hasOperandsWithLayouts(Operation *op) { + return llvm::any_of(op->getOperands(), [](Value operand) { + return hasLayoutResultAttrs(operand.getDefiningOp()); + }); } struct ConvertFunc : public ConvertFuncWithContextAwareTypeConverter { @@ -165,24 +227,89 @@ struct ConvertFunc : public ConvertFuncWithContextAwareTypeConverter { LogicalResult finalizeFuncOpModification( func::FuncOp op, ArrayRef oldArgTypes, ArrayRef oldResultTypes, PatternRewriter &rewriter) const override { - // Replace layout arg attrs with secret.original_type arg attrs + // Replace layout arg attrs with secret.original_type arg attrs This is + // necessary so that later encoding/decoding functions can know what the + // original type of the tensor was and how it was encoded. rewriter.modifyOpInPlace(op, [&] { for (int i = 0; i < op.getNumArguments(); ++i) { - auto layoutAttr = - op.getArgAttr(i, tensor_ext::TensorExtDialect::kLayoutAttrName); + auto layoutAttr = op.getArgAttr(i, kLayoutAttrName); if (!layoutAttr) continue; - op.removeArgAttr(i, tensor_ext::TensorExtDialect::kLayoutAttrName); + op.removeArgAttr(i, kLayoutAttrName); AffineMap layout = cast(layoutAttr).getValue(); - op.setArgAttr(i, tensor_ext::TensorExtDialect::kOriginalTypeAttrName, + op.setArgAttr(i, kOriginalTypeAttrName, tensor_ext::OriginalTypeAttr::get( getContext(), oldArgTypes[i], layout)); } + + for (int i = 0; i < op.getNumResults(); ++i) { + auto layoutAttr = dyn_cast_or_null( + op.getResultAttr(i, kLayoutAttrName)); + if (!layoutAttr) continue; + + op.setResultAttr( + i, kOriginalTypeAttrName, + tensor_ext::OriginalTypeAttr::get(getContext(), oldResultTypes[i], + layoutAttr.getValue())); + } + + // Since the func.return was converted, we need to erase layout ops from + // the operations that generated the return's operands. + auto returnOperands = op.getBody().front().getTerminator()->getOperands(); + for (auto returnOperand : returnOperands) { + tryRemoveLayoutAttrFromDefiningOp(returnOperand); + } }); return success(); }; }; +struct ConvertGeneric : public SecretGenericConversion { + using SecretGenericConversion::SecretGenericConversion; + + public: + LogicalResult finalizeOpModification( + secret::GenericOp op, + ConversionPatternRewriter &rewriter) const override { + LLVM_DEBUG(llvm::dbgs() + << "Finalizing secret.generic conversion for " << op << "\n"); + rewriter.modifyOpInPlace(op, [&] { + for (int i = 0; i < op.getNumOperands(); ++i) { + op.removeArgAttr(i, "layout"); + } + + for (int i = 0; i < op.getNumResults(); ++i) { + op->removeAttr("layout"); + } + }); + LLVM_DEBUG(llvm::dbgs() << "Post-Finalization: " << op << "\n"); + return success(); + }; +}; + +// A clone of ConvertAny<> but which erases the layout attribute afterward. +struct ConvertAnyRemovingLayout : public ConversionPattern { + ConvertAnyRemovingLayout(const TypeConverter &anyTypeConverter, + MLIRContext *context) + : ConversionPattern(anyTypeConverter, RewritePattern::MatchAnyOpTypeTag(), + /*benefit=*/0, context) { + setDebugName("ConvertAny"); + setHasBoundedRewriteRecursion(true); + } + + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + FailureOr result = + convertAnyOperand(getTypeConverter(), op, operands, rewriter); + if (failed(result)) return failure(); + + Operation *newOp = result.value(); + rewriter.modifyOpInPlace(newOp, [&] { newOp->removeAttr("layout"); }); + return success(); + } +}; + struct ConvertToCiphertextSemantics : impl::ConvertToCiphertextSemanticsBase { using ConvertToCiphertextSemanticsBase::ConvertToCiphertextSemanticsBase; @@ -198,10 +325,13 @@ struct ConvertToCiphertextSemantics ConversionTarget target(*context); target.addDynamicallyLegalOp( [&](func::FuncOp op) { return !hasLayoutArgAttrs(op); }); + target.addDynamicallyLegalOp( + [&](secret::GenericOp op) { return !hasLayoutArgAttrs(op); }); target.markUnknownOpDynamicallyLegal( [&](Operation *op) { return !hasLayoutResultAttrs(op); }); - patterns.add(typeConverter, context); + patterns.add( + typeConverter, context); if (failed(applyPartialConversion(module, target, std::move(patterns)))) { return signalPassFailure(); diff --git a/lib/Utils/ConversionUtils.cpp b/lib/Utils/ConversionUtils.cpp index a24ea0c33..6abf934b1 100644 --- a/lib/Utils/ConversionUtils.cpp +++ b/lib/Utils/ConversionUtils.cpp @@ -33,14 +33,15 @@ using ::mlir::func::CallOp; using ::mlir::func::FuncOp; using ::mlir::func::ReturnOp; -LogicalResult convertAnyOperand(const TypeConverter *typeConverter, - Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) { - const auto *typeWithAttrTypeConverter = - dynamic_cast(typeConverter); - - if (typeWithAttrTypeConverter) { - if (typeWithAttrTypeConverter->isLegal(op)) { +FailureOr convertAnyOperand(const TypeConverter *typeConverter, + Operation *op, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + const auto *contextAwareTypeConverter = + dynamic_cast(typeConverter); + + if (contextAwareTypeConverter) { + if (contextAwareTypeConverter->isLegal(op)) { return failure(); } } else { @@ -51,12 +52,12 @@ LogicalResult convertAnyOperand(const TypeConverter *typeConverter, SmallVector newOperandTypes; SmallVector newResultTypes; - if (typeWithAttrTypeConverter) { - if (failed(typeWithAttrTypeConverter->convertValueRangeTypes( + if (contextAwareTypeConverter) { + if (failed(contextAwareTypeConverter->convertValueRangeTypes( op->getResults(), newResultTypes))) return failure(); - if (failed(typeWithAttrTypeConverter->convertValueRangeTypes( + if (failed(contextAwareTypeConverter->convertValueRangeTypes( op->getOperands(), newOperandTypes))) return failure(); @@ -88,7 +89,7 @@ LogicalResult convertAnyOperand(const TypeConverter *typeConverter, op->getAttrs(), op->getSuccessors(), regions)); rewriter.replaceOp(op, newOp); - return success(); + return newOp; } struct ConvertExtract : public OpConversionPattern { diff --git a/lib/Utils/ConversionUtils.h b/lib/Utils/ConversionUtils.h index c9b8d0f87..c790f3593 100644 --- a/lib/Utils/ConversionUtils.h +++ b/lib/Utils/ConversionUtils.h @@ -38,9 +38,10 @@ namespace mlir { namespace heir { -LogicalResult convertAnyOperand(const TypeConverter *typeConverter, - Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter); +FailureOr convertAnyOperand(const TypeConverter *typeConverter, + Operation *op, + ArrayRef operands, + ConversionPatternRewriter &rewriter); template struct ConvertAny : public ConversionPattern { @@ -159,9 +160,9 @@ class SecretGenericOpConversion op.getResults(), resultTypes))) return failure(); } else { - auto result = - getTypeConverter()->convertTypes(op.getResultTypes(), resultTypes); - if (failed(result)) return failure(); + if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), + resultTypes))) + return failure(); } // only preserve dialect attrs @@ -195,6 +196,73 @@ class SecretGenericOpConversion } }; +// A converter for a secret.generic that doesn't touch the inner ops +class SecretGenericConversion : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + virtual LogicalResult finalizeOpModification( + secret::GenericOp op, ConversionPatternRewriter &rewriter) const { + return success(); + }; + + LogicalResult matchAndRewrite( + secret::GenericOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + // NOTE: use C++ RTTI instead of LLVM RTTI + // because TypeConverter does not support LLVM RTTI + const auto *contextAwareTypeConverter = + dynamic_cast(getTypeConverter()); + + SmallVector inputs = op.getOperands(); + SmallVector inputTypes; + if (contextAwareTypeConverter) { + // manually do the OpAdaptor's work + SmallVector inputTypes; + if (failed(contextAwareTypeConverter->convertValueRangeTypes(inputs, + inputTypes))) + return failure(); + rewriter.modifyOpInPlace(op, [&] { + for (const auto &[opOperand, newType] : + llvm::zip(op->getOpOperands(), inputTypes)) { + opOperand.get().setType(newType); + + // And set the inner block argument + auto secretType = dyn_cast(newType); + if (secretType) { + auto blockArg = + op.getBody()->getArgument(opOperand.getOperandNumber()); + blockArg.setType(secretType.getValueType()); + } + } + }); + } + // else OpAdaptor will do it for us + + // convert the result types + SmallVector resultTypes; + if (contextAwareTypeConverter) { + if (failed(contextAwareTypeConverter->convertValueRangeTypes( + op.getResults(), resultTypes))) + return failure(); + } else { + if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), + resultTypes))) + return failure(); + } + + // Replace the generic op with a new generic op that has the new result + // types + secret::GenericOp newGeneric = + op.cloneWithNewResultTypes(resultTypes, rewriter); + rewriter.replaceOp(op, newGeneric); + + if (failed(finalizeOpModification(newGeneric, rewriter))) return failure(); + + return success(); + } +}; + template class SecretGenericOpCipherConversion : public SecretGenericOpConversion { public: diff --git a/tests/Transforms/convert_to_ciphertext_semantics/convert_to_ciphertext_semantics.mlir b/tests/Transforms/convert_to_ciphertext_semantics/convert_to_ciphertext_semantics.mlir index f90830e21..b2a2806c8 100644 --- a/tests/Transforms/convert_to_ciphertext_semantics/convert_to_ciphertext_semantics.mlir +++ b/tests/Transforms/convert_to_ciphertext_semantics/convert_to_ciphertext_semantics.mlir @@ -9,7 +9,7 @@ func.func @convert_minimal_example( %0 = secret.generic ins(%arg0 : !secret.secret>) attrs = {arg0 = {layout = #map}, layout = [#map]} { ^body(%input0: tensor<32x32xi16>): - %1 = arith.addi %input0, %input0 : tensor<32x32xi16> + %1 = arith.addi %input0, %input0 {layout = [#map]} : tensor<32x32xi16> secret.yield %1 : tensor<32x32xi16> } -> !secret.secret> return %0 : !secret.secret> From 05a5b6687741f10a93a0d7dfe19fdb6c640b9dd3 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Fri, 28 Feb 2025 22:55:01 -0800 Subject: [PATCH 09/16] fork dialect conversion for attribute-awareness --- lib/Utils/BUILD | 17 + lib/Utils/ContextAwareDialectConversion.cpp | 2797 +++++++++++++++++++ lib/Utils/ContextAwareDialectConversion.h | 408 +++ lib/Utils/ContextAwareTypeConversion.cpp | 297 +- lib/Utils/ContextAwareTypeConversion.h | 580 +++- lib/Utils/ConversionUtils.cpp | 8 +- lib/Utils/ConversionUtils.h | 41 +- 7 files changed, 4030 insertions(+), 118 deletions(-) create mode 100644 lib/Utils/ContextAwareDialectConversion.cpp create mode 100644 lib/Utils/ContextAwareDialectConversion.h diff --git a/lib/Utils/BUILD b/lib/Utils/BUILD index db36d904e..1d648770a 100644 --- a/lib/Utils/BUILD +++ b/lib/Utils/BUILD @@ -66,6 +66,23 @@ cc_library( ], ) +cc_library( + name = "ContextAwareDialectConversion", + srcs = ["ContextAwareDialectConversion.cpp"], + hdrs = ["ContextAwareDialectConversion.h"], + deps = [ + ":ContextAwareTypeConversion", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + cc_library( name = "TargetUtils", srcs = ["TargetUtils.cpp"], diff --git a/lib/Utils/ContextAwareDialectConversion.cpp b/lib/Utils/ContextAwareDialectConversion.cpp new file mode 100644 index 000000000..15859f605 --- /dev/null +++ b/lib/Utils/ContextAwareDialectConversion.cpp @@ -0,0 +1,2797 @@ +#include "lib/Utils/ContextAwareDialectConversion.h" + +#include + +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project +#include "llvm/include/llvm/Support/SaveAndRestore.h" // from @llvm-project +#include "llvm/include/llvm/Support/ScopedPrinter.h" // from @llvm-project +#include "mlir/include/mlir/Config/mlir-config.h" // from @llvm-project +#include "mlir/include/mlir/IR/Block.h" // from @llvm-project +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/Dominance.h" // from @llvm-project +#include "mlir/include/mlir/IR/Iterators.h" // from @llvm-project +#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project +#include "mlir/include/mlir/Rewrite/PatternApplicator.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir { +namespace heir { + +using namespace detail; + +#define DEBUG_TYPE "context-aware-dialect-conversion" + +/// A utility function to log a successful result for the given reason. +template +static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { + LLVM_DEBUG({ + os.unindent(); + os.startLine() << "} -> SUCCESS"; + if (!fmt.empty()) + os.getOStream() << " : " + << llvm::formatv(fmt.data(), std::forward(args)...); + os.getOStream() << "\n"; + }); +} + +/// A utility function to log a failure result for the given reason. +template +static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { + LLVM_DEBUG({ + os.unindent(); + os.startLine() << "} -> FAILURE : " + << llvm::formatv(fmt.data(), std::forward(args)...) + << "\n"; + }); +} + +/// Helper function that computes an insertion point where the given value is +/// defined and can be used without a dominance violation. +static OpBuilder::InsertPoint computeInsertPoint(Value value) { + Block *insertBlock = value.getParentBlock(); + Block::iterator insertPt = insertBlock->begin(); + if (OpResult inputRes = dyn_cast(value)) + insertPt = ++inputRes.getOwner()->getIterator(); + return OpBuilder::InsertPoint(insertBlock, insertPt); +} + +/// Helper function that computes an insertion point where the given values are +/// defined and can be used without a dominance violation. +static OpBuilder::InsertPoint computeInsertPoint(ArrayRef vals) { + assert(!vals.empty() && "expected at least one value"); + DominanceInfo domInfo; + OpBuilder::InsertPoint pt = computeInsertPoint(vals.front()); + for (Value v : vals.drop_front()) { + // Choose the "later" insertion point. + OpBuilder::InsertPoint nextPt = computeInsertPoint(v); + if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(), + nextPt.getPoint())) { + // pt is before nextPt => choose nextPt. + pt = nextPt; + } else { +#ifndef NDEBUG + // nextPt should be before pt => choose pt. + // If pt, nextPt are no dominance relationship, then there is no valid + // insertion point at which all given values are defined. + bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(), + pt.getBlock(), pt.getPoint()); + assert(dom && "unable to find valid insertion point"); +#endif // NDEBUG + } + } + return pt; +} + +//===----------------------------------------------------------------------===// +// ConversionValueMapping +//===----------------------------------------------------------------------===// + +/// A vector of SSA values, optimized for the most common case of a single +/// value. +using ValueVector = SmallVector; + +namespace { + +/// Helper class to make it possible to use `ValueVector` as a key in DenseMap. +struct ValueVectorMapInfo { + static ValueVector getEmptyKey() { return ValueVector{Value()}; } + static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; } + static ::llvm::hash_code getHashValue(const ValueVector &val) { + return ::llvm::hash_combine_range(val.begin(), val.end()); + } + static bool isEqual(const ValueVector &lhs, const ValueVector &rhs) { + return lhs == rhs; + } +}; + +/// This class wraps a IRMapping to provide recursive lookup +/// functionality, i.e. we will traverse if the mapped value also has a mapping. +struct ConversionValueMapping { + /// Return "true" if an SSA value is mapped to the given value. May return + /// false positives. + bool isMappedTo(Value value) const { return mappedTo.contains(value); } + + /// Lookup the most recently mapped values with the desired types in the + /// mapping. + /// + /// Special cases: + /// - If the desired type range is empty, simply return the most recently + /// mapped values. + /// - If there is no mapping to the desired types, also return the most + /// recently mapped values. + /// - If there is no mapping for the given values at all, return the given + /// value. + ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; + + /// Lookup the given value within the map, or return an empty vector if the + /// value is not mapped. If it is mapped, this follows the same behavior + /// as `lookupOrDefault`. + ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const; + + template + struct IsValueVector : std::is_same, ValueVector> {}; + + /// Map a value vector to the one provided. + template + std::enable_if_t::value && IsValueVector::value> + map(OldVal &&oldVal, NewVal &&newVal) { + LLVM_DEBUG({ + ValueVector next(newVal); + while (true) { + assert(next != oldVal && "inserting cyclic mapping"); + auto it = mapping.find(next); + if (it == mapping.end()) break; + next = it->second; + } + }); + for (Value v : newVal) mappedTo.insert(v); + + mapping[std::forward(oldVal)] = std::forward(newVal); + } + + /// Map a value vector or single value to the one provided. + template + std::enable_if_t::value || + !IsValueVector::value> + map(OldVal &&oldVal, NewVal &&newVal) { + if constexpr (IsValueVector{}) { + map(std::forward(oldVal), ValueVector{newVal}); + } else if constexpr (IsValueVector{}) { + map(ValueVector{oldVal}, std::forward(newVal)); + } else { + map(ValueVector{oldVal}, ValueVector{newVal}); + } + } + + /// Drop the last mapping for the given values. + void erase(const ValueVector &value) { mapping.erase(value); } + + private: + /// Current value mappings. + DenseMap mapping; + + /// All SSA values that are mapped to. May contain false positives. + DenseSet mappedTo; +}; +} // namespace + +ValueVector ConversionValueMapping::lookupOrDefault( + Value from, TypeRange desiredTypes) const { + // Try to find the deepest values that have the desired types. If there is no + // such mapping, simply return the deepest values. + ValueVector desiredValue; + ValueVector current{from}; + do { + // Store the current value if the types match. + if (TypeRange(ValueRange(current)) == desiredTypes) desiredValue = current; + + // If possible, Replace each value with (one or multiple) mapped values. + ValueVector next; + for (Value v : current) { + auto it = mapping.find({v}); + if (it != mapping.end()) { + llvm::append_range(next, it->second); + } else { + next.push_back(v); + } + } + if (next != current) { + // If at least one value was replaced, continue the lookup from there. + current = std::move(next); + continue; + } + + // Otherwise: Check if there is a mapping for the entire vector. Such + // mappings are materializations. (N:M mapping are not supported for value + // replacements.) + // + // Note: From a correctness point of view, materializations do not have to + // be stored (and looked up) in the mapping. But for performance reasons, + // we choose to reuse existing IR (when possible) instead of creating it + // multiple times. + auto it = mapping.find(current); + if (it == mapping.end()) { + // No mapping found: The lookup stops here. + break; + } + current = it->second; + } while (true); + + // If the desired values were found use them, otherwise default to the leaf + // values. + // Note: If `desiredTypes` is empty, this function always returns `current`. + return !desiredValue.empty() ? std::move(desiredValue) : std::move(current); +} + +ValueVector ConversionValueMapping::lookupOrNull(Value from, + TypeRange desiredTypes) const { + ValueVector result = lookupOrDefault(from, desiredTypes); + if (result == ValueVector{from} || + (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes)) + return {}; + return result; +} + +//===----------------------------------------------------------------------===// +// Rewriter and Translation State +//===----------------------------------------------------------------------===// +namespace { +/// This class contains a snapshot of the current conversion rewriter state. +/// This is useful when saving and undoing a set of rewrites. +struct RewriterState { + RewriterState(unsigned numRewrites, unsigned numIgnoredOperations, + unsigned numReplacedOps) + : numRewrites(numRewrites), + numIgnoredOperations(numIgnoredOperations), + numReplacedOps(numReplacedOps) {} + + /// The current number of rewrites performed. + unsigned numRewrites; + + /// The current number of ignored operations. + unsigned numIgnoredOperations; + + /// The current number of replaced ops that are scheduled for erasure. + unsigned numReplacedOps; +}; + +//===----------------------------------------------------------------------===// +// IR rewrites +//===----------------------------------------------------------------------===// + +/// An IR rewrite that can be committed (upon success) or rolled back (upon +/// failure). +/// +/// The dialect conversion keeps track of IR modifications (requested by the +/// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites +/// are directly applied to the IR as the rewriter API is used, some are applied +/// partially, and some are delayed until the `IRRewrite` objects are committed. +class IRRewrite { + public: + /// The kind of the rewrite. Rewrites can be undone if the conversion fails. + /// Enum values are ordered, so that they can be used in `classof`: first all + /// block rewrites, then all operation rewrites. + enum class Kind { + // Block rewrites + CreateBlock, + EraseBlock, + InlineBlock, + MoveBlock, + BlockTypeConversion, + ReplaceBlockArg, + // Operation rewrites + MoveOperation, + ModifyOperation, + ReplaceOperation, + CreateOperation, + UnresolvedMaterialization + }; + + virtual ~IRRewrite() = default; + + /// Roll back the rewrite. Operations may be erased during rollback. + virtual void rollback() = 0; + + /// Commit the rewrite. At this point, it is certain that the dialect + /// conversion will succeed. All IR modifications, except for operation/block + /// erasure, must be performed through the given rewriter. + /// + /// Instead of erasing operations/blocks, they should merely be unlinked + /// commit phase and finally be erased during the cleanup phase. This is + /// because internal dialect conversion state (such as `mapping`) may still + /// be using them. + /// + /// Any IR modification that was already performed before the commit phase + /// (e.g., insertion of an op) must be communicated to the listener that may + /// be attached to the given rewriter. + virtual void commit(RewriterBase &rewriter) {} + + /// Cleanup operations/blocks. Cleanup is called after commit. + virtual void cleanup(RewriterBase &rewriter) {} + + Kind getKind() const { return kind; } + + static bool classof(const IRRewrite *rewrite) { return true; } + + protected: + IRRewrite(Kind kind, ContextAwareConversionPatternRewriterImpl &rewriterImpl) + : kind(kind), rewriterImpl(rewriterImpl) {} + + const ConversionConfig &getConfig() const; + + const Kind kind; + ContextAwareConversionPatternRewriterImpl &rewriterImpl; +}; + +/// A block rewrite. +class BlockRewrite : public IRRewrite { + public: + /// Return the block that this rewrite operates on. + Block *getBlock() const { return block; } + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() >= Kind::CreateBlock && + rewrite->getKind() <= Kind::ReplaceBlockArg; + } + + protected: + BlockRewrite(Kind kind, + ContextAwareConversionPatternRewriterImpl &rewriterImpl, + Block *block) + : IRRewrite(kind, rewriterImpl), block(block) {} + + // The block that this rewrite operates on. + Block *block; +}; + +/// Creation of a block. Block creations are immediately reflected in the IR. +/// There is no extra work to commit the rewrite. During rollback, the newly +/// created block is erased. +class CreateBlockRewrite : public BlockRewrite { + public: + CreateBlockRewrite(ContextAwareConversionPatternRewriterImpl &rewriterImpl, + Block *block) + : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::CreateBlock; + } + + void commit(RewriterBase &rewriter) override { + // The block was already created and inserted. Just inform the listener. + if (auto *listener = rewriter.getListener()) + listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{}); + } + + void rollback() override { + // Unlink all of the operations within this block, they will be deleted + // separately. + auto &blockOps = block->getOperations(); + while (!blockOps.empty()) blockOps.remove(blockOps.begin()); + block->dropAllUses(); + if (block->getParent()) + block->erase(); + else + delete block; + } +}; + +/// Erasure of a block. Block erasures are partially reflected in the IR. Erased +/// blocks are immediately unlinked, but only erased during cleanup. This makes +/// it easier to rollback a block erasure: the block is simply inserted into its +/// original location. +class EraseBlockRewrite : public BlockRewrite { + public: + EraseBlockRewrite(ContextAwareConversionPatternRewriterImpl &rewriterImpl, + Block *block) + : BlockRewrite(Kind::EraseBlock, rewriterImpl, block), + region(block->getParent()), + insertBeforeBlock(block->getNextNode()) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::EraseBlock; + } + + ~EraseBlockRewrite() override { + assert(!block && + "rewrite was neither rolled back nor committed/cleaned up"); + } + + void rollback() override { + // The block (owned by this rewrite) was not actually erased yet. It was + // just unlinked. Put it back into its original position. + assert(block && "expected block"); + auto &blockList = region->getBlocks(); + Region::iterator before = insertBeforeBlock + ? Region::iterator(insertBeforeBlock) + : blockList.end(); + blockList.insert(before, block); + block = nullptr; + } + + void commit(RewriterBase &rewriter) override { + // Erase the block. + assert(block && "expected block"); + assert(block->empty() && "expected empty block"); + + // Notify the listener that the block is about to be erased. + if (auto *listener = + dyn_cast_or_null(rewriter.getListener())) + listener->notifyBlockErased(block); + } + + void cleanup(RewriterBase &rewriter) override { + // Erase the block. + block->dropAllDefinedValueUses(); + delete block; + block = nullptr; + } + + private: + // The region in which this block was previously contained. + Region *region; + + // The original successor of this block before it was unlinked. "nullptr" if + // this block was the only block in the region. + Block *insertBeforeBlock; +}; + +/// Inlining of a block. This rewrite is immediately reflected in the IR. +/// Note: This rewrite represents only the inlining of the operations. The +/// erasure of the inlined block is a separate rewrite. +class InlineBlockRewrite : public BlockRewrite { + public: + InlineBlockRewrite(ContextAwareConversionPatternRewriterImpl &rewriterImpl, + Block *block, Block *sourceBlock, Block::iterator before) + : BlockRewrite(Kind::InlineBlock, rewriterImpl, block), + sourceBlock(sourceBlock), + firstInlinedInst(sourceBlock->empty() ? nullptr + : &sourceBlock->front()), + lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) { + // If a listener is attached to the dialect conversion, ops must be moved + // one-by-one. When they are moved in bulk, notifications cannot be sent + // because the ops that used to be in the source block at the time of the + // inlining (before the "commit" phase) are unknown at the time when + // notifications are sent (which is during the "commit" phase). + assert(!getConfig().listener && + "InlineBlockRewrite not supported if listener is attached"); + } + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::InlineBlock; + } + + void rollback() override { + // Put the operations from the destination block (owned by the rewrite) + // back into the source block. + if (firstInlinedInst) { + assert(lastInlinedInst && "expected operation"); + sourceBlock->getOperations().splice(sourceBlock->begin(), + block->getOperations(), + Block::iterator(firstInlinedInst), + ++Block::iterator(lastInlinedInst)); + } + } + + private: + // The block that originally contained the operations. + Block *sourceBlock; + + // The first inlined operation. + Operation *firstInlinedInst; + + // The last inlined operation. + Operation *lastInlinedInst; +}; + +/// Moving of a block. This rewrite is immediately reflected in the IR. +class MoveBlockRewrite : public BlockRewrite { + public: + MoveBlockRewrite(ContextAwareConversionPatternRewriterImpl &rewriterImpl, + Block *block, Region *region, Block *insertBeforeBlock) + : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), + region(region), + insertBeforeBlock(insertBeforeBlock) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::MoveBlock; + } + + void commit(RewriterBase &rewriter) override { + // The block was already moved. Just inform the listener. + if (auto *listener = rewriter.getListener()) { + // Note: `previousIt` cannot be passed because this is a delayed + // notification and iterators into past IR state cannot be represented. + listener->notifyBlockInserted(block, /*previous=*/region, + /*previousIt=*/{}); + } + } + + void rollback() override { + // Move the block back to its original position. + Region::iterator before = + insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end(); + region->getBlocks().splice(before, block->getParent()->getBlocks(), block); + } + + private: + // The region in which this block was previously contained. + Region *region; + + // The original successor of this block before it was moved. "nullptr" if + // this block was the only block in the region. + Block *insertBeforeBlock; +}; + +/// Block type conversion. This rewrite is partially reflected in the IR. +class BlockTypeConversionRewrite : public BlockRewrite { + public: + BlockTypeConversionRewrite( + ContextAwareConversionPatternRewriterImpl &rewriterImpl, Block *origBlock, + Block *newBlock) + : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock), + newBlock(newBlock) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::BlockTypeConversion; + } + + Block *getOrigBlock() const { return block; } + + Block *getNewBlock() const { return newBlock; } + + void commit(RewriterBase &rewriter) override; + + void rollback() override; + + private: + /// The new block that was created as part of this signature conversion. + Block *newBlock; +}; + +/// Replacing a block argument. This rewrite is not immediately reflected in the +/// IR. An internal IR mapping is updated, but the actual replacement is delayed +/// until the rewrite is committed. +class ReplaceBlockArgRewrite : public BlockRewrite { + public: + ReplaceBlockArgRewrite( + ContextAwareConversionPatternRewriterImpl &rewriterImpl, Block *block, + BlockArgument arg, const ContextAwareTypeConverter *converter) + : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), + arg(arg), + converter(converter) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::ReplaceBlockArg; + } + + void commit(RewriterBase &rewriter) override; + + void rollback() override; + + private: + BlockArgument arg; + + /// The current type converter when the block argument was replaced. + const ContextAwareTypeConverter *converter; +}; + +/// An operation rewrite. +class OperationRewrite : public IRRewrite { + public: + /// Return the operation that this rewrite operates on. + Operation *getOperation() const { return op; } + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() >= Kind::MoveOperation && + rewrite->getKind() <= Kind::UnresolvedMaterialization; + } + + protected: + OperationRewrite(Kind kind, + ContextAwareConversionPatternRewriterImpl &rewriterImpl, + Operation *op) + : IRRewrite(kind, rewriterImpl), op(op) {} + + // The operation that this rewrite operates on. + Operation *op; +}; + +/// Moving of an operation. This rewrite is immediately reflected in the IR. +class MoveOperationRewrite : public OperationRewrite { + public: + MoveOperationRewrite(ContextAwareConversionPatternRewriterImpl &rewriterImpl, + Operation *op, Block *block, Operation *insertBeforeOp) + : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), + block(block), + insertBeforeOp(insertBeforeOp) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::MoveOperation; + } + + void commit(RewriterBase &rewriter) override { + // The operation was already moved. Just inform the listener. + if (auto *listener = rewriter.getListener()) { + // Note: `previousIt` cannot be passed because this is a delayed + // notification and iterators into past IR state cannot be represented. + listener->notifyOperationInserted( + op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block, + /*insertPt=*/{})); + } + } + + void rollback() override { + // Move the operation back to its original position. + Block::iterator before = + insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end(); + block->getOperations().splice(before, op->getBlock()->getOperations(), op); + } + + private: + // The block in which this operation was previously contained. + Block *block; + + // The original successor of this operation before it was moved. "nullptr" + // if this operation was the only operation in the region. + Operation *insertBeforeOp; +}; + +/// In-place modification of an op. This rewrite is immediately reflected in +/// the IR. The previous state of the operation is stored in this object. +class ModifyOperationRewrite : public OperationRewrite { + public: + ModifyOperationRewrite( + ContextAwareConversionPatternRewriterImpl &rewriterImpl, Operation *op) + : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op), + name(op->getName()), + loc(op->getLoc()), + attrs(op->getAttrDictionary()), + operands(op->operand_begin(), op->operand_end()), + successors(op->successor_begin(), op->successor_end()) { + if (OpaqueProperties prop = op->getPropertiesStorage()) { + // Make a copy of the properties. + propertiesStorage = operator new(op->getPropertiesStorageSize()); + OpaqueProperties propCopy(propertiesStorage); + name.initOpProperties(propCopy, /*init=*/prop); + } + } + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::ModifyOperation; + } + + ~ModifyOperationRewrite() override { + assert(!propertiesStorage && + "rewrite was neither committed nor rolled back"); + } + + void commit(RewriterBase &rewriter) override { + // Notify the listener that the operation was modified in-place. + if (auto *listener = + dyn_cast_or_null(rewriter.getListener())) + listener->notifyOperationModified(op); + + if (propertiesStorage) { + OpaqueProperties propCopy(propertiesStorage); + // Note: The operation may have been erased in the mean time, so + // OperationName must be stored in this object. + name.destroyOpProperties(propCopy); + operator delete(propertiesStorage); + propertiesStorage = nullptr; + } + } + + void rollback() override { + op->setLoc(loc); + op->setAttrs(attrs); + op->setOperands(operands); + for (const auto &it : llvm::enumerate(successors)) + op->setSuccessor(it.value(), it.index()); + if (propertiesStorage) { + OpaqueProperties propCopy(propertiesStorage); + op->copyProperties(propCopy); + name.destroyOpProperties(propCopy); + operator delete(propertiesStorage); + propertiesStorage = nullptr; + } + } + + private: + OperationName name; + LocationAttr loc; + DictionaryAttr attrs; + SmallVector operands; + SmallVector successors; + void *propertiesStorage = nullptr; +}; + +/// Replacing an operation. Erasing an operation is treated as a special case +/// with "null" replacements. This rewrite is not immediately reflected in the +/// IR. An internal IR mapping is updated, but values are not replaced and the +/// original op is not erased until the rewrite is committed. +class ReplaceOperationRewrite : public OperationRewrite { + public: + ReplaceOperationRewrite( + ContextAwareConversionPatternRewriterImpl &rewriterImpl, Operation *op, + const ContextAwareTypeConverter *converter) + : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op), + converter(converter) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::ReplaceOperation; + } + + void commit(RewriterBase &rewriter) override; + + void rollback() override; + + void cleanup(RewriterBase &rewriter) override; + + private: + /// An optional type converter that can be used to materialize conversions + /// between the new and old values if necessary. + const ContextAwareTypeConverter *converter; +}; + +class CreateOperationRewrite : public OperationRewrite { + public: + CreateOperationRewrite( + ContextAwareConversionPatternRewriterImpl &rewriterImpl, Operation *op) + : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::CreateOperation; + } + + void commit(RewriterBase &rewriter) override { + // The operation was already created and inserted. Just inform the listener. + if (auto *listener = rewriter.getListener()) + listener->notifyOperationInserted(op, /*previous=*/{}); + } + + void rollback() override; +}; + +/// The type of materialization. +enum MaterializationKind { + /// This materialization materializes a conversion from an illegal type to a + /// legal one. + Target, + + /// This materialization materializes a conversion from a legal type back to + /// an illegal one. + Source +}; + +/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast" +/// op. Unresolved materializations are erased at the end of the dialect +/// conversion. +class UnresolvedMaterializationRewrite : public OperationRewrite { + public: + UnresolvedMaterializationRewrite( + ContextAwareConversionPatternRewriterImpl &rewriterImpl, + UnrealizedConversionCastOp op, const ContextAwareTypeConverter *converter, + MaterializationKind kind, Type originalType, ValueVector mappedValues); + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::UnresolvedMaterialization; + } + + void rollback() override; + + UnrealizedConversionCastOp getOperation() const { + return cast(op); + } + + /// Return the type converter of this materialization (which may be null). + const ContextAwareTypeConverter *getConverter() const { + return converterAndKind.getPointer(); + } + + /// Return the kind of this materialization. + MaterializationKind getMaterializationKind() const { + return converterAndKind.getInt(); + } + + /// Return the original type of the SSA value. + Type getOriginalType() const { return originalType; } + + private: + /// The corresponding type converter to use when resolving this + /// materialization, and the kind of this materialization. + llvm::PointerIntPair + converterAndKind; + + /// The original type of the SSA value. Only used for target + /// materializations. + Type originalType; + + /// The values in the conversion value mapping that are being replaced by the + /// results of this unresolved materialization. + ValueVector mappedValues; +}; +} // namespace + +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS +/// Return "true" if there is an operation rewrite that matches the specified +/// rewrite type and operation among the given rewrites. +template +static bool hasRewrite(R &&rewrites, Operation *op) { + return any_of(std::forward(rewrites), [&](auto &rewrite) { + auto *rewriteTy = dyn_cast(rewrite.get()); + return rewriteTy && rewriteTy->getOperation() == op; + }); +} + +/// Return "true" if there is a block rewrite that matches the specified +/// rewrite type and block among the given rewrites. +template +static bool hasRewrite(R &&rewrites, Block *block) { + return any_of(std::forward(rewrites), [&](auto &rewrite) { + auto *rewriteTy = dyn_cast(rewrite.get()); + return rewriteTy && rewriteTy->getBlock() == block; + }); +} +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + +//===----------------------------------------------------------------------===// +// ContextAwareConversionPatternRewriterImpl +//===----------------------------------------------------------------------===// +namespace detail { +struct ContextAwareConversionPatternRewriterImpl + : public RewriterBase::Listener { + explicit ContextAwareConversionPatternRewriterImpl( + MLIRContext *ctx, const ConversionConfig &config) + : context(ctx), eraseRewriter(ctx), config(config) {} + + //===--------------------------------------------------------------------===// + // State Management + //===--------------------------------------------------------------------===// + + /// Return the current state of the rewriter. + RewriterState getCurrentState(); + + /// Apply all requested operation rewrites. This method is invoked when the + /// conversion process succeeds. + void applyRewrites(); + + /// Reset the state of the rewriter to a previously saved point. + void resetState(RewriterState state); + + /// Append a rewrite. Rewrites are committed upon success and rolled back upon + /// failure. + template + void appendRewrite(Args &&...args) { + rewrites.push_back( + std::make_unique(*this, std::forward(args)...)); + } + + /// Undo the rewrites (motions, splits) one by one in reverse order until + /// "numRewritesToKeep" rewrites remains. + void undoRewrites(unsigned numRewritesToKeep = 0); + + /// Remap the given values to those with potentially different types. Returns + /// success if the values could be remapped, failure otherwise. `valueDiagTag` + /// is the tag used when describing a value within a diagnostic, e.g. + /// "operand". + LogicalResult remapValues(StringRef valueDiagTag, + std::optional inputLoc, + PatternRewriter &rewriter, ValueRange values, + SmallVector &remapped); + + /// Return "true" if the given operation is ignored, and does not need to be + /// converted. + bool isOpIgnored(Operation *op) const; + + /// Return "true" if the given operation was replaced or erased. + bool wasOpReplaced(Operation *op) const; + + //===--------------------------------------------------------------------===// + // Type Conversion + //===--------------------------------------------------------------------===// + + /// Convert the types of block arguments within the given region. + FailureOr convertRegionTypes( + ContextAwareConversionPatternRewriter &rewriter, Region *region, + const ContextAwareTypeConverter &converter, + ContextAwareTypeConverter::SignatureConversion *entryConversion); + + /// Apply the given signature conversion on the given block. The new block + /// containing the updated signature is returned. If no conversions were + /// necessary, e.g. if the block has no arguments, `block` is returned. + /// `converter` is used to generate any necessary cast operations that + /// translate between the origin argument types and those specified in the + /// signature conversion. + Block *applySignatureConversion( + ContextAwareConversionPatternRewriter &rewriter, Block *block, + const ContextAwareTypeConverter *converter, + ContextAwareTypeConverter::SignatureConversion &signatureConversion); + + //===--------------------------------------------------------------------===// + // Materializations + //===--------------------------------------------------------------------===// + + /// Build an unresolved materialization operation given a range of output + /// types and a list of input operands. Returns the inputs if they their + /// types match the output types. + /// + /// If a cast op was built, it can optionally be returned with the `castOp` + /// output argument. + /// + /// If `valuesToMap` is set to a non-null Value, then that value is mapped to + /// the results of the unresolved materialization in the conversion value + /// mapping. + ValueRange buildUnresolvedMaterialization( + MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, + ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, + Type originalType, const ContextAwareTypeConverter *converter, + UnrealizedConversionCastOp *castOp = nullptr); + + /// Find a replacement value for the given SSA value in the conversion value + /// mapping. The replacement value must have the same type as the given SSA + /// value. If there is no replacement value with the correct type, find the + /// latest replacement value (regardless of the type) and build a source + /// materialization. + Value findOrBuildReplacementValue(Value value, + const ContextAwareTypeConverter *converter); + + //===--------------------------------------------------------------------===// + // Rewriter Notification Hooks + //===--------------------------------------------------------------------===// + + //// Notifies that an op was inserted. + void notifyOperationInserted(Operation *op, + OpBuilder::InsertPoint previous) override; + + /// Notifies that an op is about to be replaced with the given values. + void notifyOpReplaced(Operation *op, ArrayRef newValues); + + /// Notifies that a block is about to be erased. + void notifyBlockIsBeingErased(Block *block); + + /// Notifies that a block was inserted. + void notifyBlockInserted(Block *block, Region *previous, + Region::iterator previousIt) override; + + /// Notifies that a block is being inlined into another block. + void notifyBlockBeingInlined(Block *block, Block *srcBlock, + Block::iterator before); + + /// Notifies that a pattern match failed for the given reason. + void notifyMatchFailure( + Location loc, function_ref reasonCallback) override; + + //===--------------------------------------------------------------------===// + // IR Erasure + //===--------------------------------------------------------------------===// + + /// A rewriter that keeps track of erased ops and blocks. It ensures that no + /// operation or block is erased multiple times. This rewriter assumes that + /// no new IR is created between calls to `eraseOp`/`eraseBlock`. + struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener { + public: + SingleEraseRewriter(MLIRContext *context) + : RewriterBase(context, /*listener=*/this) {} + + /// Erase the given op (unless it was already erased). + void eraseOp(Operation *op) override { + if (wasErased(op)) return; + op->dropAllUses(); + RewriterBase::eraseOp(op); + } + + /// Erase the given block (unless it was already erased). + void eraseBlock(Block *block) override { + if (wasErased(block)) return; + assert(block->empty() && "expected empty block"); + block->dropAllDefinedValueUses(); + RewriterBase::eraseBlock(block); + } + + bool wasErased(void *ptr) const { return erased.contains(ptr); } + + void notifyOperationErased(Operation *op) override { erased.insert(op); } + + void notifyBlockErased(Block *block) override { erased.insert(block); } + + private: + /// Pointers to all erased operations and blocks. + DenseSet erased; + }; + + //===--------------------------------------------------------------------===// + // State + //===--------------------------------------------------------------------===// + + /// MLIR context. + MLIRContext *context; + + /// A rewriter that keeps track of ops/block that were already erased and + /// skips duplicate op/block erasures. This rewriter is used during the + /// "cleanup" phase. + SingleEraseRewriter eraseRewriter; + + // Mapping between replaced values that differ in type. This happens when + // replacing a value with one of a different type. + ConversionValueMapping mapping; + + /// Ordered list of block operations (creations, splits, motions). + SmallVector> rewrites; + + /// A set of operations that should no longer be considered for legalization. + /// E.g., ops that are recursively legal. Ops that were replaced/erased are + /// tracked separately. + SetVector ignoredOps; + + /// A set of operations that were replaced/erased. Such ops are not erased + /// immediately but only when the dialect conversion succeeds. In the mean + /// time, they should no longer be considered for legalization and any attempt + /// to modify/access them is invalid rewriter API usage. + SetVector replacedOps; + + /// A mapping of all unresolved materializations (UnrealizedConversionCastOp) + /// to the corresponding rewrite objects. + DenseMap + unresolvedMaterializations; + + /// The current type converter, or nullptr if no type converter is currently + /// active. + const ContextAwareTypeConverter *currentContextAwareTypeConverter = nullptr; + + /// A mapping of regions to type converters that should be used when + /// converting the arguments of blocks within that region. + DenseMap regionToConverter; + + /// Dialect conversion configuration. + const ConversionConfig &config; + +#ifndef NDEBUG + /// A set of operations that have pending updates. This tracking isn't + /// strictly necessary, and is thus only active during debug builds for extra + /// verification. + SmallPtrSet pendingRootUpdates; + + /// A logger used to emit diagnostics during the conversion process. + llvm::ScopedPrinter logger{llvm::dbgs()}; +#endif +}; +} // namespace detail + +const ConversionConfig &IRRewrite::getConfig() const { + return rewriterImpl.config; +} + +void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) { + // Inform the listener about all IR modifications that have already taken + // place: References to the original block have been replaced with the new + // block. + if (auto *listener = + dyn_cast_or_null(rewriter.getListener())) + for (Operation *op : getNewBlock()->getUsers()) + listener->notifyOperationModified(op); +} + +void BlockTypeConversionRewrite::rollback() { + getNewBlock()->replaceAllUsesWith(getOrigBlock()); +} + +void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { + Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); + if (!repl) return; + + if (isa(repl)) { + rewriter.replaceAllUsesWith(arg, repl); + return; + } + + // If the replacement value is an operation, we check to make sure that we + // don't replace uses that are within the parent operation of the + // replacement value. + Operation *replOp = cast(repl).getOwner(); + Block *replBlock = replOp->getBlock(); + rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) { + Operation *user = operand.getOwner(); + return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); + }); +} + +void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); } + +void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { + auto *listener = + dyn_cast_or_null(rewriter.getListener()); + + // Compute replacement values. + SmallVector replacements = + llvm::map_to_vector(op->getResults(), [&](OpResult result) { + return rewriterImpl.findOrBuildReplacementValue(result, converter); + }); + + // Notify the listener that the operation is about to be replaced. + if (listener) listener->notifyOperationReplaced(op, replacements); + + // Replace all uses with the new values. + for (auto [result, newValue] : + llvm::zip_equal(op->getResults(), replacements)) + if (newValue) rewriter.replaceAllUsesWith(result, newValue); + + // The original op will be erased, so remove it from the set of unlegalized + // ops. + if (getConfig().unlegalizedOps) getConfig().unlegalizedOps->erase(op); + + // Notify the listener that the operation (and its nested operations) was + // erased. + if (listener) { + op->walk( + [&](Operation *op) { listener->notifyOperationErased(op); }); + } + + // Do not erase the operation yet. It may still be referenced in `mapping`. + // Just unlink it for now and erase it during cleanup. + op->getBlock()->getOperations().remove(op); +} + +void ReplaceOperationRewrite::rollback() { + for (auto result : op->getResults()) rewriterImpl.mapping.erase({result}); +} + +void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) { + rewriter.eraseOp(op); +} + +void CreateOperationRewrite::rollback() { + for (Region ®ion : op->getRegions()) { + while (!region.getBlocks().empty()) + region.getBlocks().remove(region.getBlocks().begin()); + } + op->dropAllUses(); + op->erase(); +} + +UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite( + ContextAwareConversionPatternRewriterImpl &rewriterImpl, + UnrealizedConversionCastOp op, const ContextAwareTypeConverter *converter, + MaterializationKind kind, Type originalType, ValueVector mappedValues) + : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), + converterAndKind(converter, kind), + originalType(originalType), + mappedValues(std::move(mappedValues)) { + assert((!originalType || kind == MaterializationKind::Target) && + "original type is valid only for target materializations"); + rewriterImpl.unresolvedMaterializations[op] = this; +} + +void UnresolvedMaterializationRewrite::rollback() { + if (!mappedValues.empty()) rewriterImpl.mapping.erase(mappedValues); + rewriterImpl.unresolvedMaterializations.erase(getOperation()); + op->erase(); +} + +void ContextAwareConversionPatternRewriterImpl::applyRewrites() { + // Commit all rewrites. + IRRewriter rewriter(context, config.listener); + // Note: New rewrites may be added during the "commit" phase and the + // `rewrites` vector may reallocate. + for (const auto &rewrite : rewrites) rewrite->commit(rewriter); + + // Clean up all rewrites. + for (auto &rewrite : rewrites) rewrite->cleanup(eraseRewriter); +} + +//===----------------------------------------------------------------------===// +// State Management + +RewriterState ContextAwareConversionPatternRewriterImpl::getCurrentState() { + return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size()); +} + +void ContextAwareConversionPatternRewriterImpl::resetState( + RewriterState state) { + // Undo any rewrites. + undoRewrites(state.numRewrites); + + // Pop all of the recorded ignored operations that are no longer valid. + while (ignoredOps.size() != state.numIgnoredOperations) ignoredOps.pop_back(); + + while (replacedOps.size() != state.numReplacedOps) replacedOps.pop_back(); +} + +void ContextAwareConversionPatternRewriterImpl::undoRewrites( + unsigned numRewritesToKeep) { + for (auto &rewrite : + llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) + rewrite->rollback(); + rewrites.resize(numRewritesToKeep); +} + +LogicalResult ContextAwareConversionPatternRewriterImpl::remapValues( + StringRef valueDiagTag, std::optional inputLoc, + PatternRewriter &rewriter, ValueRange values, + SmallVector &remapped) { + remapped.reserve(llvm::size(values)); + + for (const auto &it : llvm::enumerate(values)) { + Value operand = it.value(); + Type origType = operand.getType(); + Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); + + if (!currentContextAwareTypeConverter) { + // The current pattern does not have a type converter. I.e., it does not + // distinguish between legal and illegal types. For each operand, simply + // pass through the most recently mapped values. + remapped.push_back(mapping.lookupOrDefault(operand)); + continue; + } + + // If there is no legal conversion, fail to match this pattern. + SmallVector legalTypes; + if (failed(currentContextAwareTypeConverter->convertType(origType, operand, + legalTypes))) { + notifyMatchFailure(operandLoc, [=](Diagnostic &diag) { + diag << "unable to convert type for " << valueDiagTag << " #" + << it.index() << ", type was " << origType; + }); + return failure(); + } + // If a type is converted to 0 types, there is nothing to do. + if (legalTypes.empty()) { + remapped.push_back({}); + continue; + } + + ValueVector repl = mapping.lookupOrDefault(operand, legalTypes); + if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) { + // Mapped values have the correct type or there is an existing + // materialization. Or the operand is not mapped at all and has the + // correct type. + remapped.push_back(std::move(repl)); + continue; + } + + // Create a materialization for the most recently mapped values. + repl = mapping.lookupOrDefault(operand); + ValueRange castValues = buildUnresolvedMaterialization( + MaterializationKind::Target, computeInsertPoint(repl), operandLoc, + /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes, + /*originalType=*/origType, currentContextAwareTypeConverter); + remapped.push_back(castValues); + } + return success(); +} + +bool ContextAwareConversionPatternRewriterImpl::isOpIgnored( + Operation *op) const { + // Check to see if this operation is ignored or was replaced. + return replacedOps.count(op) || ignoredOps.count(op); +} + +bool ContextAwareConversionPatternRewriterImpl::wasOpReplaced( + Operation *op) const { + // Check to see if this operation was replaced. + return replacedOps.count(op); +} + +//===----------------------------------------------------------------------===// +// Type Conversion + +FailureOr +ContextAwareConversionPatternRewriterImpl::convertRegionTypes( + ContextAwareConversionPatternRewriter &rewriter, Region *region, + const ContextAwareTypeConverter &converter, + ContextAwareTypeConverter::SignatureConversion *entryConversion) { + regionToConverter[region] = &converter; + if (region->empty()) return nullptr; + + // Convert the arguments of each non-entry block within the region. + for (Block &block : + llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) { + // Compute the signature for the block with the provided converter. + std::optional conversion = + converter.convertBlockSignature(&block); + if (!conversion) return failure(); + // Convert the block with the computed signature. + applySignatureConversion(rewriter, &block, &converter, *conversion); + } + + // Convert the entry block. If an entry signature conversion was provided, + // use that one. Otherwise, compute the signature with the type converter. + if (entryConversion) + return applySignatureConversion(rewriter, ®ion->front(), &converter, + *entryConversion); + std::optional conversion = + converter.convertBlockSignature(®ion->front()); + if (!conversion) return failure(); + return applySignatureConversion(rewriter, ®ion->front(), &converter, + *conversion); +} + +Block *ContextAwareConversionPatternRewriterImpl::applySignatureConversion( + ContextAwareConversionPatternRewriter &rewriter, Block *block, + const ContextAwareTypeConverter *converter, + ContextAwareTypeConverter::SignatureConversion &signatureConversion) { +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // A block cannot be converted multiple times. + if (hasRewrite(rewrites, block)) + llvm::report_fatal_error("block was already converted"); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + + OpBuilder::InsertionGuard g(rewriter); + + // If no arguments are being changed or added, there is nothing to do. + unsigned origArgCount = block->getNumArguments(); + auto convertedTypes = signatureConversion.getConvertedTypes(); + if (llvm::equal(block->getArgumentTypes(), convertedTypes)) return block; + + // Compute the locations of all block arguments in the new block. + SmallVector newLocs(convertedTypes.size(), + rewriter.getUnknownLoc()); + for (unsigned i = 0; i < origArgCount; ++i) { + auto inputMap = signatureConversion.getInputMapping(i); + if (!inputMap || inputMap->replacementValue) continue; + Location origLoc = block->getArgument(i).getLoc(); + for (unsigned j = 0; j < inputMap->size; ++j) + newLocs[inputMap->inputNo + j] = origLoc; + } + + // Insert a new block with the converted block argument types and move all ops + // from the old block to the new block. + Block *newBlock = + rewriter.createBlock(block->getParent(), std::next(block->getIterator()), + convertedTypes, newLocs); + + // If a listener is attached to the dialect conversion, ops cannot be moved + // to the destination block in bulk ("fast path"). This is because at the time + // the notifications are sent, it is unknown which ops were moved. Instead, + // ops should be moved one-by-one ("slow path"), so that a separate + // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is + // a bit more efficient, so we try to do that when possible. + bool fastPath = !config.listener; + if (fastPath) { + appendRewrite(newBlock, block, newBlock->end()); + newBlock->getOperations().splice(newBlock->end(), block->getOperations()); + } else { + while (!block->empty()) + rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end()); + } + + // Replace all uses of the old block with the new block. + block->replaceAllUsesWith(newBlock); + + for (unsigned i = 0; i != origArgCount; ++i) { + BlockArgument origArg = block->getArgument(i); + Type origArgType = origArg.getType(); + + std::optional + inputMap = signatureConversion.getInputMapping(i); + if (!inputMap) { + // This block argument was dropped and no replacement value was provided. + // Materialize a replacement value "out of thin air". + buildUnresolvedMaterialization( + MaterializationKind::Source, + OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), + /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(), + /*outputTypes=*/origArgType, /*originalType=*/Type(), converter); + appendRewrite(block, origArg, converter); + continue; + } + + if (Value repl = inputMap->replacementValue) { + // This block argument was dropped and a replacement value was provided. + assert(inputMap->size == 0 && + "invalid to provide a replacement value when the argument isn't " + "dropped"); + mapping.map(origArg, repl); + appendRewrite(block, origArg, converter); + continue; + } + + // This is a 1->1+ mapping. + auto replArgs = + newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); + ValueVector replArgVals = llvm::to_vector_of(replArgs); + mapping.map(origArg, std::move(replArgVals)); + appendRewrite(block, origArg, converter); + } + + appendRewrite(/*origBlock=*/block, newBlock); + + // Erase the old block. (It is just unlinked for now and will be erased during + // cleanup.) + rewriter.eraseBlock(block); + + return newBlock; +} + +//===----------------------------------------------------------------------===// +// Materializations +//===----------------------------------------------------------------------===// + +/// Build an unresolved materialization operation given an output type and set +/// of input operands. +ValueRange +ContextAwareConversionPatternRewriterImpl::buildUnresolvedMaterialization( + MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, + ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, + Type originalType, const ContextAwareTypeConverter *converter, + UnrealizedConversionCastOp *castOp) { + assert((!originalType || kind == MaterializationKind::Target) && + "original type is valid only for target materializations"); + assert(TypeRange(inputs) != outputTypes && + "materialization is not necessary"); + + // Create an unresolved materialization. We use a new OpBuilder to avoid + // tracking the materialization like we do for other operations. + OpBuilder builder(outputTypes.front().getContext()); + builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); + auto convertOp = + builder.create(loc, outputTypes, inputs); + if (!valuesToMap.empty()) mapping.map(valuesToMap, convertOp.getResults()); + if (castOp) *castOp = convertOp; + appendRewrite( + convertOp, converter, kind, originalType, std::move(valuesToMap)); + return convertOp.getResults(); +} + +Value ContextAwareConversionPatternRewriterImpl::findOrBuildReplacementValue( + Value value, const ContextAwareTypeConverter *converter) { + // Try to find a replacement value with the same type in the conversion value + // mapping. This includes cached materializations. We try to reuse those + // instead of generating duplicate IR. + ValueVector repl = mapping.lookupOrNull(value, value.getType()); + if (!repl.empty()) return repl.front(); + + // Check if the value is dead. No replacement value is needed in that case. + // This is an approximate check that may have false negatives but does not + // require computing and traversing an inverse mapping. (We may end up + // building source materializations that are never used and that fold away.) + if (llvm::all_of(value.getUsers(), + [&](Operation *op) { return replacedOps.contains(op); }) && + !mapping.isMappedTo(value)) + return Value(); + + // No replacement value was found. Get the latest replacement value + // (regardless of the type) and build a source materialization to the + // original type. + repl = mapping.lookupOrNull(value); + if (repl.empty()) { + // No replacement value is registered in the mapping. This means that the + // value is dropped and no longer needed. (If the value were still needed, + // a source materialization producing a replacement value "out of thin air" + // would have already been created during `replaceOp` or + // `applySignatureConversion`.) + return Value(); + } + + // Note: `computeInsertPoint` computes the "earliest" insertion point at + // which all values in `repl` are defined. It is important to emit the + // materialization at that location because the same materialization may be + // reused in a different context. (That's because materializations are cached + // in the conversion value mapping.) The insertion point of the + // materialization must be valid for all future users that may be created + // later in the conversion process. + Value castValue = + buildUnresolvedMaterialization(MaterializationKind::Source, + computeInsertPoint(repl), value.getLoc(), + /*valuesToMap=*/repl, /*inputs=*/repl, + /*outputTypes=*/value.getType(), + /*originalType=*/Type(), converter) + .front(); + return castValue; +} + +//===----------------------------------------------------------------------===// +// Rewriter Notification Hooks + +void ContextAwareConversionPatternRewriterImpl::notifyOperationInserted( + Operation *op, OpBuilder::InsertPoint previous) { + LLVM_DEBUG({ + logger.startLine() << "** Insert : '" << op->getName() << "'(" << op + << ")\n"; + }); + assert(!wasOpReplaced(op->getParentOp()) && + "attempting to insert into a block within a replaced/erased op"); + + if (!previous.isSet()) { + // This is a newly created op. + appendRewrite(op); + return; + } + Operation *prevOp = previous.getPoint() == previous.getBlock()->end() + ? nullptr + : &*previous.getPoint(); + appendRewrite(op, previous.getBlock(), prevOp); +} + +void ContextAwareConversionPatternRewriterImpl::notifyOpReplaced( + Operation *op, ArrayRef newValues) { + assert(newValues.size() == op->getNumResults()); + assert(!ignoredOps.contains(op) && "operation was already replaced"); + + // Check if replaced op is an unresolved materialization, i.e., an + // unrealized_conversion_cast op that was created by the conversion driver. + bool isUnresolvedMaterialization = false; + if (auto castOp = dyn_cast(op)) + if (unresolvedMaterializations.contains(castOp)) + isUnresolvedMaterialization = true; + + // Create mappings for each of the new result values. + for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) { + if (repl.empty()) { + // This result was dropped and no replacement value was provided. + if (isUnresolvedMaterialization) { + // Do not create another materializations if we are erasing a + // materialization. + continue; + } + + // Materialize a replacement value "out of thin air". + buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(result), + result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(), + /*outputTypes=*/result.getType(), /*originalType=*/Type(), + currentContextAwareTypeConverter); + continue; + } + + // Make sure that the user does not mess with unresolved materializations + // that were inserted by the conversion driver. We keep track of these + // ops in internal data structures. Erasing them must be allowed because + // this can happen when the user is erasing an entire block (including + // its body). But replacing them with another value should be forbidden + // to avoid problems with the `mapping`. + assert(!isUnresolvedMaterialization && + "attempting to replace an unresolved materialization"); + + // Remap result to replacement value. + if (repl.empty()) continue; + mapping.map(result, repl); + } + + appendRewrite(op, currentContextAwareTypeConverter); + // Mark this operation and all nested ops as replaced. + op->walk([&](Operation *op) { replacedOps.insert(op); }); +} + +void ContextAwareConversionPatternRewriterImpl::notifyBlockIsBeingErased( + Block *block) { + appendRewrite(block); +} + +void ContextAwareConversionPatternRewriterImpl::notifyBlockInserted( + Block *block, Region *previous, Region::iterator previousIt) { + assert(!wasOpReplaced(block->getParentOp()) && + "attempting to insert into a region within a replaced/erased op"); + LLVM_DEBUG( + { + Operation *parent = block->getParentOp(); + if (parent) { + logger.startLine() << "** Insert Block into : '" << parent->getName() + << "'(" << parent << ")\n"; + } else { + logger.startLine() + << "** Insert Block into detached Region (nullptr parent op)'"; + } + }); + + if (!previous) { + // This is a newly created block. + appendRewrite(block); + return; + } + Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt; + appendRewrite(block, previous, prevBlock); +} + +void ContextAwareConversionPatternRewriterImpl::notifyBlockBeingInlined( + Block *block, Block *srcBlock, Block::iterator before) { + appendRewrite(block, srcBlock, before); +} + +void ContextAwareConversionPatternRewriterImpl::notifyMatchFailure( + Location loc, function_ref reasonCallback) { + LLVM_DEBUG({ + Diagnostic diag(loc, DiagnosticSeverity::Remark); + reasonCallback(diag); + logger.startLine() << "** Failure : " << diag.str() << "\n"; + if (config.notifyCallback) config.notifyCallback(diag); + }); +} + +//===----------------------------------------------------------------------===// +// ContextAwareConversionPatternRewriter +//===----------------------------------------------------------------------===// + +ContextAwareConversionPatternRewriter::ContextAwareConversionPatternRewriter( + MLIRContext *ctx, const ConversionConfig &config) + : PatternRewriter(ctx), + impl(new detail::ContextAwareConversionPatternRewriterImpl(ctx, config)) { + setListener(impl.get()); +} + +ContextAwareConversionPatternRewriter:: + ~ContextAwareConversionPatternRewriter() = default; + +void ContextAwareConversionPatternRewriter::replaceOp(Operation *op, + Operation *newOp) { + assert(op && newOp && "expected non-null op"); + replaceOp(op, newOp->getResults()); +} + +void ContextAwareConversionPatternRewriter::replaceOp(Operation *op, + ValueRange newValues) { + assert(op->getNumResults() == newValues.size() && + "incorrect # of replacement values"); + LLVM_DEBUG({ + impl->logger.startLine() + << "** Replace : '" << op->getName() << "'(" << op << ")\n"; + }); + SmallVector newVals; + for (size_t i = 0; i < newValues.size(); ++i) { + if (newValues[i]) { + newVals.push_back(newValues.slice(i, 1)); + } else { + newVals.push_back(ValueRange()); + } + } + impl->notifyOpReplaced(op, newVals); +} + +void ContextAwareConversionPatternRewriter::replaceOpWithMultiple( + Operation *op, ArrayRef newValues) { + assert(op->getNumResults() == newValues.size() && + "incorrect # of replacement values"); + LLVM_DEBUG({ + impl->logger.startLine() + << "** Replace : '" << op->getName() << "'(" << op << ")\n"; + }); + impl->notifyOpReplaced(op, newValues); +} + +void ContextAwareConversionPatternRewriter::eraseOp(Operation *op) { + LLVM_DEBUG({ + impl->logger.startLine() + << "** Erase : '" << op->getName() << "'(" << op << ")\n"; + }); + SmallVector nullRepls(op->getNumResults(), {}); + impl->notifyOpReplaced(op, nullRepls); +} + +void ContextAwareConversionPatternRewriter::eraseBlock(Block *block) { + assert(!impl->wasOpReplaced(block->getParentOp()) && + "attempting to erase a block within a replaced/erased op"); + + // Mark all ops for erasure. + for (Operation &op : *block) eraseOp(&op); + + // Unlink the block from its parent region. The block is kept in the rewrite + // object and will be actually destroyed when rewrites are applied. This + // allows us to keep the operations in the block live and undo the removal by + // re-inserting the block. + impl->notifyBlockIsBeingErased(block); + block->getParent()->getBlocks().remove(block); +} + +Block *ContextAwareConversionPatternRewriter::applySignatureConversion( + Block *block, ContextAwareTypeConverter::SignatureConversion &conversion, + const ContextAwareTypeConverter *converter) { + assert(!impl->wasOpReplaced(block->getParentOp()) && + "attempting to apply a signature conversion to a block within a " + "replaced/erased op"); + return impl->applySignatureConversion(*this, block, converter, conversion); +} + +FailureOr ContextAwareConversionPatternRewriter::convertRegionTypes( + Region *region, const ContextAwareTypeConverter &converter, + ContextAwareTypeConverter::SignatureConversion *entryConversion) { + assert(!impl->wasOpReplaced(region->getParentOp()) && + "attempting to apply a signature conversion to a block within a " + "replaced/erased op"); + return impl->convertRegionTypes(*this, region, converter, entryConversion); +} + +void ContextAwareConversionPatternRewriter::replaceUsesOfBlockArgument( + BlockArgument from, Value to) { + LLVM_DEBUG({ + impl->logger.startLine() << "** Replace Argument : '" << from << "'"; + if (Operation *parentOp = from.getOwner()->getParentOp()) { + impl->logger.getOStream() << " (in region of '" << parentOp->getName() + << "' (" << parentOp << ")\n"; + } else { + impl->logger.getOStream() << " (unlinked block)\n"; + } + }); + impl->appendRewrite( + from.getOwner(), from, impl->currentContextAwareTypeConverter); + impl->mapping.map(impl->mapping.lookupOrDefault(from), to); +} + +Value ContextAwareConversionPatternRewriter::getRemappedValue(Value key) { + SmallVector remappedValues; + if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key, + remappedValues))) + return nullptr; + assert(remappedValues.front().size() == 1 && "1:N conversion not supported"); + return remappedValues.front().front(); +} + +LogicalResult ContextAwareConversionPatternRewriter::getRemappedValues( + ValueRange keys, SmallVectorImpl &results) { + if (keys.empty()) return success(); + SmallVector remapped; + if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, + remapped))) + return failure(); + for (const auto &values : remapped) { + assert(values.size() == 1 && "1:N conversion not supported"); + results.push_back(values.front()); + } + return success(); +} + +void ContextAwareConversionPatternRewriter::inlineBlockBefore( + Block *source, Block *dest, Block::iterator before, ValueRange argValues) { +#ifndef NDEBUG + assert(argValues.size() == source->getNumArguments() && + "incorrect # of argument replacement values"); + assert(!impl->wasOpReplaced(source->getParentOp()) && + "attempting to inline a block from a replaced/erased op"); + assert(!impl->wasOpReplaced(dest->getParentOp()) && + "attempting to inline a block into a replaced/erased op"); + auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); }; + // The source block will be deleted, so it should not have any users (i.e., + // there should be no predecessors). + assert(llvm::all_of(source->getUsers(), opIgnored) && + "expected 'source' to have no predecessors"); +#endif // NDEBUG + + // If a listener is attached to the dialect conversion, ops cannot be moved + // to the destination block in bulk ("fast path"). This is because at the time + // the notifications are sent, it is unknown which ops were moved. Instead, + // ops should be moved one-by-one ("slow path"), so that a separate + // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is + // a bit more efficient, so we try to do that when possible. + bool fastPath = !impl->config.listener; + + if (fastPath) impl->notifyBlockBeingInlined(dest, source, before); + + // Replace all uses of block arguments. + for (auto it : llvm::zip(source->getArguments(), argValues)) + replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it)); + + if (fastPath) { + // Move all ops at once. + dest->getOperations().splice(before, source->getOperations()); + } else { + // Move op by op. + while (!source->empty()) moveOpBefore(&source->front(), dest, before); + } + + // Erase the source block. + eraseBlock(source); +} + +void ContextAwareConversionPatternRewriter::startOpModification(Operation *op) { + assert(!impl->wasOpReplaced(op) && + "attempting to modify a replaced/erased op"); +#ifndef NDEBUG + impl->pendingRootUpdates.insert(op); +#endif + impl->appendRewrite(op); +} + +void ContextAwareConversionPatternRewriter::finalizeOpModification( + Operation *op) { + assert(!impl->wasOpReplaced(op) && + "attempting to modify a replaced/erased op"); + PatternRewriter::finalizeOpModification(op); + // There is nothing to do here, we only need to track the operation at the + // start of the update. +#ifndef NDEBUG + assert(impl->pendingRootUpdates.erase(op) && + "operation did not have a pending in-place update"); +#endif +} + +void ContextAwareConversionPatternRewriter::cancelOpModification( + Operation *op) { +#ifndef NDEBUG + assert(impl->pendingRootUpdates.erase(op) && + "operation did not have a pending in-place update"); +#endif + // Erase the last update for this operation. + auto it = llvm::find_if( + llvm::reverse(impl->rewrites), [&](std::unique_ptr &rewrite) { + auto *modifyRewrite = dyn_cast(rewrite.get()); + return modifyRewrite && modifyRewrite->getOperation() == op; + }); + assert(it != impl->rewrites.rend() && "no root update started on op"); + (*it)->rollback(); + int updateIdx = std::prev(impl->rewrites.rend()) - it; + impl->rewrites.erase(impl->rewrites.begin() + updateIdx); +} + +detail::ContextAwareConversionPatternRewriterImpl & +ContextAwareConversionPatternRewriter::getImpl() { + return *impl; +} + +//===----------------------------------------------------------------------===// +// ConversionPattern +//===----------------------------------------------------------------------===// + +LogicalResult ContextAwareConversionPattern::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + auto &dialectRewriter = + static_cast(rewriter); + auto &rewriterImpl = dialectRewriter.getImpl(); + + // Track the current conversion pattern type converter in the rewriter. + llvm::SaveAndRestore currentConverterGuard( + rewriterImpl.currentContextAwareTypeConverter, getTypeConverter()); + + // Remap the operands of the operation. + SmallVector remapped; + if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, + op->getOperands(), remapped))) { + return failure(); + } + SmallVector remappedAsRange = + llvm::to_vector_of(remapped); + return matchAndRewrite(op, remappedAsRange, dialectRewriter); +} + +SmallVector ContextAwareConversionPattern::getOneToOneAdaptorOperands( + ArrayRef operands) const { + SmallVector oneToOneOperands; + oneToOneOperands.reserve(operands.size()); + for (ValueRange operand : operands) { + if (operand.size() != 1) + llvm::report_fatal_error("pattern '" + getDebugName() + + "' does not support 1:N conversion"); + oneToOneOperands.push_back(operand.front()); + } + return oneToOneOperands; +} + +//===----------------------------------------------------------------------===// +// OperationLegalizer +//===----------------------------------------------------------------------===// + +namespace { +/// A set of rewrite patterns that can be used to legalize a given operation. +using LegalizationPatterns = SmallVector; + +/// This class defines a recursive operation legalizer. +class OperationLegalizer { + public: + using LegalizationAction = ConversionTarget::LegalizationAction; + + OperationLegalizer(const ConversionTarget &targetInfo, + const FrozenRewritePatternSet &patterns, + const ConversionConfig &config); + + /// Returns true if the given operation is known to be illegal on the target. + bool isIllegal(Operation *op) const; + + /// Attempt to legalize the given operation. Returns success if the operation + /// was legalized, failure otherwise. + LogicalResult legalize(Operation *op, + ContextAwareConversionPatternRewriter &rewriter); + + /// Returns the conversion target in use by the legalizer. + const ConversionTarget &getTarget() { return target; } + + private: + /// Attempt to legalize the given operation by folding it. + LogicalResult legalizeWithFold( + Operation *op, ContextAwareConversionPatternRewriter &rewriter); + + /// Attempt to legalize the given operation by applying a pattern. Returns + /// success if the operation was legalized, failure otherwise. + LogicalResult legalizeWithPattern( + Operation *op, ContextAwareConversionPatternRewriter &rewriter); + + /// Return true if the given pattern may be applied to the given operation, + /// false otherwise. + bool canApplyPattern(Operation *op, const Pattern &pattern, + ContextAwareConversionPatternRewriter &rewriter); + + /// Legalize the resultant IR after successfully applying the given pattern. + LogicalResult legalizePatternResult( + Operation *op, const Pattern &pattern, + ContextAwareConversionPatternRewriter &rewriter, RewriterState &curState); + + /// Legalizes the actions registered during the execution of a pattern. + LogicalResult legalizePatternBlockRewrites( + Operation *op, ContextAwareConversionPatternRewriter &rewriter, + ContextAwareConversionPatternRewriterImpl &impl, RewriterState &state, + RewriterState &newState); + LogicalResult legalizePatternCreatedOperations( + ContextAwareConversionPatternRewriter &rewriter, + ContextAwareConversionPatternRewriterImpl &impl, RewriterState &state, + RewriterState &newState); + LogicalResult legalizePatternRootUpdates( + ContextAwareConversionPatternRewriter &rewriter, + ContextAwareConversionPatternRewriterImpl &impl, RewriterState &state, + RewriterState &newState); + + //===--------------------------------------------------------------------===// + // Cost Model + //===--------------------------------------------------------------------===// + + /// Build an optimistic legalization graph given the provided patterns. This + /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with + /// patterns for operations that are not directly legal, but may be + /// transitively legal for the current target given the provided patterns. + void buildLegalizationGraph( + LegalizationPatterns &anyOpLegalizerPatterns, + DenseMap &legalizerPatterns); + + /// Compute the benefit of each node within the computed legalization graph. + /// This orders the patterns within 'legalizerPatterns' based upon two + /// criteria: + /// 1) Prefer patterns that have the lowest legalization depth, i.e. + /// represent the more direct mapping to the target. + /// 2) When comparing patterns with the same legalization depth, prefer the + /// pattern with the highest PatternBenefit. This allows for users to + /// prefer specific legalizations over others. + void computeLegalizationGraphBenefit( + LegalizationPatterns &anyOpLegalizerPatterns, + DenseMap &legalizerPatterns); + + /// Compute the legalization depth when legalizing an operation of the given + /// type. + unsigned computeOpLegalizationDepth( + OperationName op, DenseMap &minOpPatternDepth, + DenseMap &legalizerPatterns); + + /// Apply the conversion cost model to the given set of patterns, and return + /// the smallest legalization depth of any of the patterns. See + /// `computeLegalizationGraphBenefit` for the breakdown of the cost model. + unsigned applyCostModelToPatterns( + LegalizationPatterns &patterns, + DenseMap &minOpPatternDepth, + DenseMap &legalizerPatterns); + + /// The current set of patterns that have been applied. + SmallPtrSet appliedPatterns; + + /// The legalization information provided by the target. + const ConversionTarget ⌖ + + /// The pattern applicator to use for conversions. + PatternApplicator applicator; + + /// Dialect conversion configuration. + const ConversionConfig &config; +}; +} // namespace + +OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo, + const FrozenRewritePatternSet &patterns, + const ConversionConfig &config) + : target(targetInfo), applicator(patterns), config(config) { + // The set of patterns that can be applied to illegal operations to transform + // them into legal ones. + DenseMap legalizerPatterns; + LegalizationPatterns anyOpLegalizerPatterns; + + buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns); + computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns); +} + +bool OperationLegalizer::isIllegal(Operation *op) const { + return target.isIllegal(op); +} + +LogicalResult OperationLegalizer::legalize( + Operation *op, ContextAwareConversionPatternRewriter &rewriter) { +#ifndef NDEBUG + const char *logLineComment = + "//===-------------------------------------------===//\n"; + + auto &logger = rewriter.getImpl().logger; +#endif + LLVM_DEBUG({ + logger.getOStream() << "\n"; + logger.startLine() << logLineComment; + logger.startLine() << "Legalizing operation : '" << op->getName() << "'(" + << op << ") {\n"; + logger.indent(); + + // If the operation has no regions, just print it here. + if (op->getNumRegions() == 0) { + op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm()); + logger.getOStream() << "\n\n"; + } + }); + + // Check if this operation is legal on the target. + if (auto legalityInfo = target.isLegal(op)) { + LLVM_DEBUG({ + logSuccess( + logger, "operation marked legal by the target{0}", + legalityInfo->isRecursivelyLegal + ? "; NOTE: operation is recursively legal; skipping internals" + : ""); + logger.startLine() << logLineComment; + }); + + // If this operation is recursively legal, mark its children as ignored so + // that we don't consider them for legalization. + if (legalityInfo->isRecursivelyLegal) { + op->walk([&](Operation *nested) { + if (op != nested) rewriter.getImpl().ignoredOps.insert(nested); + }); + } + + return success(); + } + + // Check to see if the operation is ignored and doesn't need to be converted. + if (rewriter.getImpl().isOpIgnored(op)) { + LLVM_DEBUG({ + logSuccess(logger, "operation marked 'ignored' during conversion"); + logger.startLine() << logLineComment; + }); + return success(); + } + + // If the operation isn't legal, try to fold it in-place. + // TODO: Should we always try to do this, even if the op is + // already legal? + if (succeeded(legalizeWithFold(op, rewriter))) { + LLVM_DEBUG({ + logSuccess(logger, "operation was folded"); + logger.startLine() << logLineComment; + }); + return success(); + } + + // Otherwise, we need to apply a legalization pattern to this operation. + if (succeeded(legalizeWithPattern(op, rewriter))) { + LLVM_DEBUG({ + logSuccess(logger, ""); + logger.startLine() << logLineComment; + }); + return success(); + } + + LLVM_DEBUG({ + logFailure(logger, "no matched legalization pattern"); + logger.startLine() << logLineComment; + }); + return failure(); +} + +LogicalResult OperationLegalizer::legalizeWithFold( + Operation *op, ContextAwareConversionPatternRewriter &rewriter) { + auto &rewriterImpl = rewriter.getImpl(); + RewriterState curState = rewriterImpl.getCurrentState(); + + LLVM_DEBUG({ + rewriterImpl.logger.startLine() << "* Fold {\n"; + rewriterImpl.logger.indent(); + }); + + // Try to fold the operation. + SmallVector replacementValues; + rewriter.setInsertionPoint(op); + if (failed(rewriter.tryFold(op, replacementValues))) { + LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold")); + return failure(); + } + // An empty list of replacement values indicates that the fold was in-place. + // As the operation changed, a new legalization needs to be attempted. + if (replacementValues.empty()) return legalize(op, rewriter); + + // Insert a replacement for 'op' with the folded replacement values. + rewriter.replaceOp(op, replacementValues); + + // Recursively legalize any new constant operations. + for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size(); + i != e; ++i) { + auto *createOp = + dyn_cast(rewriterImpl.rewrites[i].get()); + if (!createOp) continue; + if (failed(legalize(createOp->getOperation(), rewriter))) { + LLVM_DEBUG(logFailure(rewriterImpl.logger, + "failed to legalize generated constant '{0}'", + createOp->getOperation()->getName())); + rewriterImpl.resetState(curState); + return failure(); + } + } + + LLVM_DEBUG(logSuccess(rewriterImpl.logger, "")); + return success(); +} + +LogicalResult OperationLegalizer::legalizeWithPattern( + Operation *op, ContextAwareConversionPatternRewriter &rewriter) { + auto &rewriterImpl = rewriter.getImpl(); + + // Functor that returns if the given pattern may be applied. + auto canApply = [&](const Pattern &pattern) { + bool canApply = canApplyPattern(op, pattern, rewriter); + if (canApply && config.listener) + config.listener->notifyPatternBegin(pattern, op); + return canApply; + }; + + // Functor that cleans up the rewriter state after a pattern failed to match. + RewriterState curState = rewriterImpl.getCurrentState(); + auto onFailure = [&](const Pattern &pattern) { + assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); + LLVM_DEBUG({ + logFailure(rewriterImpl.logger, "pattern failed to match"); + if (rewriterImpl.config.notifyCallback) { + Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark); + diag << "Failed to apply pattern \"" << pattern.getDebugName() + << "\" on op:\n" + << *op; + rewriterImpl.config.notifyCallback(diag); + } + }); + if (config.listener) config.listener->notifyPatternEnd(pattern, failure()); + rewriterImpl.resetState(curState); + appliedPatterns.erase(&pattern); + }; + + // Functor that performs additional legalization when a pattern is + // successfully applied. + auto onSuccess = [&](const Pattern &pattern) { + assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); + auto result = legalizePatternResult(op, pattern, rewriter, curState); + appliedPatterns.erase(&pattern); + if (failed(result)) rewriterImpl.resetState(curState); + if (config.listener) config.listener->notifyPatternEnd(pattern, result); + return result; + }; + + // Try to match and rewrite a pattern on this operation. + return applicator.matchAndRewrite(op, rewriter, canApply, onFailure, + onSuccess); +} + +bool OperationLegalizer::canApplyPattern( + Operation *op, const Pattern &pattern, + ContextAwareConversionPatternRewriter &rewriter) { + LLVM_DEBUG({ + auto &os = rewriter.getImpl().logger; + os.getOStream() << "\n"; + os.startLine() << "* Pattern : '" << op->getName() << " -> ("; + llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream()); + os.getOStream() << ")' {\n"; + os.indent(); + }); + + // Ensure that we don't cycle by not allowing the same pattern to be + // applied twice in the same recursion stack if it is not known to be safe. + if (!pattern.hasBoundedRewriteRecursion() && + !appliedPatterns.insert(&pattern).second) { + LLVM_DEBUG( + logFailure(rewriter.getImpl().logger, "pattern was already applied")); + return false; + } + return true; +} + +LogicalResult OperationLegalizer::legalizePatternResult( + Operation *op, const Pattern &pattern, + ContextAwareConversionPatternRewriter &rewriter, RewriterState &curState) { + auto &impl = rewriter.getImpl(); + assert(impl.pendingRootUpdates.empty() && "dangling root updates"); + +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // Check that the root was either replaced or updated in place. + auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites); + auto replacedRoot = [&] { + return hasRewrite(newRewrites, op); + }; + auto updatedRootInPlace = [&] { + return hasRewrite(newRewrites, op); + }; + if (!replacedRoot() && !updatedRootInPlace()) + llvm::report_fatal_error("expected pattern to replace the root operation"); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + + // Legalize each of the actions registered during application. + RewriterState newState = impl.getCurrentState(); + if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState, + newState)) || + failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) || + failed(legalizePatternCreatedOperations(rewriter, impl, curState, + newState))) { + return failure(); + } + + LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully")); + return success(); +} + +LogicalResult OperationLegalizer::legalizePatternBlockRewrites( + Operation *op, ContextAwareConversionPatternRewriter &rewriter, + ContextAwareConversionPatternRewriterImpl &impl, RewriterState &state, + RewriterState &newState) { + SmallPtrSet operationsToIgnore; + + // If the pattern moved or created any blocks, make sure the types of block + // arguments get legalized. + for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { + BlockRewrite *rewrite = dyn_cast(impl.rewrites[i].get()); + if (!rewrite) continue; + Block *block = rewrite->getBlock(); + if (isa(rewrite)) + continue; + // Only check blocks outside of the current operation. + Operation *parentOp = block->getParentOp(); + if (!parentOp || parentOp == op || block->getNumArguments() == 0) continue; + + // If the region of the block has a type converter, try to convert the block + // directly. + if (auto *converter = impl.regionToConverter.lookup(block->getParent())) { + std::optional conversion = + converter->convertBlockSignature(block); + if (!conversion) { + LLVM_DEBUG(logFailure(impl.logger, + "failed to convert types of moved " + "block")); + return failure(); + } + impl.applySignatureConversion(rewriter, block, converter, *conversion); + continue; + } + + // Otherwise, check that this operation isn't one generated by this pattern. + // This is because we will attempt to legalize the parent operation, and + // blocks in regions created by this pattern will already be legalized later + // on. If we haven't built the set yet, build it now. + if (operationsToIgnore.empty()) { + for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e; + ++i) { + auto *createOp = + dyn_cast(impl.rewrites[i].get()); + if (!createOp) continue; + operationsToIgnore.insert(createOp->getOperation()); + } + } + + // If this operation should be considered for re-legalization, try it. + if (operationsToIgnore.insert(parentOp).second && + failed(legalize(parentOp, rewriter))) { + LLVM_DEBUG(logFailure(impl.logger, + "operation '{0}'({1}) became illegal after rewrite", + parentOp->getName(), parentOp)); + return failure(); + } + } + return success(); +} + +LogicalResult OperationLegalizer::legalizePatternCreatedOperations( + ContextAwareConversionPatternRewriter &rewriter, + ContextAwareConversionPatternRewriterImpl &impl, RewriterState &state, + RewriterState &newState) { + for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { + auto *createOp = dyn_cast(impl.rewrites[i].get()); + if (!createOp) continue; + Operation *op = createOp->getOperation(); + if (failed(legalize(op, rewriter))) { + LLVM_DEBUG(logFailure(impl.logger, + "failed to legalize generated operation '{0}'({1})", + op->getName(), op)); + return failure(); + } + } + return success(); +} + +LogicalResult OperationLegalizer::legalizePatternRootUpdates( + ContextAwareConversionPatternRewriter &rewriter, + ContextAwareConversionPatternRewriterImpl &impl, RewriterState &state, + RewriterState &newState) { + for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { + auto *rewrite = dyn_cast(impl.rewrites[i].get()); + if (!rewrite) continue; + Operation *op = rewrite->getOperation(); + if (failed(legalize(op, rewriter))) { + LLVM_DEBUG(logFailure( + impl.logger, "failed to legalize operation updated in-place '{0}'", + op->getName())); + return failure(); + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Cost Model + +void OperationLegalizer::buildLegalizationGraph( + LegalizationPatterns &anyOpLegalizerPatterns, + DenseMap &legalizerPatterns) { + // A mapping between an operation and a set of operations that can be used to + // generate it. + DenseMap> parentOps; + // A mapping between an operation and any currently invalid patterns it has. + DenseMap> invalidPatterns; + // A worklist of patterns to consider for legality. + SetVector patternWorklist; + + // Build the mapping from operations to the parent ops that may generate them. + applicator.walkAllPatterns([&](const Pattern &pattern) { + std::optional root = pattern.getRootKind(); + + // If the pattern has no specific root, we can't analyze the relationship + // between the root op and generated operations. Given that, add all such + // patterns to the legalization set. + if (!root) { + anyOpLegalizerPatterns.push_back(&pattern); + return; + } + + // Skip operations that are always known to be legal. + if (target.getOpAction(*root) == LegalizationAction::Legal) return; + + // Add this pattern to the invalid set for the root op and record this root + // as a parent for any generated operations. + invalidPatterns[*root].insert(&pattern); + for (auto op : pattern.getGeneratedOps()) parentOps[op].insert(*root); + + // Add this pattern to the worklist. + patternWorklist.insert(&pattern); + }); + + // If there are any patterns that don't have a specific root kind, we can't + // make direct assumptions about what operations will never be legalized. + // Note: Technically we could, but it would require an analysis that may + // recurse into itself. It would be better to perform this kind of filtering + // at a higher level than here anyways. + if (!anyOpLegalizerPatterns.empty()) { + for (const Pattern *pattern : patternWorklist) + legalizerPatterns[*pattern->getRootKind()].push_back(pattern); + return; + } + + while (!patternWorklist.empty()) { + auto *pattern = patternWorklist.pop_back_val(); + + // Check to see if any of the generated operations are invalid. + if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) { + std::optional action = target.getOpAction(op); + return !legalizerPatterns.count(op) && + (!action || action == LegalizationAction::Illegal); + })) + continue; + + // Otherwise, if all of the generated operation are valid, this op is now + // legal so add all of the child patterns to the worklist. + legalizerPatterns[*pattern->getRootKind()].push_back(pattern); + invalidPatterns[*pattern->getRootKind()].erase(pattern); + + // Add any invalid patterns of the parent operations to see if they have now + // become legal. + for (auto op : parentOps[*pattern->getRootKind()]) + patternWorklist.set_union(invalidPatterns[op]); + } +} + +void OperationLegalizer::computeLegalizationGraphBenefit( + LegalizationPatterns &anyOpLegalizerPatterns, + DenseMap &legalizerPatterns) { + // The smallest pattern depth, when legalizing an operation. + DenseMap minOpPatternDepth; + + // For each operation that is transitively legal, compute a cost for it. + for (auto &opIt : legalizerPatterns) + if (!minOpPatternDepth.count(opIt.first)) + computeOpLegalizationDepth(opIt.first, minOpPatternDepth, + legalizerPatterns); + + // Apply the cost model to the patterns that can match any operation. Those + // with a specific operation type are already resolved when computing the op + // legalization depth. + if (!anyOpLegalizerPatterns.empty()) + applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth, + legalizerPatterns); + + // Apply a cost model to the pattern applicator. We order patterns first by + // depth then benefit. `legalizerPatterns` contains per-op patterns by + // decreasing benefit. + applicator.applyCostModel([&](const Pattern &pattern) { + ArrayRef orderedPatternList; + if (std::optional rootName = pattern.getRootKind()) + orderedPatternList = legalizerPatterns[*rootName]; + else + orderedPatternList = anyOpLegalizerPatterns; + + // If the pattern is not found, then it was removed and cannot be matched. + auto *it = llvm::find(orderedPatternList, &pattern); + if (it == orderedPatternList.end()) + return PatternBenefit::impossibleToMatch(); + + // Patterns found earlier in the list have higher benefit. + return PatternBenefit(std::distance(it, orderedPatternList.end())); + }); +} + +unsigned OperationLegalizer::computeOpLegalizationDepth( + OperationName op, DenseMap &minOpPatternDepth, + DenseMap &legalizerPatterns) { + // Check for existing depth. + auto depthIt = minOpPatternDepth.find(op); + if (depthIt != minOpPatternDepth.end()) return depthIt->second; + + // If a mapping for this operation does not exist, then this operation + // is always legal. Return 0 as the depth for a directly legal operation. + auto opPatternsIt = legalizerPatterns.find(op); + if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty()) + return 0u; + + // Record this initial depth in case we encounter this op again when + // recursively computing the depth. + minOpPatternDepth.try_emplace(op, std::numeric_limits::max()); + + // Apply the cost model to the operation patterns, and update the minimum + // depth. + unsigned minDepth = applyCostModelToPatterns( + opPatternsIt->second, minOpPatternDepth, legalizerPatterns); + minOpPatternDepth[op] = minDepth; + return minDepth; +} + +unsigned OperationLegalizer::applyCostModelToPatterns( + LegalizationPatterns &patterns, + DenseMap &minOpPatternDepth, + DenseMap &legalizerPatterns) { + unsigned minDepth = std::numeric_limits::max(); + + // Compute the depth for each pattern within the set. + SmallVector, 4> patternsByDepth; + patternsByDepth.reserve(patterns.size()); + for (const Pattern *pattern : patterns) { + unsigned depth = 1; + for (auto generatedOp : pattern->getGeneratedOps()) { + unsigned generatedOpDepth = computeOpLegalizationDepth( + generatedOp, minOpPatternDepth, legalizerPatterns); + depth = std::max(depth, generatedOpDepth + 1); + } + patternsByDepth.emplace_back(pattern, depth); + + // Update the minimum depth of the pattern list. + minDepth = std::min(minDepth, depth); + } + + // If the operation only has one legalization pattern, there is no need to + // sort them. + if (patternsByDepth.size() == 1) return minDepth; + + // Sort the patterns by those likely to be the most beneficial. + std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(), + [](const std::pair &lhs, + const std::pair &rhs) { + // First sort by the smaller pattern legalization + // depth. + if (lhs.second != rhs.second) + return lhs.second < rhs.second; + + // Then sort by the larger pattern benefit. + auto lhsBenefit = lhs.first->getBenefit(); + auto rhsBenefit = rhs.first->getBenefit(); + return lhsBenefit > rhsBenefit; + }); + + // Update the legalization pattern to use the new sorted list. + patterns.clear(); + for (auto &patternIt : patternsByDepth) patterns.push_back(patternIt.first); + return minDepth; +} + +//===----------------------------------------------------------------------===// +// OperationConverter +//===----------------------------------------------------------------------===// +namespace { +enum OpConversionMode { + /// In this mode, the conversion will ignore failed conversions to allow + /// illegal operations to co-exist in the IR. + Partial, + + /// In this mode, all operations must be legal for the given target for the + /// conversion to succeed. + Full, + + /// In this mode, operations are analyzed for legality. No actual rewrites are + /// applied to the operations on success. + Analysis, +}; +} // namespace + +// This class converts operations to a given conversion target via a set of +// rewrite patterns. The conversion behaves differently depending on the +// conversion mode. +struct OperationConverter { + explicit OperationConverter(const ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + const ConversionConfig &config, + OpConversionMode mode) + : config(config), + opLegalizer(target, patterns, this->config), + mode(mode) {} + + /// Converts the given operations to the conversion target. + LogicalResult convertOperations(ArrayRef ops); + + private: + /// Converts an operation with the given rewriter. + LogicalResult convert(ContextAwareConversionPatternRewriter &rewriter, + Operation *op); + + /// Dialect conversion configuration. + ConversionConfig config; + + /// The legalizer to use when converting operations. + OperationLegalizer opLegalizer; + + /// The conversion mode to use when legalizing operations. + OpConversionMode mode; +}; + +LogicalResult OperationConverter::convert( + ContextAwareConversionPatternRewriter &rewriter, Operation *op) { + // Legalize the given operation. + if (failed(opLegalizer.legalize(op, rewriter))) { + // Handle the case of a failed conversion for each of the different modes. + // Full conversions expect all operations to be converted. + if (mode == OpConversionMode::Full) + return op->emitError() + << "failed to legalize operation '" << op->getName() << "'"; + // Partial conversions allow conversions to fail iff the operation was not + // explicitly marked as illegal. If the user provided a `unlegalizedOps` + // set, non-legalizable ops are added to that set. + if (mode == OpConversionMode::Partial) { + if (opLegalizer.isIllegal(op)) + return op->emitError() + << "failed to legalize operation '" << op->getName() + << "' that was explicitly marked illegal"; + if (config.unlegalizedOps) config.unlegalizedOps->insert(op); + } + } else if (mode == OpConversionMode::Analysis) { + // Analysis conversions don't fail if any operations fail to legalize, + // they are only interested in the operations that were successfully + // legalized. + if (config.legalizableOps) config.legalizableOps->insert(op); + } + return success(); +} + +static LogicalResult legalizeUnresolvedMaterialization( + RewriterBase &rewriter, UnresolvedMaterializationRewrite *rewrite) { + UnrealizedConversionCastOp op = rewrite->getOperation(); + assert(!op.use_empty() && + "expected that dead materializations have already been DCE'd"); + Operation::operand_range inputOperands = op.getOperands(); + + // Try to materialize the conversion. + if (const ContextAwareTypeConverter *converter = rewrite->getConverter()) { + rewriter.setInsertionPoint(op); + SmallVector newMaterialization; + switch (rewrite->getMaterializationKind()) { + case MaterializationKind::Target: + newMaterialization = converter->materializeTargetConversion( + rewriter, op->getLoc(), op.getResultTypes(), inputOperands, + rewrite->getOriginalType()); + break; + case MaterializationKind::Source: + assert(op->getNumResults() == 1 && "expected single result"); + Value sourceMat = converter->materializeSourceConversion( + rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands); + if (sourceMat) newMaterialization.push_back(sourceMat); + break; + } + if (!newMaterialization.empty()) { +#ifndef NDEBUG + ValueRange newMaterializationRange(newMaterialization); + assert(TypeRange(newMaterializationRange) == op.getResultTypes() && + "materialization callback produced value of incorrect type"); +#endif // NDEBUG + rewriter.replaceOp(op, newMaterialization); + return success(); + } + } + + InFlightDiagnostic diag = op->emitError() + << "failed to legalize unresolved materialization " + "from (" + << inputOperands.getTypes() << ") to (" + << op.getResultTypes() + << ") that remained live after conversion"; + diag.attachNote(op->getUsers().begin()->getLoc()) + << "see existing live user here: " << *op->getUsers().begin(); + return failure(); +} + +LogicalResult OperationConverter::convertOperations(ArrayRef ops) { + if (ops.empty()) return success(); + const ConversionTarget &target = opLegalizer.getTarget(); + + // Compute the set of operations and blocks to convert. + SmallVector toConvert; + for (auto *op : ops) { + op->walk>( + [&](Operation *op) { + toConvert.push_back(op); + // Don't check this operation's children for conversion if the + // operation is recursively legal. + auto legalityInfo = target.isLegal(op); + if (legalityInfo && legalityInfo->isRecursivelyLegal) + return WalkResult::skip(); + return WalkResult::advance(); + }); + } + + // Convert each operation and discard rewrites on failure. + ContextAwareConversionPatternRewriter rewriter(ops.front()->getContext(), + config); + ContextAwareConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); + + for (auto *op : toConvert) + if (failed(convert(rewriter, op))) + return rewriterImpl.undoRewrites(), failure(); + + // After a successful conversion, apply rewrites. + rewriterImpl.applyRewrites(); + + // Gather all unresolved materializations. + SmallVector allCastOps; + const DenseMap + &materializations = rewriterImpl.unresolvedMaterializations; + for (auto it : materializations) { + if (rewriterImpl.eraseRewriter.wasErased(it.first)) continue; + allCastOps.push_back(it.first); + } + + // Reconcile all UnrealizedConversionCastOps that were inserted by the + // dialect conversion frameworks. (Not the one that were inserted by + // patterns.) + SmallVector remainingCastOps; + reconcileUnrealizedCasts(allCastOps, &remainingCastOps); + + // Try to legalize all unresolved materializations. + if (config.buildMaterializations) { + IRRewriter rewriter(rewriterImpl.context, config.listener); + for (UnrealizedConversionCastOp castOp : remainingCastOps) { + auto it = materializations.find(castOp); + assert(it != materializations.end() && "inconsistent state"); + if (failed(legalizeUnresolvedMaterialization(rewriter, it->second))) + return failure(); + } + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// FunctionOpInterfaceSignatureConversion +//===----------------------------------------------------------------------===// + +static LogicalResult convertFuncOpTypes( + FunctionOpInterface funcOp, const ContextAwareTypeConverter &typeConverter, + ContextAwareConversionPatternRewriter &rewriter) { + FunctionType type = dyn_cast(funcOp.getFunctionType()); + if (!type) return failure(); + + // Convert the original function types. + ContextAwareTypeConverter::SignatureConversion result(type.getNumInputs()); + SmallVector newResults; + if (failed(typeConverter.convertSignatureArgs(funcOp, type.getInputs(), + result)) || + failed( + typeConverter.convertTypes(type.getResults(), funcOp, newResults)) || + failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(), + typeConverter, &result))) + return failure(); + + // Update the function signature in-place. + auto newType = FunctionType::get(rewriter.getContext(), + result.getConvertedTypes(), newResults); + + rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); }); + + return success(); +} + +/// Create a default conversion pattern that rewrites the type signature of a +/// FunctionOpInterface op. This only supports ops which use FunctionType to +/// represent their type. +namespace { +struct FunctionOpInterfaceSignatureConversion + : public ContextAwareConversionPattern { + FunctionOpInterfaceSignatureConversion( + StringRef functionLikeOpName, MLIRContext *ctx, + const ContextAwareTypeConverter &converter) + : ContextAwareConversionPattern(converter, functionLikeOpName, + /*benefit=*/1, ctx) {} + + LogicalResult matchAndRewrite( + Operation *op, ArrayRef /*operands*/, + ContextAwareConversionPatternRewriter &rewriter) const override { + FunctionOpInterface funcOp = cast(op); + return convertFuncOpTypes(funcOp, *typeConverter, rewriter); + } +}; + +struct AnyFunctionOpInterfaceSignatureConversion + : public OpInterfaceContextAwareConversionPattern { + using OpInterfaceContextAwareConversionPattern:: + OpInterfaceContextAwareConversionPattern; + + LogicalResult matchAndRewrite( + FunctionOpInterface funcOp, ArrayRef /*operands*/, + ContextAwareConversionPatternRewriter &rewriter) const override { + return convertFuncOpTypes(funcOp, *typeConverter, rewriter); + } +}; +} // namespace + +FailureOr convertOpResultTypes( + Operation *op, ValueRange operands, + const ContextAwareTypeConverter &converter, + ContextAwareConversionPatternRewriter &rewriter) { + assert(op && "Invalid op"); + Location loc = op->getLoc(); + if (converter.isLegal(op)) + return rewriter.notifyMatchFailure(loc, "op already legal"); + + OperationState newOp(loc, op->getName()); + newOp.addOperands(operands); + + SmallVector newResultTypes; + if (failed(converter.convertTypes(op->getResultTypes(), op->getResults(), + newResultTypes))) + return rewriter.notifyMatchFailure(loc, "couldn't convert return types"); + + newOp.addTypes(newResultTypes); + newOp.addAttributes(op->getAttrs()); + return rewriter.create(newOp); +} + +void populateFunctionOpInterfaceTypeConversionPattern( + StringRef functionLikeOpName, RewritePatternSet &patterns, + const ContextAwareTypeConverter &converter) { + patterns.add( + functionLikeOpName, patterns.getContext(), converter); +} + +void populateAnyFunctionOpInterfaceTypeConversionPattern( + RewritePatternSet &patterns, const ContextAwareTypeConverter &converter) { + patterns.add( + converter, patterns.getContext()); +} + +//===----------------------------------------------------------------------===// +// Op Conversion Entry Points +//===----------------------------------------------------------------------===// + +LogicalResult applyContextAwarePartialConversion( + ArrayRef ops, const ConversionTarget &target, + const FrozenRewritePatternSet &patterns, ConversionConfig config) { + OperationConverter opConverter(target, patterns, config, + OpConversionMode::Partial); + return opConverter.convertOperations(ops); +} +LogicalResult applyPartialConversion(Operation *op, + const ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + ConversionConfig config) { + return applyContextAwarePartialConversion(llvm::ArrayRef(op), target, + patterns, config); +} + +} // namespace heir +} // namespace mlir diff --git a/lib/Utils/ContextAwareDialectConversion.h b/lib/Utils/ContextAwareDialectConversion.h new file mode 100644 index 000000000..58cb3f6dd --- /dev/null +++ b/lib/Utils/ContextAwareDialectConversion.h @@ -0,0 +1,408 @@ +#ifndef LIB_UTILS_CONTEXTAWAREDIALECTCONVERSION_H_ +#define LIB_UTILS_CONTEXTAWAREDIALECTCONVERSION_H_ + +#include "lib/Utils/ContextAwareTypeConversion.h" +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir { +namespace heir { + +// This file is a port of DialectConversion.h from MLIR upstream, with the +// following changes: +// +// - Replaced uses of TypeConverter with AttributeAwareTypeConverter +// - Changed ConversionPattern to ContextAwareConversionPattern +// - Removed ConversionTarget so it can be reused from upstream +// - Removed reconcileUnrealizedCasts as it can be reused from upstream +// - Removed 1-N dialect conversion hooks, though some support for 1-N +// conversion remains in the TypeConverter internals. +// - Removed PDL stuff +// - Removed all drivers except applyPartialConversion + +class ContextAwareConversionPatternRewriter; + +//===----------------------------------------------------------------------===// +// Conversion Patterns +//===----------------------------------------------------------------------===// + +/// Base class for the conversion patterns. This pattern class enables type +/// conversions, and other uses specific to the conversion framework. As such, +/// patterns of this type can only be used with the 'apply*' methods below. +class ContextAwareConversionPattern : public RewritePattern { + public: + /// Hook for derived classes to implement rewriting. `op` is the (first) + /// operation matched by the pattern, `operands` is a list of the rewritten + /// operand values that are passed to `op`, `rewriter` can be used to emit the + /// new operations. This function should not fail. If some specific cases of + /// the operation are not supported, these cases should not be matched. + virtual void rewrite(Operation *op, ArrayRef operands, + ContextAwareConversionPatternRewriter &rewriter) const { + llvm_unreachable("unimplemented rewrite"); + } + + /// Hook for derived classes to implement combined matching and rewriting. + /// This overload supports only 1:1 replacements. + virtual LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ContextAwareConversionPatternRewriter &rewriter) const { + if (failed(match(op))) return failure(); + rewrite(op, operands, rewriter); + return success(); + } + virtual LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ContextAwareConversionPatternRewriter &rewriter) const { + return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + } + + /// Attempt to match and rewrite the IR root at the specified operation. + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final; + + /// Return the type converter held by this pattern, or nullptr if the pattern + /// does not require type conversion. + const ContextAwareTypeConverter *getTypeConverter() const { + return typeConverter; + } + + protected: + /// See `RewritePattern::RewritePattern` for information on the other + /// available constructors. + using RewritePattern::RewritePattern; + /// Construct a conversion pattern with the given converter, and forward the + /// remaining arguments to RewritePattern. + template + ContextAwareConversionPattern(const ContextAwareTypeConverter &typeConverter, + Args &&...args) + : RewritePattern(std::forward(args)...), + typeConverter(&typeConverter) {} + + /// Given an array of value ranges, which are the inputs to a 1:N adaptor, + /// try to extract the single value of each range to construct a the inputs + /// for a 1:1 adaptor. + /// + /// This function produces a fatal error if at least one range has 0 or + /// more than 1 value: "pattern 'name' does not support 1:N conversion" + SmallVector getOneToOneAdaptorOperands( + ArrayRef operands) const; + + protected: + /// An optional type converter for use by this pattern. + const ContextAwareTypeConverter *typeConverter = nullptr; + + private: + using RewritePattern::rewrite; +}; + +/// ContextAwareOpConversionPattern is a wrapper around +/// ContextAwareConversionPattern that allows for matching and rewriting +/// against an instance of a derived operation class as opposed to a raw +/// Operation. +template +class ContextAwareOpConversionPattern : public ContextAwareConversionPattern { + public: + using OpAdaptor = typename SourceOp::Adaptor; + + ContextAwareOpConversionPattern(MLIRContext *context, + PatternBenefit benefit = 1) + : ContextAwareConversionPattern(SourceOp::getOperationName(), benefit, + context) {} + ContextAwareOpConversionPattern( + const ContextAwareTypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : ContextAwareConversionPattern( + typeConverter, SourceOp::getOperationName(), benefit, context) {} + + /// Wrappers around the ContextAwareConversionPattern methods that pass the + /// derived op type. + LogicalResult match(Operation *op) const final { + return match(cast(op)); + } + void rewrite(Operation *op, ArrayRef operands, + ContextAwareConversionPatternRewriter &rewriter) const final { + auto sourceOp = cast(op); + rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); + } + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ContextAwareConversionPatternRewriter &rewriter) const final { + auto sourceOp = cast(op); + return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); + } + + /// Rewrite and Match methods that operate on the SourceOp type. These must be + /// overridden by the derived pattern class. + virtual LogicalResult match(SourceOp op) const { + llvm_unreachable("must override match or matchAndRewrite"); + } + virtual void rewrite(SourceOp op, OpAdaptor adaptor, + ContextAwareConversionPatternRewriter &rewriter) const { + llvm_unreachable("must override matchAndRewrite or a rewrite method"); + } + virtual LogicalResult matchAndRewrite( + SourceOp op, OpAdaptor adaptor, + ContextAwareConversionPatternRewriter &rewriter) const { + if (failed(match(op))) return failure(); + rewrite(op, adaptor, rewriter); + return success(); + } + + private: + using ContextAwareConversionPattern::matchAndRewrite; +}; + +/// OpInterfaceContextAwareConversionPattern is a wrapper around +/// ContextAwareConversionPattern that allows for matching and rewriting against +/// an instance of an OpInterface class as opposed to a raw Operation. +template +class OpInterfaceContextAwareConversionPattern + : public ContextAwareConversionPattern { + public: + OpInterfaceContextAwareConversionPattern(MLIRContext *context, + PatternBenefit benefit = 1) + : ContextAwareConversionPattern(Pattern::MatchInterfaceOpTypeTag(), + SourceOp::getInterfaceID(), benefit, + context) {} + OpInterfaceContextAwareConversionPattern( + const ContextAwareTypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : ContextAwareConversionPattern( + typeConverter, Pattern::MatchInterfaceOpTypeTag(), + SourceOp::getInterfaceID(), benefit, context) {} + + /// Wrappers around the ContextAwareConversionPattern methods that pass the + /// derived op type. + void rewrite(Operation *op, ArrayRef operands, + ContextAwareConversionPatternRewriter &rewriter) const final { + rewrite(cast(op), operands, rewriter); + } + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ContextAwareConversionPatternRewriter &rewriter) const final { + return matchAndRewrite(cast(op), operands, rewriter); + } + + /// Rewrite and Match methods that operate on the SourceOp type. These must be + /// overridden by the derived pattern class. + virtual void rewrite(SourceOp op, ArrayRef operands, + ContextAwareConversionPatternRewriter &rewriter) const { + llvm_unreachable("must override matchAndRewrite or a rewrite method"); + } + virtual LogicalResult matchAndRewrite( + SourceOp op, ArrayRef operands, + ContextAwareConversionPatternRewriter &rewriter) const { + if (failed(match(op))) return failure(); + rewrite(op, operands, rewriter); + return success(); + } + + private: + using ContextAwareConversionPattern::matchAndRewrite; +}; + +/// OpTraitContextAwareConversionPattern is a wrapper around +/// ContextAwareConversionPattern that allows for matching and rewriting against +/// instances of an operation that possess a given trait. +template