From d90a7b3b006be912493ddf15d0dc5895a6929b38 Mon Sep 17 00:00:00 2001 From: f001 Date: Sun, 5 Feb 2017 17:39:52 +0800 Subject: [PATCH] std: Add retain method for HashMap and HashSet Fix #36648 --- src/libstd/collections/hash/map.rs | 114 ++++++++++++++++++--------- src/libstd/collections/hash/set.rs | 33 ++++++++ src/libstd/collections/hash/table.rs | 81 +++++++++++++++++-- 3 files changed, 182 insertions(+), 46 deletions(-) diff --git a/src/libstd/collections/hash/map.rs b/src/libstd/collections/hash/map.rs index 8058972e75093..f689589dfa25f 100644 --- a/src/libstd/collections/hash/map.rs +++ b/src/libstd/collections/hash/map.rs @@ -416,22 +416,26 @@ fn search_hashed(table: M, hash: SafeHash, mut is_match: F) -> Inter } } -fn pop_internal(starting_bucket: FullBucketMut) -> (K, V) { +fn pop_internal(starting_bucket: FullBucketMut) + -> (K, V, &mut RawTable) +{ let (empty, retkey, retval) = starting_bucket.take(); let mut gap = match empty.gap_peek() { - Some(b) => b, - None => return (retkey, retval), + Ok(b) => b, + Err(b) => return (retkey, retval, b.into_table()), }; while gap.full().displacement() != 0 { gap = match gap.shift() { - Some(b) => b, - None => break, + Ok(b) => b, + Err(b) => { + return (retkey, retval, b.into_table()); + }, }; } // Now we've done all our shifting. Return the value we grabbed earlier. - (retkey, retval) + (retkey, retval, gap.into_bucket().into_table()) } /// Perform robin hood bucket stealing at the given `bucket`. You must @@ -721,38 +725,7 @@ impl HashMap return; } - // Grow the table. - // Specialization of the other branch. - let mut bucket = Bucket::first(&mut old_table); - - // "So a few of the first shall be last: for many be called, - // but few chosen." - // - // We'll most likely encounter a few buckets at the beginning that - // have their initial buckets near the end of the table. They were - // placed at the beginning as the probe wrapped around the table - // during insertion. We must skip forward to a bucket that won't - // get reinserted too early and won't unfairly steal others spot. - // This eliminates the need for robin hood. - loop { - bucket = match bucket.peek() { - Full(full) => { - if full.displacement() == 0 { - // This bucket occupies its ideal spot. - // It indicates the start of another "cluster". - bucket = full.into_bucket(); - break; - } - // Leaving this bucket in the last cluster for later. - full.into_bucket() - } - Empty(b) => { - // Encountered a hole between clusters. - b.into_bucket() - } - }; - bucket.next(); - } + let mut bucket = Bucket::head_bucket(&mut old_table); // This is how the buckets might be laid out in memory: // ($ marks an initialized bucket) @@ -1208,6 +1181,57 @@ impl HashMap self.search_mut(k).into_occupied_bucket().map(|bucket| pop_internal(bucket).1) } + + /// Retains only the elements specified by the predicate. + /// + /// In other words, remove all pairs `(k, v)` such that `f(&k,&mut v)` returns `false`. + /// + /// # Examples + /// + /// ``` + /// #![feature(retain_hash_collection)] + /// use std::collections::HashMap; + /// + /// let mut map: HashMap = (0..8).map(|x|(x, x*10)).collect(); + /// map.retain(|&k, _| k % 2 == 0); + /// assert_eq!(map.len(), 4); + /// ``` + #[unstable(feature = "retain_hash_collection", issue = "36648")] + pub fn retain(&mut self, mut f: F) + where F: FnMut(&K, &mut V) -> bool + { + if self.table.capacity() == 0 || self.table.size() == 0 { + return; + } + let mut bucket = Bucket::head_bucket(&mut self.table); + bucket.prev(); + let tail = bucket.index(); + loop { + bucket = match bucket.peek() { + Full(mut full) => { + let should_remove = { + let (k, v) = full.read_mut(); + !f(k, v) + }; + if should_remove { + let prev_idx = full.index(); + let prev_raw = full.raw(); + let (_, _, t) = pop_internal(full); + Bucket::new_from(prev_raw, prev_idx, t) + } else { + full.into_bucket() + } + }, + Empty(b) => { + b.into_bucket() + } + }; + bucket.prev(); // reverse iteration + if bucket.index() == tail { + break; + } + } + } } #[stable(feature = "rust1", since = "1.0.0")] @@ -1862,7 +1886,8 @@ impl<'a, K, V> OccupiedEntry<'a, K, V> { /// ``` #[stable(feature = "map_entry_recover_keys2", since = "1.12.0")] pub fn remove_entry(self) -> (K, V) { - pop_internal(self.elem) + let (k, v, _) = pop_internal(self.elem); + (k, v) } /// Gets a reference to the value in the entry. @@ -3156,4 +3181,15 @@ mod test_map { assert_eq!(a.len(), 1); assert_eq!(a[key], value); } + + #[test] + fn test_retain() { + let mut map: HashMap = (0..100).map(|x|(x, x*10)).collect(); + + map.retain(|&k, _| k % 2 == 0); + assert_eq!(map.len(), 50); + assert_eq!(map[&2], 20); + assert_eq!(map[&4], 40); + assert_eq!(map[&6], 60); + } } diff --git a/src/libstd/collections/hash/set.rs b/src/libstd/collections/hash/set.rs index 341b050862f5c..8de742db46110 100644 --- a/src/libstd/collections/hash/set.rs +++ b/src/libstd/collections/hash/set.rs @@ -624,6 +624,28 @@ impl HashSet { Recover::take(&mut self.map, value) } + + /// Retains only the elements specified by the predicate. + /// + /// In other words, remove all elements `e` such that `f(&e)` returns `false`. + /// + /// # Examples + /// + /// ``` + /// #![feature(retain_hash_collection)] + /// use std::collections::HashSet; + /// + /// let xs = [1,2,3,4,5,6]; + /// let mut set: HashSet = xs.iter().cloned().collect(); + /// set.retain(|&k| k % 2 == 0); + /// assert_eq!(set.len(), 3); + /// ``` + #[unstable(feature = "retain_hash_collection", issue = "36648")] + pub fn retain(&mut self, mut f: F) + where F: FnMut(&T) -> bool + { + self.map.retain(|k, _| f(k)); + } } #[stable(feature = "rust1", since = "1.0.0")] @@ -1605,4 +1627,15 @@ mod test_set { assert!(a.contains(&5)); assert!(a.contains(&6)); } + + #[test] + fn test_retain() { + let xs = [1,2,3,4,5,6]; + let mut set: HashSet = xs.iter().cloned().collect(); + set.retain(|&k| k % 2 == 0); + assert_eq!(set.len(), 3); + assert!(set.contains(&2)); + assert!(set.contains(&4)); + assert!(set.contains(&6)); + } } diff --git a/src/libstd/collections/hash/table.rs b/src/libstd/collections/hash/table.rs index 1ab62130cd3dd..9e92b4750145e 100644 --- a/src/libstd/collections/hash/table.rs +++ b/src/libstd/collections/hash/table.rs @@ -85,7 +85,7 @@ pub struct RawTable { unsafe impl Send for RawTable {} unsafe impl Sync for RawTable {} -struct RawBucket { +pub struct RawBucket { hash: *mut HashUint, // We use *const to ensure covariance with respect to K and V pair: *const (K, V), @@ -216,6 +216,10 @@ impl FullBucket { pub fn index(&self) -> usize { self.idx } + /// Get the raw bucket. + pub fn raw(&self) -> RawBucket { + self.raw + } } impl EmptyBucket { @@ -230,6 +234,10 @@ impl Bucket { pub fn index(&self) -> usize { self.idx } + /// get the table. + pub fn into_table(self) -> M { + self.table + } } impl Deref for FullBucket @@ -275,6 +283,16 @@ impl>> Bucket { Bucket::at_index(table, hash.inspect() as usize) } + pub fn new_from(r: RawBucket, i: usize, t: M) + -> Bucket + { + Bucket { + raw: r, + idx: i, + table: t, + } + } + pub fn at_index(table: M, ib_index: usize) -> Bucket { // if capacity is 0, then the RawBucket will be populated with bogus pointers. // This is an uncommon case though, so avoid it in release builds. @@ -296,6 +314,40 @@ impl>> Bucket { } } + // "So a few of the first shall be last: for many be called, + // but few chosen." + // + // We'll most likely encounter a few buckets at the beginning that + // have their initial buckets near the end of the table. They were + // placed at the beginning as the probe wrapped around the table + // during insertion. We must skip forward to a bucket that won't + // get reinserted too early and won't unfairly steal others spot. + // This eliminates the need for robin hood. + pub fn head_bucket(table: M) -> Bucket { + let mut bucket = Bucket::first(table); + + loop { + bucket = match bucket.peek() { + Full(full) => { + if full.displacement() == 0 { + // This bucket occupies its ideal spot. + // It indicates the start of another "cluster". + bucket = full.into_bucket(); + break; + } + // Leaving this bucket in the last cluster for later. + full.into_bucket() + } + Empty(b) => { + // Encountered a hole between clusters. + b.into_bucket() + } + }; + bucket.next(); + } + bucket + } + /// Reads a bucket at a given index, returning an enum indicating whether /// it's initialized or not. You need to match on this enum to get /// the appropriate types to call most of the other functions in @@ -333,6 +385,17 @@ impl>> Bucket { self.raw = self.raw.offset(dist); } } + + /// Modifies the bucket pointer in place to make it point to the previous slot. + pub fn prev(&mut self) { + let range = self.table.capacity(); + let new_idx = self.idx.wrapping_sub(1) & (range - 1); + let dist = (new_idx as isize).wrapping_sub(self.idx as isize); + self.idx = new_idx; + unsafe { + self.raw = self.raw.offset(dist); + } + } } impl>> EmptyBucket { @@ -352,7 +415,7 @@ impl>> EmptyBucket { } } - pub fn gap_peek(self) -> Option> { + pub fn gap_peek(self) -> Result, Bucket> { let gap = EmptyBucket { raw: self.raw, idx: self.idx, @@ -361,12 +424,12 @@ impl>> EmptyBucket { match self.next().peek() { Full(bucket) => { - Some(GapThenFull { + Ok(GapThenFull { gap: gap, full: bucket, }) } - Empty(..) => None, + Empty(e) => Err(e.into_bucket()), } } } @@ -529,7 +592,11 @@ impl GapThenFull &self.full } - pub fn shift(mut self) -> Option> { + pub fn into_bucket(self) -> Bucket { + self.full.into_bucket() + } + + pub fn shift(mut self) -> Result, Bucket> { unsafe { *self.gap.raw.hash = mem::replace(&mut *self.full.raw.hash, EMPTY_BUCKET); ptr::copy_nonoverlapping(self.full.raw.pair, self.gap.raw.pair as *mut (K, V), 1); @@ -544,9 +611,9 @@ impl GapThenFull self.full = bucket; - Some(self) + Ok(self) } - Empty(..) => None, + Empty(b) => Err(b.into_bucket()), } } }