Skip to content

Commit 3ecec09

Browse files
[mlir][IR] Add listener notifications for pattern begin/end
1 parent 270ade8 commit 3ecec09

File tree

3 files changed

+77
-35
lines changed

3 files changed

+77
-35
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder {
432432
/// Note: This notification is not triggered when unlinking an operation.
433433
virtual void notifyOperationErased(Operation *op) {}
434434

435-
/// Notify the listener that the pattern failed to match the given
436-
/// operation, and provide a callback to populate a diagnostic with the
437-
/// reason why the failure occurred. This method allows for derived
438-
/// listeners to optionally hook into the reason why a rewrite failed, and
439-
/// display it to users.
435+
/// Notify the listener that the specified pattern is about to be applied
436+
/// at the specified root operation.
437+
virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {}
438+
439+
/// Notify the listener that a pattern application finished with the
440+
/// specified status. "success" indicates that the pattern was applied
441+
/// successfully. "failure" indicates that the pattern could not be
442+
/// applied. The pattern may have communicated the reason for the failure
443+
/// with `notifyMatchFailure`.
444+
virtual void notifyPatternEnd(const Pattern &pattern,
445+
LogicalResult status) {}
446+
447+
/// Notify the listener that the pattern failed to match, and provide a
448+
/// callback to populate a diagnostic with the reason why the failure
449+
/// occurred. This method allows for derived listeners to optionally hook
450+
/// into the reason why a rewrite failed, and display it to users.
440451
virtual void
441452
notifyMatchFailure(Location loc,
442453
function_ref<void(Diagnostic &)> reasonCallback) {}
@@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder {
478489
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
479490
rewriteListener->notifyOperationErased(op);
480491
}
492+
void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
493+
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
494+
rewriteListener->notifyPatternBegin(pattern, op);
495+
}
496+
void notifyPatternEnd(const Pattern &pattern,
497+
LogicalResult status) override {
498+
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
499+
rewriteListener->notifyPatternEnd(pattern, status);
500+
}
481501
void notifyMatchFailure(
482502
Location loc,
483503
function_ref<void(Diagnostic &)> reasonCallback) override {

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,7 +1863,8 @@ class OperationLegalizer {
18631863
using LegalizationAction = ConversionTarget::LegalizationAction;
18641864

18651865
OperationLegalizer(const ConversionTarget &targetInfo,
1866-
const FrozenRewritePatternSet &patterns);
1866+
const FrozenRewritePatternSet &patterns,
1867+
const ConversionConfig &config);
18671868

18681869
/// Returns true if the given operation is known to be illegal on the target.
18691870
bool isIllegal(Operation *op) const;
@@ -1955,12 +1956,16 @@ class OperationLegalizer {
19551956

19561957
/// The pattern applicator to use for conversions.
19571958
PatternApplicator applicator;
1959+
1960+
/// Dialect conversion configuration.
1961+
const ConversionConfig &config;
19581962
};
19591963
} // namespace
19601964

19611965
OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
1962-
const FrozenRewritePatternSet &patterns)
1963-
: target(targetInfo), applicator(patterns) {
1966+
const FrozenRewritePatternSet &patterns,
1967+
const ConversionConfig &config)
1968+
: target(targetInfo), applicator(patterns), config(config) {
19641969
// The set of patterns that can be applied to illegal operations to transform
19651970
// them into legal ones.
19661971
DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2105,7 +2110,10 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
21052110

21062111
// Functor that returns if the given pattern may be applied.
21072112
auto canApply = [&](const Pattern &pattern) {
2108-
return canApplyPattern(op, pattern, rewriter);
2113+
bool canApply = canApplyPattern(op, pattern, rewriter);
2114+
if (canApply && config.listener)
2115+
config.listener->notifyPatternBegin(pattern, op);
2116+
return canApply;
21092117
};
21102118

21112119
// Functor that cleans up the rewriter state after a pattern failed to match.
@@ -2122,6 +2130,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
21222130
rewriterImpl.config.notifyCallback(diag);
21232131
}
21242132
});
2133+
if (config.listener)
2134+
config.listener->notifyPatternEnd(pattern, failure());
21252135
rewriterImpl.resetState(curState);
21262136
appliedPatterns.erase(&pattern);
21272137
};
@@ -2134,6 +2144,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
21342144
appliedPatterns.erase(&pattern);
21352145
if (failed(result))
21362146
rewriterImpl.resetState(curState);
2147+
if (config.listener)
2148+
config.listener->notifyPatternEnd(pattern, result);
21372149
return result;
21382150
};
21392151

