diff --git a/CHANGELOG.md b/CHANGELOG.md index 190d6e8..b2522fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,8 @@ # Changelog -## Unreleased +## Unreleased +- Optimise `StrExt::to_lowercase_smolstr`, `StrExt::to_uppercase_smolstr` ~2x speedup inline, ~5-50x for heap. - Optimise `StrExt::to_ascii_lowercase_smolstr`, `StrExt::to_ascii_uppercase_smolstr` ~2x speedup inline, ~4-22x for heap. diff --git a/src/lib.rs b/src/lib.rs index ff25651..5ef6260 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -233,18 +233,30 @@ impl iter::FromIterator for SmolStr { } } -fn from_char_iter(mut iter: impl Iterator) -> SmolStr { - let (min_size, _) = iter.size_hint(); +#[inline] +fn from_char_iter(iter: impl Iterator) -> SmolStr { + from_buf_and_chars([0; _], 0, iter) +} + +fn from_buf_and_chars( + mut buf: [u8; INLINE_CAP], + buf_len: usize, + mut iter: impl Iterator, +) -> SmolStr { + let min_size = iter.size_hint().0 + buf_len; if min_size > INLINE_CAP { - let heap: String = iter.collect(); + let heap: String = core::str::from_utf8(&buf[..buf_len]) + .unwrap() + .chars() + .chain(iter) + .collect(); if heap.len() <= INLINE_CAP { // size hint lied return SmolStr::new_inline(&heap); } return SmolStr(Repr::Heap(heap.into_boxed_str().into())); } - let mut len = 0; - let mut buf = [0u8; INLINE_CAP]; + let mut len = buf_len; while let Some(ch) = iter.next() { let size = ch.len_utf8(); if size + len > INLINE_CAP { @@ -634,12 +646,32 @@ pub trait StrExt: private::Sealed { impl StrExt for str { #[inline] fn to_lowercase_smolstr(&self) -> SmolStr { - from_char_iter(self.chars().flat_map(|c| c.to_lowercase())) + let len = self.len(); + if len <= INLINE_CAP { + let (buf, rest) = inline_convert_while_ascii(self, u8::to_ascii_lowercase); + from_buf_and_chars( + buf, + len - rest.len(), + rest.chars().flat_map(|c| c.to_lowercase()), + ) + } else { + self.to_lowercase().into() + } } #[inline] fn to_uppercase_smolstr(&self) -> SmolStr { - from_char_iter(self.chars().flat_map(|c| c.to_uppercase())) + let len = self.len(); + if len <= INLINE_CAP { + let (buf, rest) = inline_convert_while_ascii(self, u8::to_ascii_uppercase); + from_buf_and_chars( + buf, + len - rest.len(), + rest.chars().flat_map(|c| c.to_uppercase()), + ) + } else { + self.to_uppercase().into() + } } #[inline] @@ -699,6 +731,70 @@ impl StrExt for str { } } +/// Inline version of std fn `convert_while_ascii`. `s` must have len <= 23. +#[inline] +fn inline_convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> ([u8; INLINE_CAP], &str) { + // Process the input in chunks of 16 bytes to enable auto-vectorization. + // Previously the chunk size depended on the size of `usize`, + // but on 32-bit platforms with sse or neon is also the better choice. + // The only downside on other platforms would be a bit more loop-unrolling. + const N: usize = 16; + + debug_assert!(s.len() <= INLINE_CAP, "only for inline-able strings"); + + let mut slice = s.as_bytes(); + let mut out = [0u8; INLINE_CAP]; + let mut out_slice = &mut out[..slice.len()]; + let mut is_ascii = [false; N]; + + while slice.len() >= N { + // SAFETY: checked in loop condition + let chunk = unsafe { slice.get_unchecked(..N) }; + // SAFETY: out_slice has at least same length as input slice and gets sliced with the same offsets + let out_chunk = unsafe { out_slice.get_unchecked_mut(..N) }; + + for j in 0..N { + is_ascii[j] = chunk[j] <= 127; + } + + // Auto-vectorization for this check is a bit fragile, sum and comparing against the chunk + // size gives the best result, specifically a pmovmsk instruction on x86. + // See https://github.com/llvm/llvm-project/issues/96395 for why llvm currently does not + // currently recognize other similar idioms. + if is_ascii.iter().map(|x| *x as u8).sum::() as usize != N { + break; + } + + for j in 0..N { + out_chunk[j] = convert(&chunk[j]); + } + + slice = unsafe { slice.get_unchecked(N..) }; + out_slice = unsafe { out_slice.get_unchecked_mut(N..) }; + } + + // handle the remainder as individual bytes + while !slice.is_empty() { + let byte = slice[0]; + if byte > 127 { + break; + } + // SAFETY: out_slice has at least same length as input slice + unsafe { + *out_slice.get_unchecked_mut(0) = convert(&byte); + } + slice = unsafe { slice.get_unchecked(1..) }; + out_slice = unsafe { out_slice.get_unchecked_mut(1..) }; + } + + unsafe { + // SAFETY: we know this is a valid char boundary + // since we only skipped over leading ascii bytes + let rest = core::str::from_utf8_unchecked(slice); + (out, rest) + } +} + impl ToSmolStr for T where T: fmt::Display + ?Sized, @@ -848,3 +944,14 @@ impl<'a> arbitrary::Arbitrary<'a> for SmolStr { mod borsh; #[cfg(feature = "serde")] mod serde; + +#[test] +fn from_buf_and_chars_size_hinted_heap() { + let str = from_buf_and_chars( + *b"abcdefghijklmnopqr00000", + 18, + "_0x1x2x3x4x5x6x7x8x9x10x11x12x13".chars(), + ); + + assert_eq!(str, "abcdefghijklmnopqr_0x1x2x3x4x5x6x7x8x9x10x11x12x13"); +}