Skip to content

Commit

Permalink
Auto merge of rust-lang#130223 - LaihoE:faster_str_replace, r=thomcc
Browse files Browse the repository at this point in the history
optimize str.replace

Adds a fast path for str.replace for the ascii to ascii case. This allows for autovectorizing the code. Also should this instead be done with specialization? This way we could remove one branch. I think it is the kind of branch that is easy to predict though.

Benchmark for the fast path (replace all "a" with "b" in the rust wikipedia article, using criterion) :
| N        | Speedup | Time New (ns) | Time Old (ns) |
|----------|---------|---------------|---------------|
| 2        | 2.03    | 13.567        | 27.576        |
| 8        | 1.73    | 17.478        | 30.259        |
| 11       | 2.46    | 18.296        | 45.055        |
| 16       | 2.71    | 17.181        | 46.526        |
| 37       | 4.43    | 18.526        | 81.997        |
| 64       | 8.54    | 18.670        | 159.470       |
| 200      | 9.82    | 29.634        | 291.010       |
| 2000     | 24.34   | 81.114        | 1974.300      |
| 20000    | 30.61   | 598.520       | 18318.000     |
| 1000000  | 29.31   | 33458.000     | 980540.000    |
  • Loading branch information
bors committed Oct 17, 2024
2 parents 6c85d31 + 27136c4 commit 0d7c889
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 2 deletions.
25 changes: 24 additions & 1 deletion alloc/src/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub use core::str::SplitInclusive;
pub use core::str::SplitWhitespace;
#[stable(feature = "rust1", since = "1.0.0")]
pub use core::str::pattern;
use core::str::pattern::{DoubleEndedSearcher, Pattern, ReverseSearcher, Searcher};
use core::str::pattern::{DoubleEndedSearcher, Pattern, ReverseSearcher, Searcher, Utf8Pattern};
#[stable(feature = "rust1", since = "1.0.0")]
pub use core::str::{Bytes, CharIndices, Chars, from_utf8, from_utf8_mut};
#[stable(feature = "str_escape", since = "1.34.0")]
Expand Down Expand Up @@ -269,6 +269,18 @@ impl str {
#[stable(feature = "rust1", since = "1.0.0")]
#[inline]
pub fn replace<P: Pattern>(&self, from: P, to: &str) -> String {
// Fast path for ASCII to ASCII case.

if let Some(from_byte) = match from.as_utf8_pattern() {
Some(Utf8Pattern::StringPattern([from_byte])) => Some(*from_byte),
Some(Utf8Pattern::CharPattern(c)) => c.as_ascii().map(|ascii_char| ascii_char.to_u8()),
_ => None,
} {
if let [to_byte] = to.as_bytes() {
return unsafe { replace_ascii(self.as_bytes(), from_byte, *to_byte) };
}
}

let mut result = String::new();
let mut last_end = 0;
for (start, part) in self.match_indices(from) {
Expand Down Expand Up @@ -686,3 +698,14 @@ pub fn convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> (String, &str) {
(ascii_string, rest)
}
}
#[inline]
#[cfg(not(test))]
#[cfg(not(no_global_oom_handling))]
#[allow(dead_code)]
/// Faster implementation of string replacement for ASCII to ASCII cases.
/// Should produce fast vectorized code.
unsafe fn replace_ascii(utf8_bytes: &[u8], from: u8, to: u8) -> String {
let result: Vec<u8> = utf8_bytes.iter().map(|b| if *b == from { to } else { *b }).collect();
// SAFETY: We replaced ascii with ascii on valid utf8 strings.
unsafe { String::from_utf8_unchecked(result) }
}
7 changes: 6 additions & 1 deletion alloc/src/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ use core::ops::AddAssign;
#[cfg(not(no_global_oom_handling))]
use core::ops::Bound::{Excluded, Included, Unbounded};
use core::ops::{self, Range, RangeBounds};
use core::str::pattern::Pattern;
use core::str::pattern::{Pattern, Utf8Pattern};
use core::{fmt, hash, ptr, slice};

#[cfg(not(no_global_oom_handling))]
Expand Down Expand Up @@ -2436,6 +2436,11 @@ impl<'b> Pattern for &'b String {
{
self[..].strip_suffix_of(haystack)
}

#[inline]
fn as_utf8_pattern(&self) -> Option<Utf8Pattern<'_>> {
Some(Utf8Pattern::StringPattern(self.as_bytes()))
}
}

macro_rules! impl_eq {
Expand Down
33 changes: 33 additions & 0 deletions core/src/str/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,19 @@ pub trait Pattern: Sized {
None
}
}

/// Returns the pattern as utf-8 bytes if possible.
fn as_utf8_pattern(&self) -> Option<Utf8Pattern<'_>>;
}
/// Result of calling [`Pattern::as_utf8_pattern()`].
/// Can be used for inspecting the contents of a [`Pattern`] in cases
/// where the underlying representation can be represented as UTF-8.
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub enum Utf8Pattern<'a> {
/// Type returned by String and str types.
StringPattern(&'a [u8]),
/// Type returned by char types.
CharPattern(char),
}

// Searcher
Expand Down Expand Up @@ -599,6 +612,11 @@ impl Pattern for char {
{
self.encode_utf8(&mut [0u8; 4]).strip_suffix_of(haystack)
}

#[inline]
fn as_utf8_pattern(&self) -> Option<Utf8Pattern<'_>> {
Some(Utf8Pattern::CharPattern(*self))
}
}

/////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -657,6 +675,11 @@ impl<C: MultiCharEq> Pattern for MultiCharEqPattern<C> {
fn into_searcher(self, haystack: &str) -> MultiCharEqSearcher<'_, C> {
MultiCharEqSearcher { haystack, char_eq: self.0, char_indices: haystack.char_indices() }
}

#[inline]
fn as_utf8_pattern(&self) -> Option<Utf8Pattern<'_>> {
None
}
}

unsafe impl<'a, C: MultiCharEq> Searcher<'a> for MultiCharEqSearcher<'a, C> {
Expand Down Expand Up @@ -747,6 +770,11 @@ macro_rules! pattern_methods {
{
($pmap)(self).strip_suffix_of(haystack)
}

#[inline]
fn as_utf8_pattern(&self) -> Option<Utf8Pattern<'_>> {
None
}
};
}

Expand Down Expand Up @@ -1022,6 +1050,11 @@ impl<'b> Pattern for &'b str {
None
}
}

#[inline]
fn as_utf8_pattern(&self) -> Option<Utf8Pattern<'_>> {
Some(Utf8Pattern::StringPattern(self.as_bytes()))
}
}

/////////////////////////////////////////////////////////////////////////////
Expand Down

0 comments on commit 0d7c889

Please # to comment.