Skip to content

Commit aca8b9d

Browse files
committed
[DAG] SimplifyDemandedBits - if we're only demanding the signbits, a MIN/MAX node can be simplified to a OR or AND node
Extension to the signbit case, if the signbits extend down through all the demanded bits then SMIN/SMAX/UMIN/UMAX nodes can be simplified to a OR/AND/AND/OR. Alive2: https://alive2.llvm.org/ce/z/mFVFAn (general case) Differential Revision: https://reviews.llvm.org/D158364
1 parent 6ef767c commit aca8b9d

File tree

2 files changed

+62
-83
lines changed

2 files changed

+62
-83
lines changed

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

+38-43
Original file line numberDiff line numberDiff line change
@@ -2156,54 +2156,49 @@ bool TargetLowering::SimplifyDemandedBits(
21562156
}
21572157
break;
21582158
}
2159-
case ISD::SMIN: {
2160-
SDValue Op0 = Op.getOperand(0);
2161-
SDValue Op1 = Op.getOperand(1);
2162-
// If we're only wanting the signbit, then we can simplify to OR node.
2163-
// TODO: Extend this based on ComputeNumSignBits.
2164-
if (DemandedBits.isSignMask())
2165-
return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, VT, Op0, Op1));
2166-
break;
2167-
}
2168-
case ISD::SMAX: {
2169-
SDValue Op0 = Op.getOperand(0);
2170-
SDValue Op1 = Op.getOperand(1);
2171-
// If we're only wanting the signbit, then we can simplify to AND node.
2172-
// TODO: Extend this based on ComputeNumSignBits.
2173-
if (DemandedBits.isSignMask())
2174-
return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::AND, dl, VT, Op0, Op1));
2175-
break;
2176-
}
2177-
case ISD::UMIN: {
2178-
SDValue Op0 = Op.getOperand(0);
2179-
SDValue Op1 = Op.getOperand(1);
2180-
// If we're only wanting the msb, then we can simplify to AND node.
2181-
if (DemandedBits.isSignMask())
2182-
return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::AND, dl, VT, Op0, Op1));
2183-
// Check if one arg is always less than (or equal) to the other arg.
2184-
KnownBits Known0 = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth + 1);
2185-
KnownBits Known1 = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
2186-
Known = KnownBits::umin(Known0, Known1);
2187-
if (std::optional<bool> IsULE = KnownBits::ule(Known0, Known1))
2188-
return TLO.CombineTo(Op, *IsULE ? Op0 : Op1);
2189-
if (std::optional<bool> IsULT = KnownBits::ult(Known0, Known1))
2190-
return TLO.CombineTo(Op, *IsULT ? Op0 : Op1);
2191-
break;
2192-
}
2159+
case ISD::SMIN:
2160+
case ISD::SMAX:
2161+
case ISD::UMIN:
21932162
case ISD::UMAX: {
2163+
unsigned Opc = Op.getOpcode();
21942164
SDValue Op0 = Op.getOperand(0);
21952165
SDValue Op1 = Op.getOperand(1);
2196-
// If we're only wanting the msb, then we can simplify to OR node.
2197-
if (DemandedBits.isSignMask())
2198-
return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, VT, Op0, Op1));
2199-
// Check if one arg is always greater than (or equal) to the other arg.
2166+
2167+
// If we're only demanding signbits, then we can simplify to OR/AND node.
2168+
unsigned BitOp =
2169+
(Opc == ISD::SMIN || Opc == ISD::UMAX) ? ISD::OR : ISD::AND;
2170+
unsigned NumSignBits =
2171+
std::min(TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1),
2172+
TLO.DAG.ComputeNumSignBits(Op1, DemandedElts, Depth + 1));
2173+
unsigned NumDemandedUpperBits = BitWidth - DemandedBits.countr_zero();
2174+
if (NumSignBits >= NumDemandedUpperBits)
2175+
return TLO.CombineTo(Op, TLO.DAG.getNode(BitOp, SDLoc(Op), VT, Op0, Op1));
2176+
2177+
// Check if one arg is always less/greater than (or equal) to the other arg.
22002178
KnownBits Known0 = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth + 1);
22012179
KnownBits Known1 = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
2202-
Known = KnownBits::umax(Known0, Known1);
2203-
if (std::optional<bool> IsUGE = KnownBits::uge(Known0, Known1))
2204-
return TLO.CombineTo(Op, *IsUGE ? Op0 : Op1);
2205-
if (std::optional<bool> IsUGT = KnownBits::ugt(Known0, Known1))
2206-
return TLO.CombineTo(Op, *IsUGT ? Op0 : Op1);
2180+
switch (Opc) {
2181+
case ISD::SMIN:
2182+
// TODO: Add KnownBits::sle/slt handling.
2183+
break;
2184+
case ISD::SMAX:
2185+
// TODO: Add KnownBits::sge/sgt handling.
2186+
break;
2187+
case ISD::UMIN:
2188+
if (std::optional<bool> IsULE = KnownBits::ule(Known0, Known1))
2189+
return TLO.CombineTo(Op, *IsULE ? Op0 : Op1);
2190+
if (std::optional<bool> IsULT = KnownBits::ult(Known0, Known1))
2191+
return TLO.CombineTo(Op, *IsULT ? Op0 : Op1);
2192+
Known = KnownBits::umin(Known0, Known1);
2193+
break;
2194+
case ISD::UMAX:
2195+
if (std::optional<bool> IsUGE = KnownBits::uge(Known0, Known1))
2196+
return TLO.CombineTo(Op, *IsUGE ? Op0 : Op1);
2197+
if (std::optional<bool> IsUGT = KnownBits::ugt(Known0, Known1))
2198+
return TLO.CombineTo(Op, *IsUGT ? Op0 : Op1);
2199+
Known = KnownBits::umax(Known0, Known1);
2200+
break;
2201+
}
22072202
break;
22082203
}
22092204
case ISD::BITREVERSE: {

llvm/test/CodeGen/X86/known-signbits-vector.ll

+24-40
Original file line numberDiff line numberDiff line change
@@ -483,28 +483,24 @@ define <4 x float> @signbits_ashr_sext_select_shuffle_sitofp(<4 x i64> %a0, <4 x
483483
define <4 x i32> @signbits_mask_ashr_smax(<4 x i32> %a0, <4 x i32> %a1) {
484484
; X86-LABEL: signbits_mask_ashr_smax:
485485
; X86: # %bb.0:
486-
; X86-NEXT: vpsrad $25, %xmm0, %xmm0
487-
; X86-NEXT: vpsrad $25, %xmm1, %xmm1
488-
; X86-NEXT: vpmaxsd %xmm1, %xmm0, %xmm0
486+
; X86-NEXT: vpand %xmm1, %xmm0, %xmm0
489487
; X86-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
488+
; X86-NEXT: vpsrad $25, %xmm0, %xmm0
490489
; X86-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0
491490
; X86-NEXT: retl
492491
;
493492
; X64-AVX1-LABEL: signbits_mask_ashr_smax:
494493
; X64-AVX1: # %bb.0:
495-
; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0
496-
; X64-AVX1-NEXT: vpsrad $25, %xmm1, %xmm1
497-
; X64-AVX1-NEXT: vpmaxsd %xmm1, %xmm0, %xmm0
494+
; X64-AVX1-NEXT: vpand %xmm1, %xmm0, %xmm0
498495
; X64-AVX1-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
496+
; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0
499497
; X64-AVX1-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
500498
; X64-AVX1-NEXT: retq
501499
;
502500
; X64-AVX2-LABEL: signbits_mask_ashr_smax:
503501
; X64-AVX2: # %bb.0:
504-
; X64-AVX2-NEXT: vmovdqa {{.*#+}} xmm2 = [25,26,27,0]
505-
; X64-AVX2-NEXT: vpsravd %xmm2, %xmm0, %xmm0
506-
; X64-AVX2-NEXT: vpsravd %xmm2, %xmm1, %xmm1
507-
; X64-AVX2-NEXT: vpmaxsd %xmm1, %xmm0, %xmm0
502+
; X64-AVX2-NEXT: vpand %xmm1, %xmm0, %xmm0
503+
; X64-AVX2-NEXT: vpsrad $25, %xmm0, %xmm0
508504
; X64-AVX2-NEXT: vpbroadcastd %xmm0, %xmm0
509505
; X64-AVX2-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
510506
; X64-AVX2-NEXT: retq
@@ -521,28 +517,24 @@ declare <4 x i32> @llvm.smax.v4i32(<4 x i32>, <4 x i32>) nounwind readnone
521517
define <4 x i32> @signbits_mask_ashr_smin(<4 x i32> %a0, <4 x i32> %a1) {
522518
; X86-LABEL: signbits_mask_ashr_smin:
523519
; X86: # %bb.0:
524-
; X86-NEXT: vpsrad $25, %xmm0, %xmm0
525-
; X86-NEXT: vpsrad $25, %xmm1, %xmm1
526-
; X86-NEXT: vpminsd %xmm1, %xmm0, %xmm0
520+
; X86-NEXT: vpor %xmm1, %xmm0, %xmm0
527521
; X86-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
522+
; X86-NEXT: vpsrad $25, %xmm0, %xmm0
528523
; X86-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0
529524
; X86-NEXT: retl
530525
;
531526
; X64-AVX1-LABEL: signbits_mask_ashr_smin:
532527
; X64-AVX1: # %bb.0:
533-
; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0
534-
; X64-AVX1-NEXT: vpsrad $25, %xmm1, %xmm1
535-
; X64-AVX1-NEXT: vpminsd %xmm1, %xmm0, %xmm0
528+
; X64-AVX1-NEXT: vpor %xmm1, %xmm0, %xmm0
536529
; X64-AVX1-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
530+
; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0
537531
; X64-AVX1-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
538532
; X64-AVX1-NEXT: retq
539533
;
540534
; X64-AVX2-LABEL: signbits_mask_ashr_smin:
541535
; X64-AVX2: # %bb.0:
542-
; X64-AVX2-NEXT: vmovdqa {{.*#+}} xmm2 = [25,26,27,0]
543-
; X64-AVX2-NEXT: vpsravd %xmm2, %xmm0, %xmm0
544-
; X64-AVX2-NEXT: vpsravd %xmm2, %xmm1, %xmm1
545-
; X64-AVX2-NEXT: vpminsd %xmm1, %xmm0, %xmm0
536+
; X64-AVX2-NEXT: vpor %xmm1, %xmm0, %xmm0
537+
; X64-AVX2-NEXT: vpsrad $25, %xmm0, %xmm0
546538
; X64-AVX2-NEXT: vpbroadcastd %xmm0, %xmm0
547539
; X64-AVX2-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
548540
; X64-AVX2-NEXT: retq
@@ -559,28 +551,24 @@ declare <4 x i32> @llvm.smin.v4i32(<4 x i32>, <4 x i32>) nounwind readnone
559551
define <4 x i32> @signbits_mask_ashr_umax(<4 x i32> %a0, <4 x i32> %a1) {
560552
; X86-LABEL: signbits_mask_ashr_umax:
561553
; X86: # %bb.0:
562-
; X86-NEXT: vpsrad $25, %xmm0, %xmm0
563-
; X86-NEXT: vpsrad $25, %xmm1, %xmm1
564-
; X86-NEXT: vpmaxud %xmm1, %xmm0, %xmm0
554+
; X86-NEXT: vpor %xmm1, %xmm0, %xmm0
565555
; X86-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
556+
; X86-NEXT: vpsrad $25, %xmm0, %xmm0
566557
; X86-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0
567558
; X86-NEXT: retl
568559
;
569560
; X64-AVX1-LABEL: signbits_mask_ashr_umax:
570561
; X64-AVX1: # %bb.0:
571-
; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0
572-
; X64-AVX1-NEXT: vpsrad $25, %xmm1, %xmm1
573-
; X64-AVX1-NEXT: vpmaxud %xmm1, %xmm0, %xmm0
562+
; X64-AVX1-NEXT: vpor %xmm1, %xmm0, %xmm0
574563
; X64-AVX1-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
564+
; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0
575565
; X64-AVX1-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
576566
; X64-AVX1-NEXT: retq
577567
;
578568
; X64-AVX2-LABEL: signbits_mask_ashr_umax:
579569
; X64-AVX2: # %bb.0:
580-
; X64-AVX2-NEXT: vmovdqa {{.*#+}} xmm2 = [25,26,27,0]
581-
; X64-AVX2-NEXT: vpsravd %xmm2, %xmm0, %xmm0
582-
; X64-AVX2-NEXT: vpsravd %xmm2, %xmm1, %xmm1
583-
; X64-AVX2-NEXT: vpmaxud %xmm1, %xmm0, %xmm0
570+
; X64-AVX2-NEXT: vpor %xmm1, %xmm0, %xmm0
571+
; X64-AVX2-NEXT: vpsrad $25, %xmm0, %xmm0
584572
; X64-AVX2-NEXT: vpbroadcastd %xmm0, %xmm0
585573
; X64-AVX2-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
586574
; X64-AVX2-NEXT: retq
@@ -597,28 +585,24 @@ declare <4 x i32> @llvm.umax.v4i32(<4 x i32>, <4 x i32>) nounwind readnone
597585
define <4 x i32> @signbits_mask_ashr_umin(<4 x i32> %a0, <4 x i32> %a1) {
598586
; X86-LABEL: signbits_mask_ashr_umin:
599587
; X86: # %bb.0:
600-
; X86-NEXT: vpsrad $25, %xmm0, %xmm0
601-
; X86-NEXT: vpsrad $25, %xmm1, %xmm1
602-
; X86-NEXT: vpminud %xmm1, %xmm0, %xmm0
588+
; X86-NEXT: vpand %xmm1, %xmm0, %xmm0
603589
; X86-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
590+
; X86-NEXT: vpsrad $25, %xmm0, %xmm0
604591
; X86-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0
605592
; X86-NEXT: retl
606593
;
607594
; X64-AVX1-LABEL: signbits_mask_ashr_umin:
608595
; X64-AVX1: # %bb.0:
609-
; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0
610-
; X64-AVX1-NEXT: vpsrad $25, %xmm1, %xmm1
611-
; X64-AVX1-NEXT: vpminud %xmm1, %xmm0, %xmm0
596+
; X64-AVX1-NEXT: vpand %xmm1, %xmm0, %xmm0
612597
; X64-AVX1-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
598+
; X64-AVX1-NEXT: vpsrad $25, %xmm0, %xmm0
613599
; X64-AVX1-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
614600
; X64-AVX1-NEXT: retq
615601
;
616602
; X64-AVX2-LABEL: signbits_mask_ashr_umin:
617603
; X64-AVX2: # %bb.0:
618-
; X64-AVX2-NEXT: vmovdqa {{.*#+}} xmm2 = [25,26,27,0]
619-
; X64-AVX2-NEXT: vpsravd %xmm2, %xmm0, %xmm0
620-
; X64-AVX2-NEXT: vpsravd %xmm2, %xmm1, %xmm1
621-
; X64-AVX2-NEXT: vpminud %xmm1, %xmm0, %xmm0
604+
; X64-AVX2-NEXT: vpand %xmm1, %xmm0, %xmm0
605+
; X64-AVX2-NEXT: vpsrad $25, %xmm0, %xmm0
622606
; X64-AVX2-NEXT: vpbroadcastd %xmm0, %xmm0
623607
; X64-AVX2-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
624608
; X64-AVX2-NEXT: retq

0 commit comments

Comments
 (0)