Skip to content

Commit 270ade8

Browse files
[mlir][Transform] apply_conversion_patterns: Update handles
Until now, `transform.apply_conversion_patterns` consumed the target handle and potentially invalidated handles. This commit adds tracking functionality similar to `transform.apply_patterns`, such that handles are no longer invalidated, but updated based on op replacements performed by the dialect conversion. This new functionality is hidden behind a `preserve_handles` attribute for now.
1 parent 9b4aa86 commit 270ade8

File tree

5 files changed

+129
-35
lines changed

5 files changed

+129
-35
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -921,20 +921,36 @@ TransformState::RegionScope TransformState::make_region_scope(Region &region) {
921921
return RegionScope(*this, region);
922922
}
923923

924+
/// A configuration object for customizing a `TrackingListener`.
925+
struct TrackingListenerConfig {
926+
using SkipHandleFn = std::function<bool(Value)>;
927+
928+
/// An optional function that returns "true" for handles that do not have to
929+
/// be updated. These are typically dead or consumed handles.
930+
SkipHandleFn skipHandleFn = nullptr;
931+
932+
/// If set to "true", the name of a replacement op must match the name of the
933+
/// original op. If set to "false", the names of the payload ops tracked in a
934+
/// handle may change as the tracking listener updates the transform state.
935+
bool requireMatchingReplacementOpName = true;
936+
937+
/// If set to "true", cast ops (that implement the CastOpInterface) are
938+
/// skipped and the replacement op search continues with the operands of the
939+
/// cast op.
940+
bool skipCastOps = true;
941+
};
942+
924943
/// A listener that updates a TransformState based on IR modifications. This
925944
/// listener can be used during a greedy pattern rewrite to keep the transform
926945
/// state up-to-date.
927946
class TrackingListener : public RewriterBase::Listener,
928947
public TransformState::Extension {
929948
public:
930-
/// A function that returns "true" for handles that do not have to be updated.
931-
using SkipHandleFn = std::function<bool(Value)>;
932-
933949
/// Create a new TrackingListener for usage in the specified transform op.
934950
/// Optionally, a function can be specified to identify handles that should
935951
/// do not have to be updated.
936952
TrackingListener(TransformState &state, TransformOpInterface op,
937-
SkipHandleFn skipHandleFn = nullptr);
953+
TrackingListenerConfig config = TrackingListenerConfig());
938954

939955
protected:
940956
/// Return a replacement payload op for the given op, which is going to be
@@ -959,7 +975,8 @@ class TrackingListener : public RewriterBase::Listener,
959975
/// same computation; e.g., there may be tiled "linalg.generic" inside the
960976
/// loop body that represents the original computation. Therefore, the
961977
/// TrackingListener is conservative by default: it drops the mapping and
962-
/// triggers the "payload replacement not found" notification.
978+
/// triggers the "payload replacement not found" notification. This default
979+
/// behavior can be customized in `TrackingListenerConfig`.
963980
///
964981
/// If no replacement op could be found according to the rules mentioned
965982
/// above, this function tries to skip over cast-like ops that implement
@@ -1023,9 +1040,8 @@ class TrackingListener : public RewriterBase::Listener,
10231040
/// The handles that are consumed by the transform op.
10241041
DenseSet<Value> consumedHandles;
10251042

1026-
/// Handles for which this function evaluates to "true" do not have to be
1027-
/// updated. These are typically dead or consumed handles.
1028-
SkipHandleFn skipHandleFn;
1043+
/// Tracking listener configuration.
1044+
TrackingListenerConfig config;
10291045
};
10301046

10311047
/// A specialized listener that keeps track of cases in which no replacement

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,19 +190,29 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns",
190190
The `legal_ops`, `illegal_ops`, `legal_dialects`, `illegal_dialects`
191191
attributes specify the conversion target.
192192

193-
This transform consumes the `target` handle and modifies the payload. It
194-
does not produce any handles.
193+
This transform modifies the payload. By default, it consumes the `target`
194+
handle. It does not produce any handles.
195+
196+
If the `preserve_handles` attribute is set, this transform does not consume
197+
the `target` handle and instead updates handles based on notifications from
198+
a tracking listener that is attached to the dialect conversion, similar to
199+
`transform.apply_patterns`. Only replacements via `RewriterBase::replaceOp`
200+
or `replaceOpWithNewOp` are considered "payload op replacements". In
201+
contrast to `transform.apply_patterns`, we allow replacement ops even if the
202+
op name has changed. More details can be found at the documentation site of
203+
`TrackingListener`.
195204

