Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

A few more chain optimizations #4170

Merged
merged 11 commits into from
Apr 11, 2022
6 changes: 6 additions & 0 deletions bench/src/main/scala-2.12/cats/bench/ChainBench.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,10 @@ class ChainBench {

@Benchmark def createChainSeqOption: Chain[Int] = Chain.fromSeq(intOption.toSeq)
@Benchmark def createChainOption: Chain[Int] = Chain.fromOption(intOption)

@Benchmark def reverseLargeList: List[Int] = largeList.reverse
@Benchmark def reverseLargeChain: Chain[Int] = largeChain.reverse

@Benchmark def lengthLargeList: Int = largeList.length
@Benchmark def lengthLargeChain: Long = largeChain.length
}
134 changes: 97 additions & 37 deletions core/src/main/scala/cats/data/Chain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,38 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {
/**
* Reverses this `Chain`
*/
def reverse: Chain[A] =
fromSeq(reverseIterator.toVector)
def reverse: Chain[A] = {
@annotation.tailrec
def loop[B <: A](h: Chain.NonEmpty[B], tail: List[Chain.NonEmpty[B]], acc: Chain[A]): Chain[A] =
h match {
case Append(l, r) => loop(l, r :: tail, acc)
case sing @ Singleton(_) =>
val nextAcc = sing.concat(acc)
tail match {
case h1 :: t1 =>
loop(h1, t1, nextAcc)
case _ =>
nextAcc
}
case Wrap(seq) =>
val nextAcc = Wrap(seq.reverse).concat(acc)
tail match {
case h1 :: t1 =>
loop(h1, t1, nextAcc)
case _ =>
nextAcc
}
}

this match {
case Append(l, r) =>
loop(l, r :: Nil, Empty)
case Wrap(seq) => Wrap(seq.reverse)
case _ =>
// Empty | Singleton(_)
this
}
}

/**
* Yields to Some(a, Chain[A]) with `a` removed where `f` holds for the first time,
Expand Down Expand Up @@ -587,19 +617,35 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {
* Returns the number of elements in this structure
*/
final def length: Long = {
// This is an optimized (unboxed) implementation
// of the same code as foldLeft
@annotation.tailrec
def loop(chains: List[Chain[A]], acc: Long): Long =
chains match {
case Nil => acc
case h :: tail =>
h match {
case Empty => loop(tail, acc)
case Wrap(seq) => loop(tail, acc + seq.length)
case Singleton(a) => loop(tail, acc + 1)
case Append(l, r) => loop(l :: r :: tail, acc)
def loop(head: Chain.NonEmpty[A], tail: List[Chain.NonEmpty[A]], acc: Long): Long =
head match {
case Append(l, r) => loop(l, r :: tail, acc)
case Singleton(_) =>
val nextAcc = acc + 1L
tail match {
case h1 :: t1 =>
loop(h1, t1, nextAcc)
case _ =>
nextAcc
}
case Wrap(seq) =>
val nextAcc = acc + seq.length.toLong
tail match {
case h1 :: t1 =>
loop(h1, t1, nextAcc)
case _ =>
nextAcc
}
}
loop(this :: Nil, 0L)

this match {
case ne: Chain.NonEmpty[A] =>
loop(ne, Nil, 0L)
case _ => 0L
}
}

/**
Expand Down Expand Up @@ -632,28 +678,37 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {
* }}}
*/
final def lengthCompare(len: Long): Int = {
import java.lang.Long
this match {
// `isEmpty` check should be faster than `== Chain.Empty`,
// but the compiler fails to prove that the match is still exhaustive.
case _ if isEmpty => Long.compare(0L, len)
case Chain.Singleton(_) => Long.compare(1L, len)
case _ if len < 2 => 1 // the following cases should always have `length >= 2`
case Chain.Wrap(seq) =>
if (len > Int.MaxValue) -1 // `Seq#length` has `Int` type so cannot be `> Int.MaxValue`
else
seq.lengthCompare(len.toInt)
case _ => // should always be `Chain.Append` (i.e. `NonEmpty` with 2+ elements)
var sz = 2L
val it = new ChainIterator(this)
it.next()
it.next()
while (it.hasNext) {
if (sz == len) return 1
it.next()
sz += 1L
// This is an optimized (unboxed) implementation
// of the same code as foldLeft
@annotation.tailrec
def loop(head: Chain.NonEmpty[A], tail: List[Chain.NonEmpty[A]], len: Long): Int =
if (len <= 0L) 1 // head is nonempty
else
head match {
case Append(l, r) => loop(l, r :: tail, len)
case Singleton(_) =>
tail match {
case h1 :: t1 =>
loop(h1, t1, len - 1L)
case _ =>
java.lang.Long.compare(1L, len)
}
case Wrap(seq) =>
val c =
if (len <= Int.MaxValue) seq.lengthCompare(len.toInt)
else -1
tail match {
case h1 :: t1 =>
if (c >= 0) 1 // there is definitely more in tail
else loop(h1, t1, len - seq.length)
case _ => c
}
}
Long.compare(sz, len)

this match {
case ne: Chain.NonEmpty[A] =>
loop(ne, Nil, len)
case _ => java.lang.Long.compare(0L, len)
}
}

Expand Down Expand Up @@ -767,18 +822,20 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {

final def sortBy[B](f: A => B)(implicit B: Order[B]): Chain[A] =
this match {
case Singleton(_) => this
case Append(_, _) => Wrap(toVector.sortBy(f)(B.toOrdering))
case Wrap(seq) => Wrap(seq.sortBy(f)(B.toOrdering))
case _ => this
case _ =>
// Empty | Singleton(_)
this
}

final def sorted[AA >: A](implicit AA: Order[AA]): Chain[AA] =
this match {
case Singleton(_) => this
case Append(_, _) => Wrap(toVector.sorted(AA.toOrdering))
case Wrap(seq) => Wrap(seq.sorted(AA.toOrdering))
case _ => this
case _ =>
// Empty | Singleton(_)
this
}
}

Expand Down Expand Up @@ -1115,6 +1172,9 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 {
Eval.defer(loop(fa))
}

override def foldMap[A, B](fa: Chain[A])(f: A => B)(implicit B: Monoid[B]): B =
B.combineAll(fa.iterator.map(f))

override def map[A, B](fa: Chain[A])(f: A => B): Chain[B] = fa.map(f)
override def toList[A](fa: Chain[A]): List[A] = fa.toList
override def isEmpty[A](fa: Chain[A]): Boolean = fa.isEmpty
Expand Down
6 changes: 6 additions & 0 deletions tests/shared/src/test/scala/cats/tests/ChainSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -434,4 +434,10 @@ class ChainSuite extends CatsSuite {

assert(sumAll == chain.iterator.sum)
}

test("foldRight(b)(fn) == toList.foldRight(b)(fn)") {
forAll { (chain: Chain[Int], init: Long, fn: (Int, Long) => Long) =>
assert(chain.foldRight(init)(fn) == chain.toList.foldRight(init)(fn))
}
}
}