Skip to content

[InstCombine] Fix miscompilation in PR83947 #83993

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 2 commits into from
Mar 5, 2024
Merged

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Mar 5, 2024

// If the mask is all zeros, a scatter does nothing.
if (ConstMask->isNullValue())
return eraseInstFromFunction(II);
// Vector splat address -> scalar store
if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) {
// scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
StoreInst *S =
new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment);
S->copyMetadata(II);
return S;
}

Comment from @topperc:

This transforms assumes the mask is a non-zero splat. We only know its a splat and not provably all 0s. The mask is a constexpr that includes the address of the global variable. We can't resolve the constant expression to an exact value.

Fixes #83947.

@llvmbot
Copy link
Member

llvmbot commented Mar 5, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-analysis

Author: Yingwei Zheng (dtcxzyw)

Changes

// If the mask is all zeros, a scatter does nothing.
if (ConstMask->isNullValue())
return eraseInstFromFunction(II);
// Vector splat address -> scalar store
if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) {
// scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
StoreInst *S =
new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment);
S->copyMetadata(II);
return S;
}

Comment from @topperc:
> This transforms assumes the mask is a non-zero splat. We only know its a splat and not provably all 0s. The mask is a constexpr that includes the address of the global variable. We can't resolve the constant expression to an exact value.

Fixes #83947.


Full diff: https://github.com/llvm/llvm-project/pull/83993.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Analysis/VectorUtils.h (+5)
  • (modified) llvm/lib/Analysis/VectorUtils.cpp (+25)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+8-5)
  • (added) llvm/test/Transforms/InstCombine/pr83947.ll (+67)
diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 7a92e62b53c53d..c6eb66cc9660ca 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -406,6 +406,11 @@ bool maskIsAllZeroOrUndef(Value *Mask);
 /// lanes can be assumed active.
 bool maskIsAllOneOrUndef(Value *Mask);
 
+/// Given a mask vector of i1, Return true if any of the elements of this
+/// predicate mask are known to be true or undef.  That is, return true if at
+/// least one lane can be assumed active.
+bool maskContainsAllOneOrUndef(Value *Mask);
+
 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
 /// for each lane which may be active.
 APInt possiblyDemandedEltsInMask(Value *Mask);
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 73facc76a92b2c..bf7bc0ba84a033 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -1012,6 +1012,31 @@ bool llvm::maskIsAllOneOrUndef(Value *Mask) {
   return true;
 }
 
+bool llvm::maskContainsAllOneOrUndef(Value *Mask) {
+  assert(isa<VectorType>(Mask->getType()) &&
+         isa<IntegerType>(Mask->getType()->getScalarType()) &&
+         cast<IntegerType>(Mask->getType()->getScalarType())->getBitWidth() ==
+             1 &&
+         "Mask must be a vector of i1");
+
+  auto *ConstMask = dyn_cast<Constant>(Mask);
+  if (!ConstMask)
+    return false;
+  if (ConstMask->isAllOnesValue() || isa<UndefValue>(ConstMask))
+    return true;
+  if (isa<ScalableVectorType>(ConstMask->getType()))
+    return false;
+  for (unsigned
+           I = 0,
+           E = cast<FixedVectorType>(ConstMask->getType())->getNumElements();
+       I != E; ++I) {
+    if (auto *MaskElt = ConstMask->getAggregateElement(I))
+      if (MaskElt->isAllOnesValue() || isa<UndefValue>(MaskElt))
+        return true;
+  }
+  return false;
+}
+
 /// TODO: This is a lot like known bits, but for
 /// vectors.  Is there something we can common this with?
 APInt llvm::possiblyDemandedEltsInMask(Value *Mask) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 50c0f9a913f32a..89ca40f1c20d53 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -399,11 +399,14 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
   if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) {
     // scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
     if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
-      Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
-      StoreInst *S =
-          new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment);
-      S->copyMetadata(II);
-      return S;
+      if (maskContainsAllOneOrUndef(ConstMask)) {
+        Align Alignment =
+            cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
+        StoreInst *S = new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false,
+                                     Alignment);
+        S->copyMetadata(II);
+        return S;
+      }
     }
     // scatter(vector, splat(ptr), splat(true)) -> store extract(vector,
     // lastlane), ptr
diff --git a/llvm/test/Transforms/InstCombine/pr83947.ll b/llvm/test/Transforms/InstCombine/pr83947.ll
new file mode 100644
index 00000000000000..c1d601ff637183
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/pr83947.ll
@@ -0,0 +1,67 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+@c = global i32 0, align 4
+@b = global i32 0, align 4
+
+define void @masked_scatter1() {
+; CHECK-LABEL: define void @masked_scatter1() {
+; CHECK-NEXT:    call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> zeroinitializer, <vscale x 4 x ptr> shufflevector (<vscale x 4 x ptr> insertelement (<vscale x 4 x ptr> poison, ptr @c, i64 0), <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer), i32 4, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c), i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> zeroinitializer, <vscale x 4 x ptr> splat (ptr @c), i32 4, <vscale x 4 x i1> splat (i1 icmp eq (ptr getelementptr (i32, ptr @b, i64 1), ptr @c)))
+  ret void
+}
+
+define void @masked_scatter2() {
+; CHECK-LABEL: define void @masked_scatter2() {
+; CHECK-NEXT:    store i32 0, ptr @c, align 4
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 true))
+  ret void
+}
+
+define void @masked_scatter3() {
+; CHECK-LABEL: define void @masked_scatter3() {
+; CHECK-NEXT:    store i32 0, ptr @c, align 4
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> undef)
+  ret void
+}
+
+define void @masked_scatter4() {
+; CHECK-LABEL: define void @masked_scatter4() {
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 false))
+  ret void
+}
+
+define void @masked_scatter5() {
+; CHECK-LABEL: define void @masked_scatter5() {
+; CHECK-NEXT:    store i32 0, ptr @c, align 4
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> <i1 true, i1 false>)
+  ret void
+}
+
+define void @masked_scatter6() {
+; CHECK-LABEL: define void @masked_scatter6() {
+; CHECK-NEXT:    store i32 0, ptr @c, align 4
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> <i1 undef, i1 false>)
+  ret void
+}
+
+define void @masked_scatter7() {
+; CHECK-LABEL: define void @masked_scatter7() {
+; CHECK-NEXT:    call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> <ptr @c, ptr @c>, i32 4, <2 x i1> <i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c), i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c)>)
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 icmp eq (ptr getelementptr (i32, ptr @b, i64 1), ptr @c)))
+  ret void
+}

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, though the need to the new helper is a bit unfortunate...

@dtcxzyw dtcxzyw merged commit a1a590e into llvm:main Mar 5, 2024
@dtcxzyw dtcxzyw deleted the fix-pr83947 branch March 5, 2024 14:34
@dtcxzyw dtcxzyw added this to the LLVM 18.X Release milestone Mar 5, 2024
@dtcxzyw
Copy link
Member Author

dtcxzyw commented Mar 5, 2024

/cherry-pick a1a590e

llvmbot pushed a commit to llvmbot/llvm-project that referenced this pull request Mar 5, 2024
https://github.com/llvm/llvm-project/blob/762f762504967efbe159db5c737154b989afc9bb/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp#L394-L407

Comment from @topperc:
> This transforms assumes the mask is a non-zero splat. We only know its
a splat and not provably all 0s. The mask is a constexpr that includes
the address of the global variable. We can't resolve the constant
expression to an exact value.

Fixes llvm#83947.

(cherry picked from commit a1a590e)
@llvmbot
Copy link
Member

llvmbot commented Mar 5, 2024

/pull-request #84021

dtcxzyw added a commit to dtcxzyw/llvm-project that referenced this pull request Mar 7, 2024
https://github.com/llvm/llvm-project/blob/762f762504967efbe159db5c737154b989afc9bb/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp#L394-L407

Comment from @topperc:
> This transforms assumes the mask is a non-zero splat. We only know its
a splat and not provably all 0s. The mask is a constexpr that includes
the address of the global variable. We can't resolve the constant
expression to an exact value.

Fixes llvm#83947.
tstellar pushed a commit to dtcxzyw/llvm-project that referenced this pull request Mar 13, 2024
https://github.com/llvm/llvm-project/blob/762f762504967efbe159db5c737154b989afc9bb/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp#L394-L407

Comment from @topperc:
> This transforms assumes the mask is a non-zero splat. We only know its
a splat and not provably all 0s. The mask is a constexpr that includes
the address of the global variable. We can't resolve the constant
expression to an exact value.

Fixes llvm#83947.
@pointhex pointhex mentioned this pull request May 7, 2024
xgupta pushed a commit to xgupta/llvm-project that referenced this pull request Sep 9, 2024
https://github.com/llvm/llvm-project/blob/762f762504967efbe159db5c737154b989afc9bb/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp#L394-L407

Comment from @topperc:
> This transforms assumes the mask is a non-zero splat. We only know its
a splat and not provably all 0s. The mask is a constexpr that includes
the address of the global variable. We can't resolve the constant
expression to an exact value.

Fixes llvm#83947.
xgupta pushed a commit to xgupta/llvm-project that referenced this pull request Oct 10, 2024
https://github.com/llvm/llvm-project/blob/762f762504967efbe159db5c737154b989afc9bb/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp#L394-L407

Comment from @topperc:
> This transforms assumes the mask is a non-zero splat. We only know its
a splat and not provably all 0s. The mask is a constexpr that includes
the address of the global variable. We can't resolve the constant
expression to an exact value.

Fixes llvm#83947.
# for free to join this conversation on GitHub. Already have an account? # to comment
Projects
Development

Successfully merging this pull request may close these issues.

[RISC-V] Miscompile at -O2
3 participants