196205
This transform produces a silenceable failure if the dialect conversion was
197-
unsuccessful.
206+
unsuccessful or the tracking listener failed to find a replacement op.
198207
}];
199208

200209
let arguments = (ins TransformHandleTypeInterface:$target,
201210
OptionalAttr<StrArrayAttr>:$legal_ops,
202211
OptionalAttr<StrArrayAttr>:$illegal_ops,
203212
OptionalAttr<StrArrayAttr>:$legal_dialects,
204213
OptionalAttr<StrArrayAttr>:$illegal_dialects,
205-
UnitAttr:$partial_conversion);
214+
UnitAttr:$partial_conversion,
215+
UnitAttr:$preserve_handles);
206216
let results = (outs);
207217
let regions = (region
208218
MaxSizedRegion<1>:$patterns,

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
918918
}
919919

920920
// Prepare rewriter and listener.
921-
TrackingListener::SkipHandleFn skipHandleFn = [&](Value handle) {
921+
TrackingListenerConfig config;
922+
config.skipHandleFn = [&](Value handle) {
922923
// Skip handle if it is dead.
923924
auto scopeIt =
924925
llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
@@ -935,7 +936,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
935936
return true;
936937
};
937938
transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
938-
skipHandleFn);
939+
config);
939940
transform::TransformRewriter rewriter(transform->getContext(),
940941
&trackingListener);
941942

@@ -1184,9 +1185,8 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const {
11841185

11851186
transform::TrackingListener::TrackingListener(TransformState &state,
11861187
TransformOpInterface op,
1187-
SkipHandleFn skipHandleFn)
1188-
: TransformState::Extension(state), transformOp(op),
1189-
skipHandleFn(skipHandleFn) {
1188+
TrackingListenerConfig config)
1189+
: TransformState::Extension(state), transformOp(op), config(config) {
11901190
if (op) {
11911191
for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
11921192
consumedHandles.insert(opOperand->get());
@@ -1228,8 +1228,19 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
12281228
return diag;
12291229
}
12301230

1231-
// If the defining op has the same type, we take it as a replacement.
1232-
if (op->getName() == defOp->getName()) {
1231+
// Skip through ops that implement CastOpInterface.
1232+
if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1233+
values.clear();
1234+
values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
1235+
diag.attachNote(defOp->getLoc())
1236+
<< "using output of 'CastOpInterface' op";
1237+
continue;
1238+
}
1239+
1240+
// If the defining op has the same name or we do not care about the name of
1241+
// op replacements at all, we take it as a replacement.
1242+
if (!config.requireMatchingReplacementOpName ||
1243+
op->getName() == defOp->getName()) {
12331244
result = defOp;
12341245
return DiagnosedSilenceableFailure::success();
12351246
}
@@ -1251,14 +1262,6 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
12511262
"'FindPayloadReplacementOpInterface'";
12521263
continue;
12531264
}
1254-
1255-
// Skip through ops that implement CastOpInterface.
1256-
if (isa<CastOpInterface>(defOp)) {
1257-
values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
1258-
diag.attachNote(defOp->getLoc())
1259-
<< "using output of 'CastOpInterface' op";
1260-
continue;
1261-
}
12621265
} while (!values.empty());
12631266