@@ -2509,7 +2521,8 @@ struct OperationConverter {
25092521
const FrozenRewritePatternSet &patterns,
25102522
const ConversionConfig &config,
25112523
OpConversionMode mode)
2512-
: opLegalizer(target, patterns), config(config), mode(mode) {}
2524+
: config(config), opLegalizer(target, patterns, this->config),
2525+
mode(mode) {}
25132526

25142527
/// Converts the given operations to the conversion target.
25152528
LogicalResult convertOperations(ArrayRef<Operation *> ops);
@@ -2546,12 +2559,12 @@ struct OperationConverter {
25462559
ConversionPatternRewriterImpl &rewriterImpl,
25472560
const DenseMap<Value, SmallVector<Value>> &inverseMapping);
25482561

2549-
/// The legalizer to use when converting operations.
2550-
OperationLegalizer opLegalizer;
2551-
25522562
/// Dialect conversion configuration.
25532563
ConversionConfig config;
25542564

2565+
/// The legalizer to use when converting operations.
2566+
OperationLegalizer opLegalizer;
2567+
25552568
/// The conversion mode to use when legalizing operations.
25562569
OpConversionMode mode;
25572570
};

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -562,30 +562,39 @@ bool GreedyPatternRewriteDriver::processWorklist() {
562562
// Try to match one of the patterns. The rewriter is automatically
563563
// notified of any necessary changes, so there is nothing else to do
564564
// here.
565-
#ifndef NDEBUG
566-
auto canApply = [&](const Pattern &pattern) {
567-
LLVM_DEBUG({
568-
logger.getOStream() << "\n";
569-
logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
570-
<< op->getName() << " -> (";
571-
llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
572-
logger.getOStream() << ")' {\n";
573-
logger.indent();
574-
});
575-
return true;
576-
};
577-
auto onFailure = [&](const Pattern &pattern) {
578-
LLVM_DEBUG(logResult("failure", "pattern failed to match"));
579-
};
580-
auto onSuccess = [&](const Pattern &pattern) {
581-
LLVM_DEBUG(logResult("success", "pattern applied successfully"));
582-
return success();
583-
};
584-
#else
585565
function_ref<bool(const Pattern &)> canApply = {};
586566
function_ref<void(const Pattern &)> onFailure = {};
587567
function_ref<LogicalResult(const Pattern &)> onSuccess = {};
588-
#endif
568+
bool debugBuild = false;
569+
#ifdef NDEBUG
570+
debugBuild = true;
571+
#endif // NDEBUG
572+
if (debugBuild || config.listener) {
573+
canApply = [&](const Pattern &pattern) {
574+
LLVM_DEBUG({
575+
logger.getOStream() << "\n";
576+
logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
577+
<< op->getName() << " -> (";
578+
llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
579+
logger.getOStream() << ")' {\n";
580+
logger.indent();
581+
});
582+
if (config.listener)
583+
config.listener->notifyPatternBegin(pattern, op);
584+
return true;
585+
};
586+
onFailure = [&](const Pattern &pattern) {
587+
LLVM_DEBUG(logResult("failure", "pattern failed to match"));
588+
if (config.listener)
589+
config.listener->notifyPatternEnd(pattern, failure());
590+
};
591+
onSuccess = [&](const Pattern &pattern) {
592+
LLVM_DEBUG(logResult("success", "pattern applied successfully"));
593+
if (config.listener)
594+
config.listener->notifyPatternEnd(pattern, success());
595+
return success();
596+
};
597+
}
589598

590599
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
591600
if (config.scope) {
@@ -731,7 +740,7 @@ void GreedyPatternRewriteDriver::notifyMatchFailure(
731740
LLVM_DEBUG({
732741
Diagnostic diag(loc, DiagnosticSeverity::Remark);
733742
reasonCallback(diag);
734-
logger.startLine() << "** Failure : " << diag.str() << "\n";
743+
logger.startLine() << "** Match Failure : " << diag.str() << "\n";
735744
});
736745
if (config.listener)
737746
config.listener->notifyMatchFailure(loc, reasonCallback);

0 commit comments

Comments
 (0)