|
7 | 7 | #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)>
|
8 | 8 | #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6)>
|
9 | 9 | #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
|
| 10 | +#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d0, d7, d3, d4, d5, d6)> |
10 | 11 | func.func @fuse_by_collapsing(%arg0 : tensor<2x12x5x336x9xi32>,
|
11 |
| - %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> { |
| 12 | + %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> (tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>) { |
12 | 13 | %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
|
13 |
| - %init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32> |
14 |
| - %generic = linalg.generic { |
15 |
| - indexing_maps = [#map0, #map1, #map2, #map3], |
| 14 | + %init_0 = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32> |
| 15 | + %init_1 = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32> |
| 16 | + %generic:2 = linalg.generic { |
| 17 | + indexing_maps = [#map0, #map1, #map2, #map3, #map4], |
16 | 18 | iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
|
17 | 19 | ins(%expand, %arg1, %arg2 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<2x3x4xi32>, tensor<5x6x7x8xi32>)
|
18 |
| - outs(%init : tensor<2x3x4x5x6x7x8x9xi32>) { |
19 |
| - ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32): |
| 20 | + outs(%init_0, %init_1 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>) { |
| 21 | + ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, %b4 : i32): |
20 | 22 | %t0 = arith.addi %b0, %b1 : i32
|
21 | 23 | %t1 = arith.addi %t0, %b2 : i32
|
22 |
| - linalg.yield %t1 : i32 |
23 |
| - } -> tensor<2x3x4x5x6x7x8x9xi32> |
24 |
| - return %generic : tensor<2x3x4x5x6x7x8x9xi32> |
| 24 | + linalg.yield %t1, %t1 : i32, i32 |
| 25 | + } -> (tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>) |
| 26 | + return %generic#0, %generic#1 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32> |
25 | 27 | }
|
26 | 28 | // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
|
27 | 29 | // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
|
28 | 30 | // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
|
| 31 | +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d0, d4, d2, d3)> |
29 | 32 | // CHECK: func @fuse_by_collapsing(
|
30 | 33 | // CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>
|
31 | 34 | // CHECK-SAME: %[[ARG1:.+]]: tensor<2x3x4xi32>
|
32 | 35 | // CHECK-SAME: %[[ARG2:.+]]: tensor<5x6x7x8xi32>
|
33 |
| -// CHECK-DAG: %[[INIT:.+]] = tensor.empty() |
| 36 | +// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32> |
| 37 | +// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32> |
34 | 38 | // CHECK-DAG: %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1, 2]{{\]}}
|
35 | 39 | // CHECK-DAG: %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}}
|
36 |
| -// CHECK-DAG: %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} |
37 |
| -// CHECK: %[[COLLAPSED_OP:.+]] = linalg.generic |
38 |
| -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]] |
| 40 | +// CHECK-DAG: %[[INIT0_RESHAPE:.+]] = tensor.collapse_shape %[[INIT0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} |
| 41 | +// CHECK-DAG: %[[INIT1_RESHAPE:.+]] = tensor.collapse_shape %[[INIT1]] {{\[}}[0, 1], [2], [3], [4], [5, 6, 7]{{\]}} |
| 42 | +// CHECK: %[[COLLAPSED_OP:.+]]:2 = linalg.generic |
| 43 | +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]], #[[MAP3]]] |
39 | 44 | // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
|
40 | 45 | // CHECK-SAME: ins(%[[ARG0]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
|
41 |
| -// CHECK-SAME: outs(%[[INIT_RESHAPE]] : |
42 |
| -// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} output_shape [2, 3, 4, 5, 6, 7, 8, 9] |
43 |
| -// CHECK: return %[[RESULT_RESHAPE]] |
| 46 | +// CHECK-SAME: outs(%[[INIT0_RESHAPE]], %[[INIT1_RESHAPE]] : |
| 47 | +// CHECK: %[[RESULT0_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]]#0 {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} output_shape [2, 3, 4, 5, 6, 7, 8, 9] |
| 48 | +// CHECK: %[[RESULT1_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]]#1 {{\[}}[0, 1], [2], [3], [4], [5, 6, 7]{{\]}} output_shape [3, 4, 2, 9, 5, 6, 7, 8] |
| 49 | +// CHECK: return %[[RESULT0_RESHAPE]], %[[RESULT1_RESHAPE]] |
44 | 50 |
|
45 | 51 | // CONTROL: func @fuse_by_collapsing(
|
46 | 52 | // CONTROL-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>
|
47 | 53 | // CONTROL-SAME: %[[ARG1:.+]]: tensor<2x3x4xi32>
|
48 | 54 | // CONTROL-SAME: %[[ARG2:.+]]: tensor<5x6x7x8xi32>
|
49 | 55 | // CONTROL: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
|
50 |
| -// CONTROL: %[[GENERIC:.+]] = linalg.generic |
| 56 | +// CONTROL: %[[GENERIC:.+]]:2 = linalg.generic |
51 | 57 | // CONTROL-SAME: ins(%[[EXPAND]],
|
52 |
| -// CONTROL: return %[[GENERIC]] |
| 58 | +// CONTROL: return %[[GENERIC]]#0, %[[GENERIC]]#1 |
53 | 59 |
|
54 | 60 | // -----
|
55 | 61 |
|
|
0 commit comments