12641267
diag.attachNote() << "ran out of suitable replacement values";
@@ -1318,9 +1321,9 @@ void transform::TrackingListener::notifyOperationReplaced(
13181321

13191322
// Check if there are any handles that must be updated.
13201323
Value aliveHandle;
1321-
if (skipHandleFn) {
1322-
auto it =
1323-
llvm::find_if(opHandles, [&](Value v) { return !skipHandleFn(v); });
1324+
if (config.skipHandleFn) {
1325+
auto it = llvm::find_if(opHandles,
1326+
[&](Value v) { return !config.skipHandleFn(v); });
13241327
if (it != opHandles.end())
13251328
aliveHandle = *it;
13261329
} else if (!opHandles.empty()) {

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,17 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
563563
}
564564
}
565565

566+
// Attach a tracking listener if handles should be preserved. We configure the
567+
// listener to allow op replacements with different names, as conversion
568+
// patterns typically replace ops with replacement ops that have a different
569+
// name.
570+
TrackingListenerConfig trackingConfig;
571+
trackingConfig.requireMatchingReplacementOpName = false;
572+
ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
573+
ConversionConfig conversionConfig;
574+
if (getPreserveHandles())
575+
conversionConfig.listener = &trackingListener;
576+
566577
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
567578
for (Operation *target : state.getPayloadOps(getTarget())) {
568579
// Make sure that this transform is not applied to itself. Modifying the
@@ -574,16 +585,36 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
574585

575586
LogicalResult status = failure();
576587
if (getPartialConversion()) {
577-
status = applyPartialConversion(target, conversionTarget, frozenPatterns);
588+
status = applyPartialConversion(target, conversionTarget, frozenPatterns,
589+
conversionConfig);
578590
} else {
579-
status = applyFullConversion(target, conversionTarget, frozenPatterns);
591+
status = applyFullConversion(target, conversionTarget, frozenPatterns,
592+
conversionConfig);
580593
}
581594

595+
// Check dialect conversion state.
596+
DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
582597
if (failed(status)) {
583-
auto diag = emitSilenceableError() << "dialect conversion failed";
598+
diag = emitSilenceableError() << "dialect conversion failed";
584599
diag.attachNote(target->getLoc()) << "target op";
585-
return diag;
586600
}
601+
602+
// Check tracking listener error state.
603+
DiagnosedSilenceableFailure trackingFailure =
604+
trackingListener.checkAndResetError();
605+
if (!trackingFailure.succeeded()) {
606+
if (diag.succeeded()) {
607+
// Tracking failure is the only failure.
608+
return trackingFailure;
609+
} else {
610+
diag.attachNote() << "tracking listener also failed: "
611+
<< trackingFailure.getMessage();
612+
(void)trackingFailure.silence();
613+
}
614+
}
615+
616+
if (!diag.succeeded())
617+
return diag;
587618
}
588619

589620
return DiagnosedSilenceableFailure::success();
@@ -632,7 +663,11 @@ LogicalResult transform::ApplyConversionPatternsOp::verify() {
632663

633664
void transform::ApplyConversionPatternsOp::getEffects(
634665
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
635-
transform::consumesHandle(getTarget(), effects);
666+
if (!getPreserveHandles()) {
667+
transform::consumesHandle(getTarget(), effects);
668+
} else {
669+
transform::onlyReadsHandle(getTarget(), effects);
670+
}
636671
transform::modifiesPayload(effects);
637672
}
638673

mlir/test/Dialect/Transform/test-pattern-application.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,3 +417,33 @@ module attributes { transform.with_named_sequence } {
417417
transform.yield
418418
}
419419
}
420+
421+
// -----
422+
423+
// "test.foo" is tracked and replaced with "test.new_op" during a dialect
424+
// conversion. Make sure that the handle is updated accordingly.
425+
426+
// CHECK-LABEL: func @dialect_conversion_tracking
427+
// CHECK-NEXT: %[[m:.*]] = "test.new_op"() {annotated} : () -> memref<5xf32>
428+
// CHECK-NEXT: %[[cast:.*]] = builtin.unrealized_conversion_cast %0 : memref<5xf32> to tensor<5xf32>
429+
// CHECK-NEXT: return %[[cast]]
430+
func.func @dialect_conversion_tracking() -> tensor<5xf32> {
431+
%0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (tensor<5xf32>)
432+
return %0 : tensor<5xf32>
433+
}
434+
435+
module attributes {transform.with_named_sequence} {
436+
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
437+
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
438+
%1 = transform.structured.match ops{["test.foo"]} in %0 : (!transform.any_op) -> !transform.any_op
439+
transform.apply_conversion_patterns to %0 {
440+
transform.apply_conversion_patterns.transform.test_conversion_patterns
441+
} with type_converter {
442+
transform.apply_conversion_patterns.transform.test_type_converter
443+
} {legal_ops = ["func.func", "func.return", "test.new_op"], preserve_handles}
444+
: !transform.any_op
445+
// Add an attribute to %1, which is now mapped to a new op.
446+
transform.annotate %1 "annotated" : !transform.any_op
447+
transform.yield
448+
}
449+
}

0 commit comments

Comments
 (0)