Skip to content

[mlir][ArmSVE] Add convert.from/to.svbool intrinsics #68418

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 1 commit into from
Oct 10, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Oct 6, 2023

These will be used in future pass to ensure that loads/stores of masks
are legal (as the LLVM backend does not support this for any type smaller
than an svbool, which is vector<[16]xi1>).

Depends on #68399

@MacDue
Copy link
Member Author

MacDue commented Oct 6, 2023

cc @c-rhodes, @banach-space

These will be used in future pass to ensure that loads/stores of masks
are legal (as the LLVM backend does not suppor this for any type smaller
than an svbool, which is vector<[16]xi1>).
@MacDue MacDue force-pushed the arm_sve_add_svbool_intrs branch from fac5375 to 2ee3ec5 Compare October 9, 2023 09:35
@MacDue MacDue marked this pull request as ready for review October 9, 2023 09:35
@llvmbot
Copy link
Member

llvmbot commented Oct 9, 2023

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-sve

@llvm/pr-subscribers-mlir-llvm

Changes

These will be used in future pass to ensure that loads/stores of masks
are legal (as the LLVM backend does not support this for any type smaller
than an svbool, which is vector<[16]xi1>).

Depends on #68399


Full diff: https://github.com/llvm/llvm-project/pull/68418.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+24)
  • (modified) mlir/test/Target/LLVMIR/arm-sve.mlir (+44)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 58dec6091f27f6e..d4294b4dd9fd4e8 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -30,6 +30,16 @@ def ArmSVE_Dialect : Dialect {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// ArmSVE type definitions
+//===----------------------------------------------------------------------===//
+
+def SVBool : ScalableVectorOfRankAndLengthAndType<
+  [1], [16], [I1]>;
+
+def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
+  [1], [16, 8, 4, 2, 1], [I1]>;
+
 //===----------------------------------------------------------------------===//
 // ArmSVE op definitions
 //===----------------------------------------------------------------------===//
@@ -302,4 +312,18 @@ def ScalableMaskedDivFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
   Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
 
+def ConvertFromSvboolIntrOp :
+  ArmSVE_IntrOp<"convert.from.svbool",
+    [TypeIs<"res", SVEPredicate>],
+    /*overloadedOperands=*/[],
+    /*overloadedResults=*/[0]>,
+  Arguments<(ins SVBool:$svbool)>;
+
+def ConvertToSvboolIntrOp :
+  ArmSVE_IntrOp<"convert.to.svbool",
+    [TypeIs<"res", SVBool>],
+    /*overloadedOperands=*/[0],
+    /*overloadedResults=*/[]>,
+    Arguments<(ins SVEPredicate:$mask)>;
+
 #endif // ARMSVE_OPS
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index 999df8079e0727a..172a2f7d12d440e 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -272,3 +272,47 @@ llvm.func @get_vector_scale() -> i64 {
   %0 = "llvm.intr.vscale"() : () -> i64
   llvm.return %0 : i64
 }
+
+// CHECK-LABEL: @arm_sve_convert_from_svbool(
+// CHECK-SAME:                               <vscale x 16 x i1> %[[SVBOOL:[0-9]+]])
+llvm.func @arm_sve_convert_from_svbool(%nxv16i1 : vector<[16]xi1>) {
+  // CHECK: %[[RES0:.*]] = call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %[[SVBOOL]])
+  %res0 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+    : (vector<[16]xi1>) -> vector<[8]xi1>
+  // CHECK: %[[RES1:.*]] = call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> %[[SVBOOL]])
+  %res1 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+    : (vector<[16]xi1>) -> vector<[4]xi1>
+  // CHECK: %[[RES2:.*]] = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> %[[SVBOOL]])
+  %res2 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+    : (vector<[16]xi1>) -> vector<[2]xi1>
+  // CHECK: %[[RES3:.*]] = call <vscale x 1 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv1i1(<vscale x 16 x i1> %[[SVBOOL]])
+  %res3 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+    : (vector<[16]xi1>) -> vector<[1]xi1>
+  llvm.return
+}
+
+// CHECK-LABEL: arm_sve_convert_to_svbool(
+// CHECK-SAME:                            <vscale x 8 x i1> %[[P8:[0-9]+]],
+// CHECK-SAME:                            <vscale x 4 x i1> %[[P4:[0-9]+]],
+// CHECK-SAME:                            <vscale x 2 x i1> %[[P2:[0-9]+]],
+// CHECK-SAME:                            <vscale x 1 x i1> %[[P1:[0-9]+]])
+llvm.func @arm_sve_convert_to_svbool(
+                                       %nxv8i1  : vector<[8]xi1>,
+                                       %nxv4i1  : vector<[4]xi1>,
+                                       %nxv2i1  : vector<[2]xi1>,
+                                       %nxv1i1  : vector<[1]xi1>
+) {
+  // CHECK-NEXT: %[[RES0:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %[[P8]])
+  %res0 = "arm_sve.intr.convert.to.svbool"(%nxv8i1)
+    : (vector<[8]xi1>) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[RES1:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %[[P4]])
+  %res1 = "arm_sve.intr.convert.to.svbool"(%nxv4i1)
+    : (vector<[4]xi1>) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[RES2:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1> %[[P2]])
+  %res2 = "arm_sve.intr.convert.to.svbool"(%nxv2i1)
+    : (vector<[2]xi1>) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[RES3:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv1i1(<vscale x 1 x i1> %[[P1]])
+  %res3 = "arm_sve.intr.convert.to.svbool"(%nxv1i1)
+    : (vector<[1]xi1>) -> vector<[16]xi1>
+  llvm.return
+}

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, ta!

@MacDue MacDue merged commit 3d70ba6 into llvm:main Oct 10, 2023
MacDue added a commit that referenced this pull request Oct 12, 2023
This adds slightly higher-level ops for converting masks between svbool
and SVE predicate types. The main reason to use these over the
intrinsics is these ops support vectors of masks (via unrolling).

E.g.

```
// Convert a svbool mask to a mask of SVE predicates:
%svbool = vector.load %memref[%c0, %c0]
                       : memref<2x?xi1>, vector<2x[16]xi1>
%mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1>
// => Results in vector<2x[8]xi1>
```
Or:
```
// Convert a mask of SVE predicates to a svbool mask:
%mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
%svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1>
// => Results in vector<2x[16]xi1>
```

Depends on #68418
# for free to join this conversation on GitHub. Already have an account? # to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants