Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Changelog

## Unreleased
## Unreleased

- Optimise `StrExt::to_lowercase_smolstr`, `StrExt::to_uppercase_smolstr` ~2x speedup inline, ~5-50x for heap.
- Optimise `StrExt::to_ascii_lowercase_smolstr`, `StrExt::to_ascii_uppercase_smolstr`
~2x speedup inline, ~4-22x for heap.

Expand Down
121 changes: 114 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,18 +233,30 @@ impl iter::FromIterator<char> for SmolStr {
}
}

fn from_char_iter(mut iter: impl Iterator<Item = char>) -> SmolStr {
let (min_size, _) = iter.size_hint();
#[inline]
fn from_char_iter(iter: impl Iterator<Item = char>) -> SmolStr {
from_buf_and_chars([0; _], 0, iter)
}

fn from_buf_and_chars(
mut buf: [u8; INLINE_CAP],
buf_len: usize,
mut iter: impl Iterator<Item = char>,
) -> SmolStr {
let min_size = iter.size_hint().0 + buf_len;
if min_size > INLINE_CAP {
let heap: String = iter.collect();
let heap: String = core::str::from_utf8(&buf[..buf_len])
.unwrap()
.chars()
.chain(iter)
.collect();
if heap.len() <= INLINE_CAP {
// size hint lied
return SmolStr::new_inline(&heap);
}
return SmolStr(Repr::Heap(heap.into_boxed_str().into()));
}
let mut len = 0;
let mut buf = [0u8; INLINE_CAP];
let mut len = buf_len;
while let Some(ch) = iter.next() {
let size = ch.len_utf8();
if size + len > INLINE_CAP {
Expand Down Expand Up @@ -634,12 +646,32 @@ pub trait StrExt: private::Sealed {
impl StrExt for str {
#[inline]
fn to_lowercase_smolstr(&self) -> SmolStr {
from_char_iter(self.chars().flat_map(|c| c.to_lowercase()))
let len = self.len();
if len <= INLINE_CAP {
let (buf, rest) = inline_convert_while_ascii(self, u8::to_ascii_lowercase);
from_buf_and_chars(
buf,
len - rest.len(),
rest.chars().flat_map(|c| c.to_lowercase()),
)
} else {
self.to_lowercase().into()
}
}

#[inline]
fn to_uppercase_smolstr(&self) -> SmolStr {
from_char_iter(self.chars().flat_map(|c| c.to_uppercase()))
let len = self.len();
if len <= INLINE_CAP {
let (buf, rest) = inline_convert_while_ascii(self, u8::to_ascii_uppercase);
from_buf_and_chars(
buf,
len - rest.len(),
rest.chars().flat_map(|c| c.to_uppercase()),
)
} else {
self.to_uppercase().into()
}
}

#[inline]
Expand Down Expand Up @@ -699,6 +731,70 @@ impl StrExt for str {
}
}

/// Inline version of std fn `convert_while_ascii`. `s` must have len <= 23.
#[inline]
fn inline_convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> ([u8; INLINE_CAP], &str) {
// Process the input in chunks of 16 bytes to enable auto-vectorization.
// Previously the chunk size depended on the size of `usize`,
// but on 32-bit platforms with sse or neon is also the better choice.
// The only downside on other platforms would be a bit more loop-unrolling.
const N: usize = 16;

debug_assert!(s.len() <= INLINE_CAP, "only for inline-able strings");

let mut slice = s.as_bytes();
let mut out = [0u8; INLINE_CAP];
let mut out_slice = &mut out[..slice.len()];
let mut is_ascii = [false; N];

while slice.len() >= N {
// SAFETY: checked in loop condition
let chunk = unsafe { slice.get_unchecked(..N) };
// SAFETY: out_slice has at least same length as input slice and gets sliced with the same offsets
let out_chunk = unsafe { out_slice.get_unchecked_mut(..N) };

for j in 0..N {
is_ascii[j] = chunk[j] <= 127;
}

// Auto-vectorization for this check is a bit fragile, sum and comparing against the chunk
// size gives the best result, specifically a pmovmsk instruction on x86.
// See https://github.com/llvm/llvm-project/issues/96395 for why llvm currently does not
// currently recognize other similar idioms.
if is_ascii.iter().map(|x| *x as u8).sum::<u8>() as usize != N {
break;
}

for j in 0..N {
out_chunk[j] = convert(&chunk[j]);
}

slice = unsafe { slice.get_unchecked(N..) };
out_slice = unsafe { out_slice.get_unchecked_mut(N..) };
}

// handle the remainder as individual bytes
while !slice.is_empty() {
let byte = slice[0];
if byte > 127 {
break;
}
// SAFETY: out_slice has at least same length as input slice
unsafe {
*out_slice.get_unchecked_mut(0) = convert(&byte);
}
slice = unsafe { slice.get_unchecked(1..) };
out_slice = unsafe { out_slice.get_unchecked_mut(1..) };
}

unsafe {
// SAFETY: we know this is a valid char boundary
// since we only skipped over leading ascii bytes
let rest = core::str::from_utf8_unchecked(slice);
(out, rest)
}
}

impl<T> ToSmolStr for T
where
T: fmt::Display + ?Sized,
Expand Down Expand Up @@ -848,3 +944,14 @@ impl<'a> arbitrary::Arbitrary<'a> for SmolStr {
mod borsh;
#[cfg(feature = "serde")]
mod serde;

#[test]
fn from_buf_and_chars_size_hinted_heap() {
let str = from_buf_and_chars(
*b"abcdefghijklmnopqr00000",
18,
"_0x1x2x3x4x5x6x7x8x9x10x11x12x13".chars(),
);

assert_eq!(str, "abcdefghijklmnopqr_0x1x2x3x4x5x6x7x8x9x10x11x12x13");
}