From 120161a567f3740116ca75f1444d4e85a1ff18ad Mon Sep 17 00:00:00 2001 From: Stein Somers Date: Wed, 9 Oct 2019 01:07:57 +0200 Subject: [PATCH] BTreeSet difference, symmetric_difference & union optimized --- src/liballoc/collections/btree/set.rs | 269 ++++++++++++++------------ src/liballoc/tests/btree/set.rs | 26 ++- 2 files changed, 171 insertions(+), 124 deletions(-) diff --git a/src/liballoc/collections/btree/set.rs b/src/liballoc/collections/btree/set.rs index 8250fc38ccd1c..dcfd3ebb343dd 100644 --- a/src/liballoc/collections/btree/set.rs +++ b/src/liballoc/collections/btree/set.rs @@ -2,10 +2,10 @@ // to TreeMap use core::borrow::Borrow; -use core::cmp::Ordering::{self, Less, Greater, Equal}; +use core::cmp::Ordering::{Less, Greater, Equal}; use core::cmp::{max, min}; use core::fmt::{self, Debug}; -use core::iter::{Peekable, FromIterator, FusedIterator}; +use core::iter::{FromIterator, FusedIterator}; use core::ops::{BitOr, BitAnd, BitXor, Sub, RangeBounds}; use crate::collections::btree_map::{self, BTreeMap, Keys}; @@ -109,6 +109,95 @@ pub struct Range<'a, T: 'a> { iter: btree_map::Range<'a, T, ()>, } +/// Core of an alternative to `Peekable`, that is more efficient +/// if copying items is cheap and if peeking early does no harm. +/// It always keeps one item (`head`) read ahead from the +/// wrapped iterator, if any is left, and the remainder (`tail`) +/// of the wrapped iterator is always one step beyond what the +/// caller sees. +/// This `struct` behaves as ExactSizeIterator and FusedIterator, +/// but is not formally defined as such, because nobody needs it. +#[derive(Clone, Debug)] +struct Peeking +where + I: Iterator, + I::Item: Copy, +{ + head: Option, + tail: I, +} + +impl Peeking +where + I: ExactSizeIterator + FusedIterator, + I::Item: Copy, +{ + fn new(mut iter: I) -> Self { + let head = iter.next(); + Peeking { head, tail: iter } + } + + fn next(&mut self) -> Option { + let next = self.head; + self.head = self.tail.next(); + next + } + + fn len(&self) -> usize { + self.head.map_or(0, |_| 1 + self.tail.len()) + } +} + +/// Core of SymmetricDifference and Union. +/// More efficient than btree.map.MergeIter thanks to Peeking, +/// and crucially for SymmetricDifference, next() reports on both sides. +#[derive(Clone)] +struct MergeIterInner +where + I: Iterator, + I::Item: Copy, +{ + a: Peeking, + b: Peeking, +} + +impl MergeIterInner +where + I: ExactSizeIterator + FusedIterator, + I::Item: Copy + Ord, +{ + fn new(a: I, b: I) -> Self { + MergeIterInner { + a: Peeking::new(a), + b: Peeking::new(b), + } + } + + fn next(&mut self) -> (Option, Option) { + let ord = match (self.a.head, self.b.head) { + (None, None) => return (None, None), + (Some(_), None) => Less, + (None, Some(_)) => Greater, + (Some(a1), Some(b1)) => a1.cmp(&b1), + }; + match ord { + Less => (self.a.next(), None), + Equal => (self.a.next(), self.b.next()), + Greater => (None, self.b.next()), + } + } +} + +impl Debug for MergeIterInner +where + I: Iterator + Debug, + I::Item: Copy + Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("MergeIterInner").field(&self.a).field(&self.b).finish() + } +} + /// A lazy iterator producing elements in the difference of `BTreeSet`s. /// /// This `struct` is created by the [`difference`] method on [`BTreeSet`]. @@ -120,11 +209,12 @@ pub struct Range<'a, T: 'a> { pub struct Difference<'a, T: 'a> { inner: DifferenceInner<'a, T>, } +#[derive(Debug)] enum DifferenceInner<'a, T: 'a> { Stitch { // iterate all of self and some of other, spotting matches along the way self_iter: Iter<'a, T>, - other_iter: Peekable>, + other_iter: Peeking>, }, Search { // iterate a small set, look up in the large set @@ -137,21 +227,7 @@ enum DifferenceInner<'a, T: 'a> { #[stable(feature = "collection_debug", since = "1.17.0")] impl fmt::Debug for Difference<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self.inner { - DifferenceInner::Stitch { - self_iter, - other_iter, - } => f - .debug_tuple("Difference") - .field(&self_iter) - .field(&other_iter) - .finish(), - DifferenceInner::Search { - self_iter, - other_set: _, - } => f.debug_tuple("Difference").field(&self_iter).finish(), - DifferenceInner::Iterate(iter) => f.debug_tuple("Difference").field(&iter).finish(), - } + f.debug_tuple("Difference").field(&self.inner).finish() } } @@ -163,18 +239,12 @@ impl fmt::Debug for Difference<'_, T> { /// [`BTreeSet`]: struct.BTreeSet.html /// [`symmetric_difference`]: struct.BTreeSet.html#method.symmetric_difference #[stable(feature = "rust1", since = "1.0.0")] -pub struct SymmetricDifference<'a, T: 'a> { - a: Peekable>, - b: Peekable>, -} +pub struct SymmetricDifference<'a, T: 'a>(MergeIterInner>); #[stable(feature = "collection_debug", since = "1.17.0")] impl fmt::Debug for SymmetricDifference<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("SymmetricDifference") - .field(&self.a) - .field(&self.b) - .finish() + f.debug_tuple("SymmetricDifference").field(&self.0).finish() } } @@ -189,6 +259,7 @@ impl fmt::Debug for SymmetricDifference<'_, T> { pub struct Intersection<'a, T: 'a> { inner: IntersectionInner<'a, T>, } +#[derive(Debug)] enum IntersectionInner<'a, T: 'a> { Stitch { // iterate similarly sized sets jointly, spotting matches along the way @@ -206,23 +277,7 @@ enum IntersectionInner<'a, T: 'a> { #[stable(feature = "collection_debug", since = "1.17.0")] impl fmt::Debug for Intersection<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self.inner { - IntersectionInner::Stitch { - a, - b, - } => f - .debug_tuple("Intersection") - .field(&a) - .field(&b) - .finish(), - IntersectionInner::Search { - small_iter, - large_set: _, - } => f.debug_tuple("Intersection").field(&small_iter).finish(), - IntersectionInner::Answer(answer) => { - f.debug_tuple("Intersection").field(&answer).finish() - } - } + f.debug_tuple("Intersection").field(&self.inner).finish() } } @@ -234,18 +289,12 @@ impl fmt::Debug for Intersection<'_, T> { /// [`BTreeSet`]: struct.BTreeSet.html /// [`union`]: struct.BTreeSet.html#method.union #[stable(feature = "rust1", since = "1.0.0")] -pub struct Union<'a, T: 'a> { - a: Peekable>, - b: Peekable>, -} +pub struct Union<'a, T: 'a>(MergeIterInner>); #[stable(feature = "collection_debug", since = "1.17.0")] impl fmt::Debug for Union<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("Union") - .field(&self.a) - .field(&self.b) - .finish() + f.debug_tuple("Union").field(&self.0).finish() } } @@ -355,19 +404,16 @@ impl BTreeSet { self_iter.next_back(); DifferenceInner::Iterate(self_iter) } - _ => { - if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF { - DifferenceInner::Search { - self_iter: self.iter(), - other_set: other, - } - } else { - DifferenceInner::Stitch { - self_iter: self.iter(), - other_iter: other.iter().peekable(), - } + _ if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => { + DifferenceInner::Search { + self_iter: self.iter(), + other_set: other, } } + _ => DifferenceInner::Stitch { + self_iter: self.iter(), + other_iter: Peeking::new(other.iter()), + }, }, } } @@ -396,10 +442,7 @@ impl BTreeSet { pub fn symmetric_difference<'a>(&'a self, other: &'a BTreeSet) -> SymmetricDifference<'a, T> { - SymmetricDifference { - a: self.iter().peekable(), - b: other.iter().peekable(), - } + SymmetricDifference(MergeIterInner::new(self.iter(), other.iter())) } /// Visits the values representing the intersection, @@ -447,24 +490,22 @@ impl BTreeSet { (Greater, _) | (_, Less) => IntersectionInner::Answer(None), (Equal, _) => IntersectionInner::Answer(Some(self_min)), (_, Equal) => IntersectionInner::Answer(Some(self_max)), - _ => { - if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF { - IntersectionInner::Search { - small_iter: self.iter(), - large_set: other, - } - } else if other.len() <= self.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF { - IntersectionInner::Search { - small_iter: other.iter(), - large_set: self, - } - } else { - IntersectionInner::Stitch { - a: self.iter(), - b: other.iter(), - } + _ if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => { + IntersectionInner::Search { + small_iter: self.iter(), + large_set: other, + } + } + _ if other.len() <= self.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => { + IntersectionInner::Search { + small_iter: other.iter(), + large_set: self, } } + _ => IntersectionInner::Stitch { + a: self.iter(), + b: other.iter(), + }, }, } } @@ -489,10 +530,7 @@ impl BTreeSet { /// ``` #[stable(feature = "rust1", since = "1.0.0")] pub fn union<'a>(&'a self, other: &'a BTreeSet) -> Union<'a, T> { - Union { - a: self.iter().peekable(), - b: other.iter().peekable(), - } + Union(MergeIterInner::new(self.iter(), other.iter())) } /// Clears the set, removing all values. @@ -1166,15 +1204,6 @@ impl<'a, T> DoubleEndedIterator for Range<'a, T> { #[stable(feature = "fused", since = "1.26.0")] impl FusedIterator for Range<'_, T> {} -/// Compares `x` and `y`, but return `short` if x is None and `long` if y is None -fn cmp_opt(x: Option<&T>, y: Option<&T>, short: Ordering, long: Ordering) -> Ordering { - match (x, y) { - (None, _) => short, - (_, None) => long, - (Some(x1), Some(y1)) => x1.cmp(y1), - } -} - #[stable(feature = "rust1", since = "1.0.0")] impl Clone for Difference<'_, T> { fn clone(&self) -> Self { @@ -1212,7 +1241,7 @@ impl<'a, T: Ord> Iterator for Difference<'a, T> { let mut self_next = self_iter.next()?; loop { match other_iter - .peek() + .head .map_or(Less, |other_next| self_next.cmp(other_next)) { Less => return Some(self_next), @@ -1261,10 +1290,7 @@ impl FusedIterator for Difference<'_, T> {} #[stable(feature = "rust1", since = "1.0.0")] impl Clone for SymmetricDifference<'_, T> { fn clone(&self) -> Self { - SymmetricDifference { - a: self.a.clone(), - b: self.b.clone(), - } + SymmetricDifference(self.0.clone()) } } #[stable(feature = "rust1", since = "1.0.0")] @@ -1273,19 +1299,22 @@ impl<'a, T: Ord> Iterator for SymmetricDifference<'a, T> { fn next(&mut self) -> Option<&'a T> { loop { - match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) { - Less => return self.a.next(), - Equal => { - self.a.next(); - self.b.next(); - } - Greater => return self.b.next(), + match self.0.next() { + (None, None) => return None, + (Some(item), None) => return Some(item), + (None, Some(item)) => return Some(item), + (Some(_), Some(_)) => (), } } } fn size_hint(&self) -> (usize, Option) { - (0, Some(self.a.len() + self.b.len())) + let a_len = self.0.a.len(); + let b_len = self.0.b.len(); + // No checked_add, because even if a and b refer to the same set, + // and T is empty, a set's storage overhead implies that the address + // space allows fewer elements than half of the usize range. + (0, Some(a_len + b_len)) } } @@ -1311,7 +1340,7 @@ impl Clone for Intersection<'_, T> { small_iter: small_iter.clone(), large_set, }, - IntersectionInner::Answer(answer) => IntersectionInner::Answer(answer.clone()), + IntersectionInner::Answer(answer) => IntersectionInner::Answer(*answer), }, } } @@ -1365,10 +1394,7 @@ impl FusedIterator for Intersection<'_, T> {} #[stable(feature = "rust1", since = "1.0.0")] impl Clone for Union<'_, T> { fn clone(&self) -> Self { - Union { - a: self.a.clone(), - b: self.b.clone(), - } + Union(self.0.clone()) } } #[stable(feature = "rust1", since = "1.0.0")] @@ -1376,19 +1402,16 @@ impl<'a, T: Ord> Iterator for Union<'a, T> { type Item = &'a T; fn next(&mut self) -> Option<&'a T> { - match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) { - Less => self.a.next(), - Equal => { - self.b.next(); - self.a.next() - } - Greater => self.b.next(), - } + let (a_next, b_next) = self.0.next(); + a_next.or(b_next) } fn size_hint(&self) -> (usize, Option) { - let a_len = self.a.len(); - let b_len = self.b.len(); + let a_len = self.0.a.len(); + let b_len = self.0.b.len(); + // No checked_add, because even if a and b refer to the same set, + // and T is empty, a set's storage overhead implies that the address + // space allows fewer elements than half of the usize range. (max(a_len, b_len), Some(a_len + b_len)) } } diff --git a/src/liballoc/tests/btree/set.rs b/src/liballoc/tests/btree/set.rs index 5c611fd21d21b..e4883abc8b56c 100644 --- a/src/liballoc/tests/btree/set.rs +++ b/src/liballoc/tests/btree/set.rs @@ -221,6 +221,18 @@ fn test_symmetric_difference() { &[-2, 1, 5, 11, 14, 22]); } +#[test] +fn test_symmetric_difference_size_hint() { + let x: BTreeSet = [2, 4].iter().copied().collect(); + let y: BTreeSet = [1, 2, 3].iter().copied().collect(); + let mut iter = x.symmetric_difference(&y); + assert_eq!(iter.size_hint(), (0, Some(5))); + assert_eq!(iter.next(), Some(&1)); + assert_eq!(iter.size_hint(), (0, Some(4))); + assert_eq!(iter.next(), Some(&3)); + assert_eq!(iter.size_hint(), (0, Some(1))); +} + #[test] fn test_union() { fn check_union(a: &[i32], b: &[i32], expected: &[i32]) { @@ -235,6 +247,18 @@ fn test_union() { &[-2, 1, 3, 5, 9, 11, 13, 16, 19, 24]); } +#[test] +fn test_union_size_hint() { + let x: BTreeSet = [2, 4].iter().copied().collect(); + let y: BTreeSet = [1, 2, 3].iter().copied().collect(); + let mut iter = x.union(&y); + assert_eq!(iter.size_hint(), (3, Some(5))); + assert_eq!(iter.next(), Some(&1)); + assert_eq!(iter.size_hint(), (2, Some(4))); + assert_eq!(iter.next(), Some(&2)); + assert_eq!(iter.size_hint(), (1, Some(2))); +} + #[test] // Only tests the simple function definition with respect to intersection fn test_is_disjoint() { @@ -244,7 +268,7 @@ fn test_is_disjoint() { } #[test] -// Also tests the trivial function definition of is_superset +// Also implicitly tests the trivial function definition of is_superset fn test_is_subset() { fn is_subset(a: &[i32], b: &[i32]) -> bool { let set_a = a.iter().collect::>();