Skip to content

Add Range parameter to BTreeMap::extract_if and BTreeSet::extract_if #140825

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 4 commits into from
May 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 48 additions & 14 deletions library/alloc/src/collections/btree/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1151,7 +1151,7 @@ impl<K, V, A: Allocator + Clone> BTreeMap<K, V, A> {
K: Ord,
F: FnMut(&K, &mut V) -> bool,
{
self.extract_if(|k, v| !f(k, v)).for_each(drop);
self.extract_if(.., |k, v| !f(k, v)).for_each(drop);
}

/// Moves all elements from `other` into `self`, leaving `other` empty.
Expand Down Expand Up @@ -1397,7 +1397,7 @@ impl<K, V, A: Allocator + Clone> BTreeMap<K, V, A> {
}
}

/// Creates an iterator that visits all elements (key-value pairs) in
/// Creates an iterator that visits elements (key-value pairs) in the specified range in
/// ascending key order and uses a closure to determine if an element
/// should be removed.
///
Expand All @@ -1423,33 +1423,42 @@ impl<K, V, A: Allocator + Clone> BTreeMap<K, V, A> {
/// use std::collections::BTreeMap;
///
/// let mut map: BTreeMap<i32, i32> = (0..8).map(|x| (x, x)).collect();
/// let evens: BTreeMap<_, _> = map.extract_if(|k, _v| k % 2 == 0).collect();
/// let evens: BTreeMap<_, _> = map.extract_if(.., |k, _v| k % 2 == 0).collect();
/// let odds = map;
/// assert_eq!(evens.keys().copied().collect::<Vec<_>>(), [0, 2, 4, 6]);
/// assert_eq!(odds.keys().copied().collect::<Vec<_>>(), [1, 3, 5, 7]);
///
/// let mut map: BTreeMap<i32, i32> = (0..8).map(|x| (x, x)).collect();
/// let low: BTreeMap<_, _> = map.extract_if(0..4, |_k, _v| true).collect();
/// let high = map;
/// assert_eq!(low.keys().copied().collect::<Vec<_>>(), [0, 1, 2, 3]);
/// assert_eq!(high.keys().copied().collect::<Vec<_>>(), [4, 5, 6, 7]);
/// ```
#[unstable(feature = "btree_extract_if", issue = "70530")]
pub fn extract_if<F>(&mut self, pred: F) -> ExtractIf<'_, K, V, F, A>
pub fn extract_if<F, R>(&mut self, range: R, pred: F) -> ExtractIf<'_, K, V, R, F, A>
where
K: Ord,
R: RangeBounds<K>,
F: FnMut(&K, &mut V) -> bool,
{
let (inner, alloc) = self.extract_if_inner();
let (inner, alloc) = self.extract_if_inner(range);
ExtractIf { pred, inner, alloc }
}

pub(super) fn extract_if_inner(&mut self) -> (ExtractIfInner<'_, K, V>, A)
pub(super) fn extract_if_inner<R>(&mut self, range: R) -> (ExtractIfInner<'_, K, V, R>, A)
where
K: Ord,
R: RangeBounds<K>,
{
if let Some(root) = self.root.as_mut() {
let (root, dormant_root) = DormantMutRef::new(root);
let front = root.borrow_mut().first_leaf_edge();
let first = root.borrow_mut().lower_bound(SearchBound::from_range(range.start_bound()));
(
ExtractIfInner {
length: &mut self.length,
dormant_root: Some(dormant_root),
cur_leaf_edge: Some(front),
cur_leaf_edge: Some(first),
range,
},
(*self.alloc).clone(),
)
Expand All @@ -1459,6 +1468,7 @@ impl<K, V, A: Allocator + Clone> BTreeMap<K, V, A> {
length: &mut self.length,
dormant_root: None,
cur_leaf_edge: None,
range,
},
(*self.alloc).clone(),
)
Expand Down Expand Up @@ -1917,18 +1927,19 @@ pub struct ExtractIf<
'a,
K,
V,
R,
F,
#[unstable(feature = "allocator_api", issue = "32838")] A: Allocator + Clone = Global,
> {
pred: F,
inner: ExtractIfInner<'a, K, V>,
inner: ExtractIfInner<'a, K, V, R>,
/// The BTreeMap will outlive this IntoIter so we don't care about drop order for `alloc`.
alloc: A,
}

/// Most of the implementation of ExtractIf are generic over the type
/// of the predicate, thus also serving for BTreeSet::ExtractIf.
pub(super) struct ExtractIfInner<'a, K, V> {
pub(super) struct ExtractIfInner<'a, K, V, R> {
/// Reference to the length field in the borrowed map, updated live.
length: &'a mut usize,
/// Buried reference to the root field in the borrowed map.
Expand All @@ -1938,10 +1949,13 @@ pub(super) struct ExtractIfInner<'a, K, V> {
/// Empty if the map has no root, if iteration went beyond the last leaf edge,
/// or if a panic occurred in the predicate.
cur_leaf_edge: Option<Handle<NodeRef<marker::Mut<'a>, K, V, marker::Leaf>, marker::Edge>>,
/// Range over which iteration was requested. We don't need the left side, but we
/// can't extract the right side without requiring K: Clone.
range: R,
}

#[unstable(feature = "btree_extract_if", issue = "70530")]
impl<K, V, F, A> fmt::Debug for ExtractIf<'_, K, V, F, A>
impl<K, V, R, F, A> fmt::Debug for ExtractIf<'_, K, V, R, F, A>
where
K: fmt::Debug,
V: fmt::Debug,
Expand All @@ -1953,8 +1967,10 @@ where
}

#[unstable(feature = "btree_extract_if", issue = "70530")]
impl<K, V, F, A: Allocator + Clone> Iterator for ExtractIf<'_, K, V, F, A>
impl<K, V, R, F, A: Allocator + Clone> Iterator for ExtractIf<'_, K, V, R, F, A>
where
K: PartialOrd,
R: RangeBounds<K>,
F: FnMut(&K, &mut V) -> bool,
{
type Item = (K, V);
Expand All @@ -1968,7 +1984,7 @@ where
}
}

impl<'a, K, V> ExtractIfInner<'a, K, V> {
impl<'a, K, V, R> ExtractIfInner<'a, K, V, R> {
/// Allow Debug implementations to predict the next element.
pub(super) fn peek(&self) -> Option<(&K, &V)> {
let edge = self.cur_leaf_edge.as_ref()?;
Expand All @@ -1978,10 +1994,22 @@ impl<'a, K, V> ExtractIfInner<'a, K, V> {
/// Implementation of a typical `ExtractIf::next` method, given the predicate.
pub(super) fn next<F, A: Allocator + Clone>(&mut self, pred: &mut F, alloc: A) -> Option<(K, V)>
where
K: PartialOrd,
R: RangeBounds<K>,
F: FnMut(&K, &mut V) -> bool,
{
while let Ok(mut kv) = self.cur_leaf_edge.take()?.next_kv() {
let (k, v) = kv.kv_mut();

// On creation, we navigated directly to the left bound, so we need only check the
// right bound here to decide whether to stop.
match self.range.end_bound() {
Bound::Included(ref end) if (*k).le(end) => (),
Bound::Excluded(ref end) if (*k).lt(end) => (),
Bound::Unbounded => (),
_ => return None,
}

if pred(k, v) {
*self.length -= 1;
let (kv, pos) = kv.remove_kv_tracking(
Expand Down Expand Up @@ -2013,7 +2041,13 @@ impl<'a, K, V> ExtractIfInner<'a, K, V> {
}

#[unstable(feature = "btree_extract_if", issue = "70530")]
impl<K, V, F> FusedIterator for ExtractIf<'_, K, V, F> where F: FnMut(&K, &mut V) -> bool {}
impl<K, V, R, F> FusedIterator for ExtractIf<'_, K, V, R, F>
where
K: PartialOrd,
R: RangeBounds<K>,
F: FnMut(&K, &mut V) -> bool,
{
}

#[stable(feature = "btree_range", since = "1.17.0")]
impl<'a, K, V> Iterator for Range<'a, K, V> {
Expand Down
Loading
Loading