diff --git a/CHANGELOG.md b/CHANGELOG.md index 190d6e8..224b8fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,9 @@ # Changelog -## Unreleased +## Unreleased +- Optimise `StrExt::replace_smolstr`, `StrExt::replacen_smolstr` for single ascii replace, + ~3x speedup inline & heap. - Optimise `StrExt::to_ascii_lowercase_smolstr`, `StrExt::to_ascii_uppercase_smolstr` ~2x speedup inline, ~4-22x for heap. diff --git a/src/lib.rs b/src/lib.rs index ff25651..5dde297 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -682,7 +682,28 @@ impl StrExt for str { } #[inline] - fn replacen_smolstr(&self, from: &str, to: &str, count: usize) -> SmolStr { + fn replacen_smolstr(&self, from: &str, to: &str, mut count: usize) -> SmolStr { + // Fast path for replacing a single ASCII character with another inline. + if let [from_u8] = from.as_bytes() { + if let [to_u8] = to.as_bytes() { + return if self.len() <= count { + // SAFETY: `from_u8` & `to_u8` are ascii + unsafe { replacen_1_ascii(self, |b| if b == from_u8 { *to_u8 } else { *b }) } + } else { + unsafe { + replacen_1_ascii(self, |b| { + if b == from_u8 && count != 0 { + count -= 1; + *to_u8 + } else { + *b + } + }) + } + }; + } + } + let mut result = SmolStrBuilder::new(); let mut last_end = 0; for (start, part) in self.match_indices(from).take(count) { @@ -699,6 +720,26 @@ impl StrExt for str { } } +/// SAFETY: `map` fn must only replace ascii with ascii or return unchanged bytes. +#[inline] +unsafe fn replacen_1_ascii(src: &str, mut map: impl FnMut(&u8) -> u8) -> SmolStr { + if src.len() <= INLINE_CAP { + let mut buf = [0u8; INLINE_CAP]; + for (idx, b) in src.as_bytes().iter().enumerate() { + buf[idx] = map(b); + } + SmolStr(Repr::Inline { + // SAFETY: `len` is in bounds + len: unsafe { InlineSize::transmute_from_u8(src.len() as u8) }, + buf, + }) + } else { + let out = src.as_bytes().iter().map(map).collect(); + // SAFETY: We replaced ascii with ascii on valid utf8 strings. + unsafe { String::from_utf8_unchecked(out).into() } + } +} + impl ToSmolStr for T where T: fmt::Display + ?Sized, diff --git a/tests/test.rs b/tests/test.rs index 0070b3a..8f7d9ec 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -389,6 +389,13 @@ mod test_str_ext { assert_eq!(result, "foo_dor_baz"); assert!(!result.is_heap_allocated()); } + + #[test] + fn replacen_1_ascii() { + let result = "foo_bar_baz".replacen_smolstr("o", "u", 1); + assert_eq!(result, "fuo_bar_baz"); + assert!(!result.is_heap_allocated()); + } } #[cfg(feature = "borsh")]