diff --git a/library/alloctests/tests/lib.rs b/library/alloctests/tests/lib.rs index 8c3ce156f3c1d..ad049a51bce14 100644 --- a/library/alloctests/tests/lib.rs +++ b/library/alloctests/tests/lib.rs @@ -25,6 +25,8 @@ #![feature(iter_next_chunk)] #![feature(slice_partition_dedup)] #![feature(string_from_utf8_lossy_owned)] +#![feature(str_internals)] +#![feature(char_internals)] #![feature(string_remove_matches)] #![feature(const_btree_len)] #![feature(const_trait_impl)] diff --git a/library/alloctests/tests/str.rs b/library/alloctests/tests/str.rs index 906fa2d425e77..e54d7920f843f 100644 --- a/library/alloctests/tests/str.rs +++ b/library/alloctests/tests/str.rs @@ -1156,6 +1156,84 @@ fn test_total_ord() { assert_eq!("22".cmp("1234"), Greater); } +// There are only 1,114,112 code points (including surrogates for WTF-8). So we +// can test `next_code_point` and `next_code_point_reverse` exhaustively on all +// possible inputs. + +/// Assert that encoding a codepoint with `encode_utf8_raw` and then decoding it +/// with `next_code_point` preserves the codepoint. +fn test_next_code_point(codepoint: u32) { + let mut bytes = [0; 4]; + let mut bytes = std::char::encode_utf8_raw(codepoint, &mut bytes).iter(); + + // SAFETY: `bytes` is UTF8-like + let got = unsafe { core::str::next_code_point(&mut bytes) }; + assert_eq!(got, Some(codepoint)); + + // SAFETY: `bytes` is UTF8-like + let got = unsafe { core::str::next_code_point(&mut bytes) }; + assert_eq!(got, None); +} + +/// The same but for `next_code_point_reverse`. +fn test_next_code_point_reverse(codepoint: u32) { + let mut bytes = [0; 4]; + let mut bytes = std::char::encode_utf8_raw(codepoint, &mut bytes).iter(); + + // SAFETY: `bytes` is UTF8-like + let got = unsafe { core::str::next_code_point_reverse(&mut bytes) }; + assert_eq!(got, Some(codepoint)); + + // SAFETY: `bytes` is UTF8-like + let got = unsafe { core::str::next_code_point_reverse(&mut bytes) }; + assert_eq!(got, None); +} + +#[test] +#[cfg_attr(miri, ignore)] // Disabled on Miri because it is too slow +fn test_next_code_point_exhaustive() { + for c in 0..=u32::from(char::MAX) { + test_next_code_point(c); + } +} + +#[test] +#[cfg_attr(miri, ignore)] // Disabled on Miri because it is too slow +fn test_next_code_point_reverse_exhaustive() { + for c in 0..=u32::from(char::MAX) { + test_next_code_point_reverse(c); + } +} + +#[rustfmt::skip] +const CODEPOINT_BOUNDARIES: &[u32] = &[ + // 1 byte codepoints (U+0000 ..= U+007F): + 0x0000, 0x007F, + + // 2 byte codepoints (U+0080 ..= U+07FF): + 0x0080, 0x07FF, + + // 3 byte codepoints (U+0800 ..= U+FFFF): + 0800, 0xFFFF, + + // 4 byte codepoints (U+01_0000 ..= U+10_FFFF): + 0x01_0000, 0x10_FFFF, +]; + +#[test] +fn test_next_code_point_boundary_conditions() { + for c in CODEPOINT_BOUNDARIES { + test_next_code_point(*c); + } +} + +#[test] +fn test_next_code_point_reverse_boundary_conditions() { + for c in CODEPOINT_BOUNDARIES { + test_next_code_point_reverse(*c); + } +} + #[test] fn test_iterator() { let s = "ศไทย中华Việt Nam"; diff --git a/library/core/src/str/mod.rs b/library/core/src/str/mod.rs index 04fdaa8143eff..f5c621c36e2d8 100644 --- a/library/core/src/str/mod.rs +++ b/library/core/src/str/mod.rs @@ -58,7 +58,7 @@ pub use lossy::{Utf8Chunk, Utf8Chunks}; #[stable(feature = "rust1", since = "1.0.0")] pub use traits::FromStr; #[unstable(feature = "str_internals", issue = "none")] -pub use validations::{next_code_point, utf8_char_width}; +pub use validations::{next_code_point, next_code_point_reverse, utf8_char_width}; #[inline(never)] #[cold] diff --git a/library/core/src/str/validations.rs b/library/core/src/str/validations.rs index b54d6478e584d..c5794eaa41936 100644 --- a/library/core/src/str/validations.rs +++ b/library/core/src/str/validations.rs @@ -1,22 +1,9 @@ //! Operations related to UTF-8 validation. use super::Utf8Error; +use crate::hint::assert_unchecked; use crate::intrinsics::const_eval_select; -/// Returns the initial codepoint accumulator for the first byte. -/// The first byte is special, only want bottom 5 bits for width 2, 4 bits -/// for width 3, and 3 bits for width 4. -#[inline] -const fn utf8_first_byte(byte: u8, width: u32) -> u32 { - (byte & (0x7F >> width)) as u32 -} - -/// Returns the value of `ch` updated with continuation byte `byte`. -#[inline] -const fn utf8_acc_cont_byte(ch: u32, byte: u8) -> u32 { - (ch << 6) | (byte & CONT_MASK) as u32 -} - /// Checks whether the byte is a UTF-8 continuation byte (i.e., starts with the /// bits `10`). #[inline] @@ -33,39 +20,49 @@ pub(super) const fn utf8_is_cont_byte(byte: u8) -> bool { #[unstable(feature = "str_internals", issue = "none")] #[inline] pub unsafe fn next_code_point<'a, I: Iterator>(bytes: &mut I) -> Option { - // Decode UTF-8 - let x = *bytes.next()?; - if x < 128 { - return Some(x as u32); + let b1 = *bytes.next()?; + if b1 < 0x80 { + // 1 byte case (U+0000 ..= U+007F): + // c = b1 + return Some(u32::from(b1)); } - // Multibyte case follows - // Decode from a byte combination out of: [[[x y] z] w] - // NOTE: Performance is sensitive to the exact formulation here - let init = utf8_first_byte(x, 2); - // SAFETY: `bytes` produces an UTF-8-like string, - // so the iterator must produce a value here. - let y = unsafe { *bytes.next().unwrap_unchecked() }; - let mut ch = utf8_acc_cont_byte(init, y); - if x >= 0xE0 { - // [[x y z] w] case - // 5th bit in 0xE0 .. 0xEF is always clear, so `init` is still valid - // SAFETY: `bytes` produces an UTF-8-like string, - // so the iterator must produce a value here. - let z = unsafe { *bytes.next().unwrap_unchecked() }; - let y_z = utf8_acc_cont_byte((y & CONT_MASK) as u32, z); - ch = init << 12 | y_z; - if x >= 0xF0 { - // [x y z w] case - // use only the lower 3 bits of `init` - // SAFETY: `bytes` produces an UTF-8-like string, - // so the iterator must produce a value here. - let w = unsafe { *bytes.next().unwrap_unchecked() }; - ch = (init & 7) << 18 | utf8_acc_cont_byte(y_z, w); - } + // SAFETY: `bytes` produces a UTF-8-like string + let mut next_byte = || unsafe { + let b = *bytes.next().unwrap_unchecked(); + assert_unchecked(utf8_is_cont_byte(b)); + b + }; + let combine = |c: u32, byte: u8| c << 6 | u32::from(byte & CONT_MASK); + + let b2 = next_byte(); + let c = u32::from(b1 & 0x1F); + let c = combine(c, b2); + if b1 < 0xE0 { + // 2 byte case (U+0080 ..= U+07FF): + // c = (b1 & 0x1F) << 6 + // | (b2 & 0x3F) << 0 + return Some(c); } - Some(ch) + let b3 = next_byte(); + let c = combine(c, b3); + if b1 < 0xF0 { + // 3 byte case (U+0800 ..= U+FFFF): + // c = (b1 & 0x1F) << 12 + // | (b2 & 0x3F) << 6 + // | (b3 & 0x3F) << 0 + return Some(c); + } + + let b4 = next_byte(); + let c = combine(c, b4); + // 4 byte case (U+01_0000 ..= U+10_FFFF): + // c = ((b1 & 0x1F) << 18 + // | (b2 & 0x3F) << 12 + // | (b3 & 0x3F) << 6 + // | (b4 & 0x3F) << 0) & 0x1F_FFFF + Some(c & 0x1F_FFFF) } /// Reads the last code point out of a byte iterator (assuming a @@ -74,41 +71,55 @@ pub unsafe fn next_code_point<'a, I: Iterator>(bytes: &mut I) -> /// # Safety /// /// `bytes` must produce a valid UTF-8-like (UTF-8 or WTF-8) string +#[unstable(feature = "str_internals", issue = "none")] #[inline] -pub(super) unsafe fn next_code_point_reverse<'a, I>(bytes: &mut I) -> Option +pub unsafe fn next_code_point_reverse<'a, I>(bytes: &mut I) -> Option where I: DoubleEndedIterator, { - // Decode UTF-8 - let w = match *bytes.next_back()? { - next_byte if next_byte < 128 => return Some(next_byte as u32), - back_byte => back_byte, + let b1 = *bytes.next_back()?; + if b1 < 0x80 { + // 1 byte case (U+0000 ..= U+007F): + // c = b1 + return Some(u32::from(b1)); + } + + // SAFETY: `bytes` produces a UTF-8-like string + let mut next_byte = || unsafe { + let b = *bytes.next_back().unwrap_unchecked(); + assert_unchecked(!b.is_ascii()); + b }; + let combine = |c: u32, byte: u8, shift| c | u32::from(byte & CONT_MASK) << shift; - // Multibyte case follows - // Decode from a byte combination out of: [x [y [z w]]] - let mut ch; - // SAFETY: `bytes` produces an UTF-8-like string, - // so the iterator must produce a value here. - let z = unsafe { *bytes.next_back().unwrap_unchecked() }; - ch = utf8_first_byte(z, 2); - if utf8_is_cont_byte(z) { - // SAFETY: `bytes` produces an UTF-8-like string, - // so the iterator must produce a value here. - let y = unsafe { *bytes.next_back().unwrap_unchecked() }; - ch = utf8_first_byte(y, 3); - if utf8_is_cont_byte(y) { - // SAFETY: `bytes` produces an UTF-8-like string, - // so the iterator must produce a value here. - let x = unsafe { *bytes.next_back().unwrap_unchecked() }; - ch = utf8_first_byte(x, 4); - ch = utf8_acc_cont_byte(ch, y); - } - ch = utf8_acc_cont_byte(ch, z); + let b2 = next_byte(); + let c = u32::from(b1 & CONT_MASK); + let c = combine(c, b2, 6); + if !utf8_is_cont_byte(b2) { + // 2 byte case (U+0080 ..= U+07FF): + // c = (b2 & 0x3F) << 6 + // | (b1 & 0x3F) << 0 + return Some(c); + } + + let b3 = next_byte(); + let c = combine(c, b3, 12); + if !utf8_is_cont_byte(b3) { + // 3 byte case (U+0800 ..= U+FFFF): + // c = ((b3 & 0x3F) << 12 + // | (b2 & 0x3F) << 6 + // | (b1 & 0x3F) << 0) & 0xFFFF + return Some(c & 0xFFFF); } - ch = utf8_acc_cont_byte(ch, w); - Some(ch) + let b4 = next_byte(); + let c = combine(c, b4, 18); + // 4 byte case (U+01_0000 ..= U+10_FFFF): + // c = ((b4 & 0x3F) << 18 + // | (b3 & 0x3F) << 12 + // | (b2 & 0x3F) << 6 + // | (b1 & 0x3F) << 0) & 0x1F_FFFF + Some(c & 0x1F_FFFF) } const NONASCII_MASK: usize = usize::repeat_u8(0x80); @@ -279,5 +290,5 @@ pub const fn utf8_char_width(b: u8) -> usize { UTF8_CHAR_WIDTH[b as usize] as usize } -/// Mask of the value bits of a continuation byte. +/// Mask of the value bits of a continuation byte (ie the lowest 6 bits). const CONT_MASK: u8 = 0b0011_1111; diff --git a/library/coretests/benches/str/iter.rs b/library/coretests/benches/str/iter.rs index d2586cef25871..3bc0f097c1154 100644 --- a/library/coretests/benches/str/iter.rs +++ b/library/coretests/benches/str/iter.rs @@ -16,3 +16,59 @@ fn chars_advance_by_0010(b: &mut Bencher) { fn chars_advance_by_0001(b: &mut Bencher) { b.iter(|| black_box(corpora::ru::LARGE).chars().advance_by(1)); } + +mod chars_sum { + use super::*; + + fn bench(b: &mut Bencher, corpus: &str) { + b.iter(|| corpus.chars().map(|c| c as u32).sum::()) + } + + #[bench] + fn en(b: &mut Bencher) { + bench(b, corpora::en::HUGE); + } + + #[bench] + fn zh(b: &mut Bencher) { + bench(b, corpora::zh::HUGE); + } + + #[bench] + fn ru(b: &mut Bencher) { + bench(b, corpora::zh::HUGE); + } + + #[bench] + fn emoji(b: &mut Bencher) { + bench(b, corpora::zh::HUGE); + } +} + +mod chars_sum_rev { + use super::*; + + fn bench(b: &mut Bencher, corpus: &str) { + b.iter(|| corpus.chars().rev().map(|c| c as u32).sum::()) + } + + #[bench] + fn en(b: &mut Bencher) { + bench(b, corpora::en::HUGE); + } + + #[bench] + fn zh(b: &mut Bencher) { + bench(b, corpora::zh::HUGE); + } + + #[bench] + fn ru(b: &mut Bencher) { + bench(b, corpora::zh::HUGE); + } + + #[bench] + fn emoji(b: &mut Bencher) { + bench(b, corpora::zh::HUGE); + } +}