Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[WIP] Rewrite linalg-to-tensor-ext and materialize layouts #1428

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
3 changes: 2 additions & 1 deletion lib/Dialect/Secret/Conversions/SecretToBGV/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ cc_library(
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Parameters/BGV:Params",
"@heir//lib/Utils",
"@heir//lib/Utils:ConversionUtils",
"@heir//lib/Utils:ContextAwareConversionUtils",
"@heir//lib/Utils:ContextAwareTypeConversion",
"@heir//lib/Utils/Polynomial",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
Expand Down
38 changes: 15 additions & 23 deletions lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.h"

#include <cassert>
#include <cmath>
#include <cstdint>
#include <optional>
Expand All @@ -23,7 +22,8 @@
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "lib/Dialect/Secret/IR/SecretTypes.h"
#include "lib/Parameters/BGV/Params.h"
#include "lib/Utils/ConversionUtils.h"
#include "lib/Utils/ContextAwareConversionUtils.h"
#include "lib/Utils/ContextAwareTypeConversion.h"
#include "lib/Utils/Polynomial/Polynomial.h"
#include "lib/Utils/Utils.h"
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
Expand Down Expand Up @@ -88,17 +88,21 @@ polynomial::RingAttr getRlweRNSRingWithLevel(polynomial::RingAttr ringAttr,

} // namespace

class SecretToBGVTypeConverter : public TypeWithAttrTypeConverter {
class SecretToBGVTypeConverter
: public UniquelyNamedAttributeAwareTypeConverter {
public:
SecretToBGVTypeConverter(MLIRContext *ctx,
::mlir::heir::polynomial::RingAttr rlweRing,
int64_t ptm)
: TypeWithAttrTypeConverter(mgmt::MgmtDialect::kArgMgmtAttrName) {
: UniquelyNamedAttributeAwareTypeConverter(
mgmt::MgmtDialect::kArgMgmtAttrName) {
ring = rlweRing;
plaintextModulus = ptm;

// isLegal/isSignatureLegal will always be true
addConversion([](Type type) { return type; });
addConversion([this](secret::SecretType type, mgmt::MgmtAttr mgmtAttr) {
return convertSecretTypeWithMgmtAttr(type, mgmtAttr);
});
addConversion([](Type type, Attribute attr) { return type; });
}

Type convertSecretTypeWithMgmtAttr(secret::SecretType type,
Expand Down Expand Up @@ -132,18 +136,6 @@ class SecretToBGVTypeConverter : public TypeWithAttrTypeConverter {
lwe::ModulusChainAttr::get(ctx, moduliChain, level));
}

Type convertTypeWithAttr(Type type, Attribute attr) const override {
auto secretTy = dyn_cast<secret::SecretType>(type);
// guard against null attribute
if (secretTy && attr) {
auto mgmtAttr = dyn_cast<mgmt::MgmtAttr>(attr);
if (mgmtAttr) {
return convertSecretTypeWithMgmtAttr(secretTy, mgmtAttr);
}
}
return type;
}

private:
::mlir::heir::polynomial::RingAttr ring;
int64_t plaintextModulus;
Expand Down Expand Up @@ -283,12 +275,11 @@ struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> {
target.addIllegalDialect<secret::SecretDialect>();
target.addIllegalOp<mgmt::ModReduceOp, mgmt::RelinearizeOp>();
target.addIllegalOp<secret::GenericOp>();
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return typeConverter.isFuncArgumentAndResultLegal(op);
});
target.addDynamicallyLegalOp<func::FuncOp>(
[&](func::FuncOp op) { return typeConverter.isLegal(op); });

patterns.add<
ConvertFuncWithContextAwareTypeConverter,
ContextAwareFuncConversion,
SecretGenericOpCipherConversion<arith::AddIOp, bgv::AddOp>,
SecretGenericOpCipherConversion<arith::SubIOp, bgv::SubOp>,
SecretGenericOpCipherConversion<arith::MulIOp, bgv::MulOp>,
Expand All @@ -301,7 +292,8 @@ struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> {
SecretGenericOpCipherPlainConversion<arith::MulIOp, bgv::MulPlainOp>>(
typeConverter, context);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
if (failed(applyContextAwarePartialConversion(module, target,
std::move(patterns)))) {
return signalPassFailure();
}

Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Secret/Conversions/SecretToCGGI/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ cc_library(
"@heir//lib/Dialect/Comb/IR:Dialect",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Utils:ContextAwareConversionUtils",
"@heir//lib/Utils:ContextAwareTypeConversion",
"@heir//lib/Utils:ConversionUtils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
Expand Down
98 changes: 60 additions & 38 deletions lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
#include "lib/Dialect/ModuleAttributes.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "lib/Dialect/Secret/IR/SecretTypes.h"
#include "lib/Utils/ConversionUtils.h"
#include "lib/Utils/ContextAwareConversionUtils.h"
#include "lib/Utils/ContextAwareTypeConversion.h"
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "llvm/include/llvm/ADT/Sequence.h" // from @llvm-project
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
Expand Down Expand Up @@ -76,10 +77,10 @@ Value buildSelectTruthTable(Location loc, OpBuilder &b, Value t, Value f,
selectFalse);
}

Operation *convertWriteOpInterface(Operation *op, SmallVector<Value> indices,
Value valueToStore,
TypedValue<MemRefType> toMemRef,
ConversionPatternRewriter &rewriter) {
Operation *convertWriteOpInterface(
Operation *op, SmallVector<Value> indices, Value valueToStore,
TypedValue<MemRefType> toMemRef,
ContextAwareConversionPatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);

MemRefType toMemRefTy = toMemRef.getType();
Expand Down Expand Up @@ -152,9 +153,9 @@ Operation *convertWriteOpInterface(Operation *op, SmallVector<Value> indices,
llvm_unreachable("expected integer or memref to store in ciphertext memref");
}

Operation *convertReadOpInterface(Operation *op, SmallVector<Value> indices,
Value fromMemRef, Type outputType,
ConversionPatternRewriter &rewriter) {
Operation *convertReadOpInterface(
Operation *op, SmallVector<Value> indices, Value fromMemRef,
Type outputType, ContextAwareConversionPatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
MemRefType outputMemRefType = dyn_cast<MemRefType>(outputType);
MemRefType fromMemRefType = cast<MemRefType>(fromMemRef.getType());
Expand Down Expand Up @@ -194,8 +195,9 @@ Operation *convertReadOpInterface(Operation *op, SmallVector<Value> indices,
return subViewOp;
}

SmallVector<Value> encodeInputs(Operation *op, ValueRange inputs,
ConversionPatternRewriter &rewriter) {
SmallVector<Value> encodeInputs(
Operation *op, ValueRange inputs,
ContextAwareConversionPatternRewriter &rewriter) {
// Get the ciphertext type.
lwe::LWECiphertextType ctxtTy;
for (auto input : inputs) {
Expand Down Expand Up @@ -227,16 +229,15 @@ SmallVector<Value> encodeInputs(Operation *op, ValueRange inputs,

} // namespace

class SecretTypeConverter : public TypeConverter {
class SecretTypeConverter : public ContextAwareTypeConverter {
public:
SecretTypeConverter(MLIRContext *ctx, int minBitWidth)
: minBitWidth(minBitWidth) {
addConversion([](Type type) { return type; });

// Convert secret types to LWE ciphertext types
addConversion([ctx, this](secret::SecretType type) -> Type {
addConversion([ctx, this](secret::SecretType type, Attribute attr) -> Type {
return getLWECiphertextForInt(ctx, type.getValueType());
});
addConversion([](Type type, Attribute attr) { return type; });
}

Type getLWECiphertextForInt(MLIRContext *ctx, Type type) const {
Expand Down Expand Up @@ -267,6 +268,17 @@ class SecretTypeConverter : public TypeConverter {
return shapedType.cloneWith(newShape, elementType);
}

FailureOr<Attribute> getContextualAttr(Value value) const override {
return FailureOr<Attribute>(nullptr);
}

LogicalResult convertFuncSignature(
FunctionOpInterface funcOp, SmallVectorImpl<Type> &newArgTypes,
SmallVectorImpl<Type> &newResultTypes) const override {
// FIXME: implement
return failure();
}

int minBitWidth;
};

Expand All @@ -278,7 +290,7 @@ class SecretGenericOpLUTConversion
LogicalResult matchAndRewriteInner(
secret::GenericOp op, TypeRange outputTypes, ValueRange inputs,
ArrayRef<NamedAttribute> attributes,
ConversionPatternRewriter &rewriter) const override {
ContextAwareConversionPatternRewriter &rewriter) const override {
SmallVector<Value> encodedInputs =
encodeInputs(op.getOperation(), inputs, rewriter);

Expand All @@ -299,7 +311,7 @@ class SecretGenericOpMemRefLoadConversion
LogicalResult matchAndRewriteInner(
secret::GenericOp op, TypeRange outputTypes, ValueRange inputs,
ArrayRef<NamedAttribute> attributes,
ConversionPatternRewriter &rewriter) const override {
ContextAwareConversionPatternRewriter &rewriter) const override {
memref::LoadOp loadOp =
cast<memref::LoadOp>(op.getBody()->getOperations().front());
if (auto lweType = dyn_cast<lwe::LWECiphertextType>(outputTypes[0])) {
Expand All @@ -323,7 +335,7 @@ class SecretGenericOpGateConversion
LogicalResult matchAndRewriteInner(
secret::GenericOp op, TypeRange outputTypes, ValueRange inputs,
ArrayRef<NamedAttribute> attributes,
ConversionPatternRewriter &rewriter) const override {
ContextAwareConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<CGGIGateOp>(
op, outputTypes, encodeInputs(op.getOperation(), inputs, rewriter),
attributes);
Expand Down Expand Up @@ -354,7 +366,7 @@ class SecretGenericOpAffineLoadConversion
LogicalResult matchAndRewriteInner(
secret::GenericOp op, TypeRange outputTypes, ValueRange inputs,
ArrayRef<NamedAttribute> attributes,
ConversionPatternRewriter &rewriter) const override {
ContextAwareConversionPatternRewriter &rewriter) const override {
affine::AffineLoadOp loadOp =
cast<affine::AffineLoadOp>(op.getBody()->getOperations().front());
if (auto lweType = dyn_cast<lwe::LWECiphertextType>(outputTypes[0])) {
Expand Down Expand Up @@ -383,7 +395,7 @@ class SecretGenericOpAffineStoreConversion
LogicalResult matchAndRewriteInner(
secret::GenericOp op, TypeRange outputTypes, ValueRange inputs,
ArrayRef<NamedAttribute> attributes,
ConversionPatternRewriter &rewriter) const override {
ContextAwareConversionPatternRewriter &rewriter) const override {
affine::AffineStoreOp storeOp =
cast<affine::AffineStoreOp>(op.getBody()->getOperations().front());
auto toMemRef = cast<TypedValue<MemRefType>>(inputs[1]);
Expand All @@ -407,7 +419,7 @@ class SecretGenericOpMemRefStoreConversion
LogicalResult matchAndRewriteInner(
secret::GenericOp op, TypeRange outputTypes, ValueRange inputs,
ArrayRef<NamedAttribute> attributes,
ConversionPatternRewriter &rewriter) const override {
ContextAwareConversionPatternRewriter &rewriter) const override {
memref::StoreOp storeOp =
cast<memref::StoreOp>(op.getBody()->getOperations().front());
auto toMemRef = cast<TypedValue<MemRefType>>(inputs[1]);
Expand All @@ -421,15 +433,16 @@ class SecretGenericOpMemRefStoreConversion
};

// ConvertTruthTableOp converts truth table ops with fully plaintext values.
struct ConvertTruthTableOp : public OpConversionPattern<comb::TruthTableOp> {
struct ConvertTruthTableOp
: public ContextAwareOpConversionPattern<comb::TruthTableOp> {
ConvertTruthTableOp(mlir::MLIRContext *context)
: OpConversionPattern<comb::TruthTableOp>(context) {}
: ContextAwareOpConversionPattern<comb::TruthTableOp>(context) {}

using OpConversionPattern::OpConversionPattern;
using ContextAwareOpConversionPattern::ContextAwareOpConversionPattern;

LogicalResult matchAndRewrite(
comb::TruthTableOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ContextAwareConversionPatternRewriter &rewriter) const override {
if (op->getNumOperands() != 3) {
op->emitError() << "expected 3 truth table arguments to lower to CGGI";
}
Expand All @@ -451,15 +464,16 @@ struct ConvertTruthTableOp : public OpConversionPattern<comb::TruthTableOp> {

// ConvertSecretCastOp removes secret.cast operations between multi-bit secret
// integers and tensors of single-bit secrets.
struct ConvertSecretCastOp : public OpConversionPattern<secret::CastOp> {
struct ConvertSecretCastOp
: public ContextAwareOpConversionPattern<secret::CastOp> {
ConvertSecretCastOp(mlir::MLIRContext *context)
: OpConversionPattern<secret::CastOp>(context) {}
: ContextAwareOpConversionPattern<secret::CastOp>(context) {}

using OpConversionPattern::OpConversionPattern;
using ContextAwareOpConversionPattern::ContextAwareOpConversionPattern;

LogicalResult matchAndRewrite(
secret::CastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ContextAwareConversionPatternRewriter &rewriter) const override {
// If this is a cast from secret<i8> to secret<memref<8xi1>> or vice
// versa, replace with the cast's input.
auto lhsType =
Expand Down Expand Up @@ -498,8 +512,10 @@ struct ConvertSecretCastOp : public OpConversionPattern<secret::CastOp> {
// when converting op results or operands in Yosys Optimizer when the
// original result or operand was also a memref.
if (lhsMemRefTy && rhsMemRefTy) {
auto outRhsType = cast<MemRefType>(
this->typeConverter->convertType(op.getOutput().getType()));
auto *typeConverter = getTypeConverter();
auto outRhsType = typeConverter->convertType<MemRefType>(
op.getOutput().getType(),
typeConverter->getContextualAttr(op.getOutput()).value_or(nullptr));
if (lhsMemRefTy.getRank() > rhsMemRefTy.getRank() &&
rhsMemRefTy.getRank() == 1) {
// This case happens when converting a high dimension memref into a flat
Expand Down Expand Up @@ -546,19 +562,22 @@ struct ConvertSecretCastOp : public OpConversionPattern<secret::CastOp> {

// ConvertSecretConcealOp lowers secret.conceal to a series of trivial_encrypt
// ops stored into a memref.
struct ConvertSecretConcealOp : public OpConversionPattern<secret::ConcealOp> {
struct ConvertSecretConcealOp
: public ContextAwareOpConversionPattern<secret::ConcealOp> {
ConvertSecretConcealOp(mlir::MLIRContext *context)
: OpConversionPattern<secret::ConcealOp>(context) {}
: ContextAwareOpConversionPattern<secret::ConcealOp>(context) {}

using OpConversionPattern::OpConversionPattern;
using ContextAwareOpConversionPattern::ContextAwareOpConversionPattern;

LogicalResult matchAndRewrite(
secret::ConcealOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ContextAwareConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

Type convertedTy =
getTypeConverter()->convertType(op.getResult().getType());
Type convertedTy = getTypeConverter()->convertType(
op.getResult().getType(), getTypeConverter()
->getContextualAttr(op.getResult())
.value_or(nullptr));
auto memrefTy = dyn_cast<MemRefType>(convertedTy);
auto ctTy = cast<lwe::LWECiphertextType>(memrefTy.getElementType());
auto ptxtTy =
Expand Down Expand Up @@ -709,9 +728,12 @@ struct SecretToCGGI : public impl::SecretToCGGIBase<SecretToCGGI> {
});
target.addDynamicallyLegalOp<affine::AffineYieldOp>(
[&](auto op) { return typeConverter.isLegal(op); });
addStructuralConversionPatterns(typeConverter, patterns, target);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
// FIXME: reinstate with context awareness!
// addStructuralConversionPatterns(typeConverter, patterns, target);

if (failed(applyContextAwarePartialConversion(module, target,
std::move(patterns)))) {
return signalPassFailure();
}

Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Secret/Conversions/SecretToCKKS/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ cc_library(
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@heir//lib/Parameters/CKKS:Params",
"@heir//lib/Utils",
"@heir//lib/Utils:ConversionUtils",
"@heir//lib/Utils:ContextAwareConversionUtils",
"@heir//lib/Utils:ContextAwareTypeConversion",
"@heir//lib/Utils/Polynomial",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
Expand Down
Loading
Loading