Skip to content

[mlir][Transforms][NFC] Dialect conversion: Cache UnresolvedMaterializationRewrite #108359

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 1 commit into from
Sep 13, 2024

Conversation

matthias-springer
Copy link
Member

The dialect conversion maintains a set of unresolved materializations (UnrealizedConversionCastOp). Turn that set into a DenseMap that maps from ops to UnresolvedMaterializationRewrite *. This improves efficiency a bit, because an iteration over ConversionPatternRewriterImpl::rewrites can be avoided.

Also delete some dead code.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Sep 12, 2024
@llvmbot
Copy link
Member

llvmbot commented Sep 12, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

The dialect conversion maintains a set of unresolved materializations (UnrealizedConversionCastOp). Turn that set into a DenseMap that maps from ops to UnresolvedMaterializationRewrite *. This improves efficiency a bit, because an iteration over ConversionPatternRewriterImpl::rewrites can be avoided.

Also delete some dead code.


Full diff: https://github.com/llvm/llvm-project/pull/108359.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+20-40)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b58a95c3baf70a..ed15b571f01883 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -688,9 +688,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
   UnresolvedMaterializationRewrite(
       ConversionPatternRewriterImpl &rewriterImpl,
       UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
-      MaterializationKind kind = MaterializationKind::Target)
-      : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
-        converterAndKind(converter, kind) {}
+      MaterializationKind kind = MaterializationKind::Target);
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -730,26 +728,6 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
   });
 }
 
-/// Find the single rewrite object of the specified type and block among the
-/// given rewrites. In debug mode, asserts that there is mo more than one such
-/// object. Return "nullptr" if no object was found.
-template <typename RewriteTy, typename R>
-static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
-  RewriteTy *result = nullptr;
-  for (auto &rewrite : rewrites) {
-    auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
-    if (rewriteTy && rewriteTy->getBlock() == block) {
-#ifndef NDEBUG
-      assert(!result && "expected single matching rewrite");
-      result = rewriteTy;
-#else
-      return rewriteTy;
-#endif // NDEBUG
-    }
-  }
-  return result;
-}
-
 //===----------------------------------------------------------------------===//
 // ConversionPatternRewriterImpl
 //===----------------------------------------------------------------------===//
@@ -892,10 +870,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 
     bool wasErased(void *ptr) const { return erased.contains(ptr); }
 
-    bool wasErased(OperationRewrite *rewrite) const {
-      return wasErased(rewrite->getOperation());
-    }
-
     void notifyOperationErased(Operation *op) override { erased.insert(op); }
 
     void notifyBlockErased(Block *block) override { erased.insert(block); }
@@ -935,8 +909,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// to modify/access them is invalid rewriter API usage.
   SetVector<Operation *> replacedOps;
 
-  /// A set of all unresolved materializations.
-  DenseSet<Operation *> unresolvedMaterializations;
+  /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
+  /// to the corresponding rewrite objects.
+  DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+      unresolvedMaterializations;
 
   /// The current type converter, or nullptr if no type converter is currently
   /// active.
@@ -1058,6 +1034,14 @@ void CreateOperationRewrite::rollback() {
   op->erase();
 }
 
+UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
+    ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
+    const TypeConverter *converter, MaterializationKind kind)
+    : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+      converterAndKind(converter, kind) {
+  rewriterImpl.unresolvedMaterializations[op] = this;
+}
+
 void UnresolvedMaterializationRewrite::rollback() {
   if (getMaterializationKind() == MaterializationKind::Target) {
     for (Value input : op->getOperands())
@@ -1345,7 +1329,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
-  unresolvedMaterializations.insert(convertOp);
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
   return convertOp.getResult(0);
 }
@@ -2499,15 +2482,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
 
   // Gather all unresolved materializations.
   SmallVector<UnrealizedConversionCastOp> allCastOps;
-  DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap;
-  for (std::unique_ptr<IRRewrite> &rewrite : rewriterImpl.rewrites) {
-    auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
-    if (!mat)
-      continue;
-    if (rewriterImpl.eraseRewriter.wasErased(mat))
+  const DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+      &materializations = rewriterImpl.unresolvedMaterializations;
+  for (auto it : materializations) {
+    if (rewriterImpl.eraseRewriter.wasErased(it.first))
       continue;
-    allCastOps.push_back(mat->getOperation());
-    rewriteMap[mat->getOperation()] = mat;
+    allCastOps.push_back(cast<UnrealizedConversionCastOp>(it.first));
   }
 
   // Reconcile all UnrealizedConversionCastOps that were inserted by the
@@ -2520,8 +2500,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   if (config.buildMaterializations) {
     IRRewriter rewriter(rewriterImpl.context, config.listener);
     for (UnrealizedConversionCastOp castOp : remainingCastOps) {
-      auto it = rewriteMap.find(castOp.getOperation());
-      assert(it != rewriteMap.end() && "inconsistent state");
+      auto it = materializations.find(castOp.getOperation());
+      assert(it != materializations.end() && "inconsistent state");
       if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
         return failure();
     }

Base automatically changed from users/matthias-springer/replace_op_source_mat to main September 12, 2024 13:30
@matthias-springer matthias-springer force-pushed the users/matthias-springer/mat_cache branch 2 times, most recently from 4cb4bcf to 066359e Compare September 12, 2024 13:36
…izationRewrite`

The dialect conversion already maintains a set of unresolved materializations (`UnrealizedConversionCastOp`). Turn that set into a map that maps from ops to `UnresolvedMaterializationRewrite *`. This improves efficiency a bit, because an iteration over `ConversionPatternRewriterImpl::rewrites` can be avoided.

Also delete some dead code.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/mat_cache branch from 066359e to e724e44 Compare September 13, 2024 17:55
@matthias-springer matthias-springer merged commit d588e49 into main Sep 13, 2024
6 of 7 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/mat_cache branch September 13, 2024 18:16
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants