Skip to content

Commit abf69a1

Browse files
authored
[InstCombine] Fold (x < y) ? -1 : zext(x != y) into u/scmp(x,y) (#101049)
This patch adds the aforementioned fold to InstCombine. This pattern is produced after naive implementations of 3-way comparison in high-level languages are transformed into LLVM IR and then optimized. Proofs: https://alive2.llvm.org/ce/z/w4QLq_
1 parent b8dccb7 commit abf69a1

File tree

6 files changed

+426
-37
lines changed

6 files changed

+426
-37
lines changed

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

+1
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
729729

730730
// Helpers of visitSelectInst().
731731
Instruction *foldSelectOfBools(SelectInst &SI);
732+
Instruction *foldSelectToCmp(SelectInst &SI);
732733
Instruction *foldSelectExtConst(SelectInst &Sel);
733734
Instruction *foldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI);
734735
Instruction *foldSelectIntoOp(SelectInst &SI, Value *, Value *);

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

+52
Original file line numberDiff line numberDiff line change
@@ -3558,6 +3558,55 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) {
35583558
Masked);
35593559
}
35603560

3561+
// This function tries to fold the following operations:
3562+
// (x < y) ? -1 : zext(x != y)
3563+
// (x > y) ? 1 : sext(x != y)
3564+
// Into ucmp/scmp(x, y), where signedness is determined by the signedness
3565+
// of the comparison in the original sequence.
3566+
Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
3567+
Value *TV = SI.getTrueValue();
3568+
Value *FV = SI.getFalseValue();
3569+
3570+
ICmpInst::Predicate Pred;
3571+
Value *LHS, *RHS;
3572+
if (!match(SI.getCondition(), m_ICmp(Pred, m_Value(LHS), m_Value(RHS))))
3573+
return nullptr;
3574+
3575+
if (!LHS->getType()->isIntOrIntVectorTy())
3576+
return nullptr;
3577+
3578+
// Try to swap operands and the predicate. We need to be careful when doing
3579+
// so because two of the patterns have opposite predicates, so use the
3580+
// constant inside select to determine if swapping operands would be
3581+
// beneficial to us.
3582+
if ((ICmpInst::isGT(Pred) && match(TV, m_AllOnes())) ||
3583+
(ICmpInst::isLT(Pred) && match(TV, m_One()))) {
3584+
Pred = ICmpInst::getSwappedPredicate(Pred);
3585+
std::swap(LHS, RHS);
3586+
}
3587+
3588+
Intrinsic::ID IID =
3589+
ICmpInst::isSigned(Pred) ? Intrinsic::scmp : Intrinsic::ucmp;
3590+
3591+
bool Replace = false;
3592+
// (x < y) ? -1 : zext(x != y)
3593+
if (ICmpInst::isLT(Pred) && match(TV, m_AllOnes()) &&
3594+
match(FV, m_ZExt(m_c_SpecificICmp(ICmpInst::ICMP_NE, m_Specific(LHS),
3595+
m_Specific(RHS)))))
3596+
Replace = true;
3597+
3598+
// (x > y) ? 1 : sext(x != y)
3599+
if (ICmpInst::isGT(Pred) && match(TV, m_One()) &&
3600+
match(FV, m_SExt(m_c_SpecificICmp(ICmpInst::ICMP_NE, m_Specific(LHS),
3601+
m_Specific(RHS)))))
3602+
Replace = true;
3603+
3604+
if (Replace)
3605+
return replaceInstUsesWith(
3606+
SI, Builder.CreateIntrinsic(SI.getType(), IID, {LHS, RHS}));
3607+
return nullptr;
3608+
}
3609+
35613610
bool InstCombinerImpl::fmulByZeroIsZero(Value *MulVal, FastMathFlags FMF,
35623611
const Instruction *CtxI) const {
35633612
KnownFPClass Known = computeKnownFPClass(MulVal, FMF, fcNegative, CtxI);
@@ -4061,6 +4110,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
40614110
if (Instruction *I = foldBitCeil(SI, Builder))
40624111
return I;
40634112

4113+
if (Instruction *I = foldSelectToCmp(SI))
4114+
return I;
4115+
40644116
// Fold:
40654117
// (select A && B, T, F) -> (select A, (select B, T, F), F)
40664118
// (select A || B, T, F) -> (select A, T, (select B, T, F))

llvm/test/Transforms/InstCombine/scmp.ll

+56
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,59 @@ define i8 @scmp_negated_multiuse(i32 %x, i32 %y) {
208208
%2 = sub i8 0, %1
209209
ret i8 %2
210210
}
211+
212+
; Fold ((x s< y) ? -1 : (x != y)) into scmp(x, y)
213+
define i8 @scmp_from_select_lt(i32 %x, i32 %y) {
214+
; CHECK-LABEL: define i8 @scmp_from_select_lt(
215+
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
216+
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
217+
; CHECK-NEXT: ret i8 [[R]]
218+
;
219+
%ne_bool = icmp ne i32 %x, %y
220+
%ne = zext i1 %ne_bool to i8
221+
%lt = icmp slt i32 %x, %y
222+
%r = select i1 %lt, i8 -1, i8 %ne
223+
ret i8 %r
224+
}
225+
226+
; Vector version
227+
define <4 x i8> @scmp_from_select_vec_lt(<4 x i32> %x, <4 x i32> %y) {
228+
; CHECK-LABEL: define <4 x i8> @scmp_from_select_vec_lt(
229+
; CHECK-SAME: <4 x i32> [[X:%.*]], <4 x i32> [[Y:%.*]]) {
230+
; CHECK-NEXT: [[R:%.*]] = call <4 x i8> @llvm.scmp.v4i8.v4i32(<4 x i32> [[X]], <4 x i32> [[Y]])
231+
; CHECK-NEXT: ret <4 x i8> [[R]]
232+
;
233+
%ne_bool = icmp ne <4 x i32> %x, %y
234+
%ne = zext <4 x i1> %ne_bool to <4 x i8>
235+
%lt = icmp slt <4 x i32> %x, %y
236+
%r = select <4 x i1> %lt, <4 x i8> splat(i8 -1), <4 x i8> %ne
237+
ret <4 x i8> %r
238+
}
239+
240+
; Fold (x s<= y) ? sext(x != y) : 1 into scmp(x, y)
241+
define i8 @scmp_from_select_le(i32 %x, i32 %y) {
242+
; CHECK-LABEL: define i8 @scmp_from_select_le(
243+
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
244+
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
245+
; CHECK-NEXT: ret i8 [[R]]
246+
;
247+
%ne_bool = icmp ne i32 %x, %y
248+
%ne = sext i1 %ne_bool to i8
249+
%le = icmp sle i32 %x, %y
250+
%r = select i1 %le, i8 %ne, i8 1
251+
ret i8 %r
252+
}
253+
254+
; Fold (x s>= y) ? zext(x != y) : -1 into scmp(x, y)
255+
define i8 @scmp_from_select_ge(i32 %x, i32 %y) {
256+
; CHECK-LABEL: define i8 @scmp_from_select_ge(
257+
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
258+
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
259+
; CHECK-NEXT: ret i8 [[R]]
260+
;
261+
%ne_bool = icmp ne i32 %x, %y
262+
%ne = zext i1 %ne_bool to i8
263+
%ge = icmp sge i32 %x, %y
264+
%r = select i1 %ge, i8 %ne, i8 -1
265+
ret i8 %r
266+
}

llvm/test/Transforms/InstCombine/select-select.ll

+9-35
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,7 @@ define i8 @strong_order_cmp_ugt_eq(i32 %a, i32 %b) {
282282

283283
define i8 @strong_order_cmp_eq_slt(i32 %a, i32 %b) {
284284
; CHECK-LABEL: @strong_order_cmp_eq_slt(
285-
; CHECK-NEXT: [[CMP_EQ:%.*]] = icmp ne i32 [[A:%.*]], [[B:%.*]]
286-
; CHECK-NEXT: [[SEL_EQ:%.*]] = zext i1 [[CMP_EQ]] to i8
287-
; CHECK-NEXT: [[CMP_LT:%.*]] = icmp slt i32 [[A]], [[B]]
288-
; CHECK-NEXT: [[SEL_LT:%.*]] = select i1 [[CMP_LT]], i8 -1, i8 [[SEL_EQ]]
285+
; CHECK-NEXT: [[SEL_LT:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
289286
; CHECK-NEXT: ret i8 [[SEL_LT]]
290287
;
291288
%cmp.eq = icmp eq i32 %a, %b
@@ -297,10 +294,7 @@ define i8 @strong_order_cmp_eq_slt(i32 %a, i32 %b) {
297294

298295
define i8 @strong_order_cmp_eq_sgt(i32 %a, i32 %b) {
299296
; CHECK-LABEL: @strong_order_cmp_eq_sgt(
300-
; CHECK-NEXT: [[CMP_EQ:%.*]] = icmp ne i32 [[A:%.*]], [[B:%.*]]
301-
; CHECK-NEXT: [[SEL_EQ:%.*]] = sext i1 [[CMP_EQ]] to i8
302-
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp sgt i32 [[A]], [[B]]
303-
; CHECK-NEXT: [[SEL_GT:%.*]] = select i1 [[CMP_GT]], i8 1, i8 [[SEL_EQ]]
297+
; CHECK-NEXT: [[SEL_GT:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
304298
; CHECK-NEXT: ret i8 [[SEL_GT]]
305299
;
306300
%cmp.eq = icmp eq i32 %a, %b
@@ -312,10 +306,7 @@ define i8 @strong_order_cmp_eq_sgt(i32 %a, i32 %b) {
312306

313307
define i8 @strong_order_cmp_eq_ult(i32 %a, i32 %b) {
314308
; CHECK-LABEL: @strong_order_cmp_eq_ult(
315-
; CHECK-NEXT: [[CMP_EQ:%.*]] = icmp ne i32 [[A:%.*]], [[B:%.*]]
316-
; CHECK-NEXT: [[SEL_EQ:%.*]] = zext i1 [[CMP_EQ]] to i8
317-
; CHECK-NEXT: [[CMP_LT:%.*]] = icmp ult i32 [[A]], [[B]]
318-
; CHECK-NEXT: [[SEL_LT:%.*]] = select i1 [[CMP_LT]], i8 -1, i8 [[SEL_EQ]]
309+
; CHECK-NEXT: [[SEL_LT:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
319310
; CHECK-NEXT: ret i8 [[SEL_LT]]
320311
;
321312
%cmp.eq = icmp eq i32 %a, %b
@@ -327,10 +318,7 @@ define i8 @strong_order_cmp_eq_ult(i32 %a, i32 %b) {
327318

328319
define i8 @strong_order_cmp_eq_ugt(i32 %a, i32 %b) {
329320
; CHECK-LABEL: @strong_order_cmp_eq_ugt(
330-
; CHECK-NEXT: [[CMP_EQ:%.*]] = icmp ne i32 [[A:%.*]], [[B:%.*]]
331-
; CHECK-NEXT: [[SEL_EQ:%.*]] = sext i1 [[CMP_EQ]] to i8
332-
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt i32 [[A]], [[B]]
333-
; CHECK-NEXT: [[SEL_GT:%.*]] = select i1 [[CMP_GT]], i8 1, i8 [[SEL_EQ]]
321+
; CHECK-NEXT: [[SEL_GT:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
334322
; CHECK-NEXT: ret i8 [[SEL_GT]]
335323
;
336324
%cmp.eq = icmp eq i32 %a, %b
@@ -404,9 +392,7 @@ define i8 @strong_order_cmp_ne_ugt_ne_not_one_use(i32 %a, i32 %b) {
404392
; CHECK-LABEL: @strong_order_cmp_ne_ugt_ne_not_one_use(
405393
; CHECK-NEXT: [[CMP_NE:%.*]] = icmp ne i32 [[A:%.*]], [[B:%.*]]
406394
; CHECK-NEXT: call void @use1(i1 [[CMP_NE]])
407-
; CHECK-NEXT: [[SEL_EQ:%.*]] = sext i1 [[CMP_NE]] to i8
408-
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt i32 [[A]], [[B]]
409-
; CHECK-NEXT: [[SEL_GT:%.*]] = select i1 [[CMP_GT]], i8 1, i8 [[SEL_EQ]]
395+
; CHECK-NEXT: [[SEL_GT:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
410396
; CHECK-NEXT: ret i8 [[SEL_GT]]
411397
;
412398
%cmp.ne = icmp ne i32 %a, %b
@@ -535,10 +521,7 @@ define <2 x i8> @strong_order_cmp_ugt_ult_vector_poison(<2 x i32> %a, <2 x i32>
535521

536522
define <2 x i8> @strong_order_cmp_eq_ugt_vector(<2 x i32> %a, <2 x i32> %b) {
537523
; CHECK-LABEL: @strong_order_cmp_eq_ugt_vector(
538-
; CHECK-NEXT: [[CMP_EQ:%.*]] = icmp ne <2 x i32> [[A:%.*]], [[B:%.*]]
539-
; CHECK-NEXT: [[SEL_EQ:%.*]] = sext <2 x i1> [[CMP_EQ]] to <2 x i8>
540-
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt <2 x i32> [[A]], [[B]]
541-
; CHECK-NEXT: [[SEL_GT:%.*]] = select <2 x i1> [[CMP_GT]], <2 x i8> <i8 1, i8 1>, <2 x i8> [[SEL_EQ]]
524+
; CHECK-NEXT: [[SEL_GT:%.*]] = call <2 x i8> @llvm.ucmp.v2i8.v2i32(<2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]])
542525
; CHECK-NEXT: ret <2 x i8> [[SEL_GT]]
543526
;
544527
%cmp.eq = icmp eq <2 x i32> %a, %b
@@ -550,10 +533,7 @@ define <2 x i8> @strong_order_cmp_eq_ugt_vector(<2 x i32> %a, <2 x i32> %b) {
550533

551534
define <2 x i8> @strong_order_cmp_eq_ugt_vector_poison1(<2 x i32> %a, <2 x i32> %b) {
552535
; CHECK-LABEL: @strong_order_cmp_eq_ugt_vector_poison1(
553-
; CHECK-NEXT: [[CMP_EQ:%.*]] = icmp ne <2 x i32> [[A:%.*]], [[B:%.*]]
554-
; CHECK-NEXT: [[SEL_EQ:%.*]] = sext <2 x i1> [[CMP_EQ]] to <2 x i8>
555-
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt <2 x i32> [[A]], [[B]]
556-
; CHECK-NEXT: [[SEL_GT:%.*]] = select <2 x i1> [[CMP_GT]], <2 x i8> <i8 1, i8 1>, <2 x i8> [[SEL_EQ]]
536+
; CHECK-NEXT: [[SEL_GT:%.*]] = call <2 x i8> @llvm.ucmp.v2i8.v2i32(<2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]])
557537
; CHECK-NEXT: ret <2 x i8> [[SEL_GT]]
558538
;
559539
%cmp.eq = icmp eq <2 x i32> %a, %b
@@ -565,10 +545,7 @@ define <2 x i8> @strong_order_cmp_eq_ugt_vector_poison1(<2 x i32> %a, <2 x i32>
565545

566546
define <2 x i8> @strong_order_cmp_eq_ugt_vector_poison2(<2 x i32> %a, <2 x i32> %b) {
567547
; CHECK-LABEL: @strong_order_cmp_eq_ugt_vector_poison2(
568-
; CHECK-NEXT: [[CMP_EQ:%.*]] = icmp ne <2 x i32> [[A:%.*]], [[B:%.*]]
569-
; CHECK-NEXT: [[SEL_EQ:%.*]] = sext <2 x i1> [[CMP_EQ]] to <2 x i8>
570-
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt <2 x i32> [[A]], [[B]]
571-
; CHECK-NEXT: [[SEL_GT:%.*]] = select <2 x i1> [[CMP_GT]], <2 x i8> <i8 1, i8 1>, <2 x i8> [[SEL_EQ]]
548+
; CHECK-NEXT: [[SEL_GT:%.*]] = call <2 x i8> @llvm.ucmp.v2i8.v2i32(<2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]])
572549
; CHECK-NEXT: ret <2 x i8> [[SEL_GT]]
573550
;
574551
%cmp.eq = icmp eq <2 x i32> %a, %b
@@ -580,10 +557,7 @@ define <2 x i8> @strong_order_cmp_eq_ugt_vector_poison2(<2 x i32> %a, <2 x i32>
580557

581558
define <2 x i8> @strong_order_cmp_eq_ugt_vector_poison3(<2 x i32> %a, <2 x i32> %b) {
582559
; CHECK-LABEL: @strong_order_cmp_eq_ugt_vector_poison3(
583-
; CHECK-NEXT: [[CMP_EQ:%.*]] = icmp ne <2 x i32> [[A:%.*]], [[B:%.*]]
584-
; CHECK-NEXT: [[SEL_EQ:%.*]] = sext <2 x i1> [[CMP_EQ]] to <2 x i8>
585-
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt <2 x i32> [[A]], [[B]]
586-
; CHECK-NEXT: [[SEL_GT:%.*]] = select <2 x i1> [[CMP_GT]], <2 x i8> <i8 1, i8 poison>, <2 x i8> [[SEL_EQ]]
560+
; CHECK-NEXT: [[SEL_GT:%.*]] = call <2 x i8> @llvm.ucmp.v2i8.v2i32(<2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]])
587561
; CHECK-NEXT: ret <2 x i8> [[SEL_GT]]
588562
;
589563
%cmp.eq = icmp eq <2 x i32> %a, %b

0 commit comments

Comments
 (0)