Skip to content

Commit fd39676

Browse files
committed
[mlir] Allow multi-result ops in reshape fusion
1 parent b9198a1 commit fd39676

File tree

2 files changed

+25
-19
lines changed

2 files changed

+25
-19
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1254,7 +1254,7 @@ static SmallVector<ReassociationIndices>
12541254
getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
12551255
ArrayRef<ReassociationIndices> reassociation) {
12561256
// Some basic checks for this fusion to be valid.
1257-
if (!genericOp.hasPureTensorSemantics() || genericOp.getNumDpsInits() != 1)
1257+
if (!genericOp.hasPureTensorSemantics())
12581258
return {};
12591259

12601260
if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {

mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,49 +7,55 @@
77
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)>
88
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6)>
99
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
10+
#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d0, d7, d3, d4, d5, d6)>
1011
func.func @fuse_by_collapsing(%arg0 : tensor<2x12x5x336x9xi32>,
11-
%arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> {
12+
%arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> (tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>) {
1213
%expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
13-
%init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
14-
%generic = linalg.generic {
15-
indexing_maps = [#map0, #map1, #map2, #map3],
14+
%init_0 = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
15+
%init_1 = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32>
16+
%generic:2 = linalg.generic {
17+
indexing_maps = [#map0, #map1, #map2, #map3, #map4],
1618
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
1719
ins(%expand, %arg1, %arg2 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<2x3x4xi32>, tensor<5x6x7x8xi32>)
18-
outs(%init : tensor<2x3x4x5x6x7x8x9xi32>) {
19-
^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
20+
outs(%init_0, %init_1 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>) {
21+
^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, %b4 : i32):
2022
%t0 = arith.addi %b0, %b1 : i32
2123
%t1 = arith.addi %t0, %b2 : i32
22-
linalg.yield %t1 : i32
23-
} -> tensor<2x3x4x5x6x7x8x9xi32>
24-
return %generic : tensor<2x3x4x5x6x7x8x9xi32>
24+
linalg.yield %t1, %t1 : i32, i32
25+
} -> (tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>)
26+
return %generic#0, %generic#1 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>
2527
}
2628
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
2729
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
2830
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
31+
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d0, d4, d2, d3)>
2932
// CHECK: func @fuse_by_collapsing(
3033
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>
3134
// CHECK-SAME: %[[ARG1:.+]]: tensor<2x3x4xi32>
3235
// CHECK-SAME: %[[ARG2:.+]]: tensor<5x6x7x8xi32>
33-
// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
36+
// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
37+
// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32>
3438
// CHECK-DAG: %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1, 2]{{\]}}
3539
// CHECK-DAG: %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}}
36-
// CHECK-DAG: %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
37-
// CHECK: %[[COLLAPSED_OP:.+]] = linalg.generic
38-
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]]
40+
// CHECK-DAG: %[[INIT0_RESHAPE:.+]] = tensor.collapse_shape %[[INIT0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
41+
// CHECK-DAG: %[[INIT1_RESHAPE:.+]] = tensor.collapse_shape %[[INIT1]] {{\[}}[0, 1], [2], [3], [4], [5, 6, 7]{{\]}}
42+
// CHECK: %[[COLLAPSED_OP:.+]]:2 = linalg.generic
43+
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]], #[[MAP3]]]
3944
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
4045
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
41-
// CHECK-SAME: outs(%[[INIT_RESHAPE]] :
42-
// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} output_shape [2, 3, 4, 5, 6, 7, 8, 9]
43-
// CHECK: return %[[RESULT_RESHAPE]]
46+
// CHECK-SAME: outs(%[[INIT0_RESHAPE]], %[[INIT1_RESHAPE]] :
47+
// CHECK: %[[RESULT0_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]]#0 {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} output_shape [2, 3, 4, 5, 6, 7, 8, 9]
48+
// CHECK: %[[RESULT1_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]]#1 {{\[}}[0, 1], [2], [3], [4], [5, 6, 7]{{\]}} output_shape [3, 4, 2, 9, 5, 6, 7, 8]
49+
// CHECK: return %[[RESULT0_RESHAPE]], %[[RESULT1_RESHAPE]]
4450

4551
// CONTROL: func @fuse_by_collapsing(
4652
// CONTROL-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>
4753
// CONTROL-SAME: %[[ARG1:.+]]: tensor<2x3x4xi32>
4854
// CONTROL-SAME: %[[ARG2:.+]]: tensor<5x6x7x8xi32>
4955
// CONTROL: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
50-
// CONTROL: %[[GENERIC:.+]] = linalg.generic
56+
// CONTROL: %[[GENERIC:.+]]:2 = linalg.generic
5157
// CONTROL-SAME: ins(%[[EXPAND]],
52-
// CONTROL: return %[[GENERIC]]
58+
// CONTROL: return %[[GENERIC]]#0, %[[GENERIC]]#1
5359

5460
// -----
5561

0 commit comments

Comments
 (0)