Skip to content

Commit 38eb248

Browse files
Auto merge of #142038 - Kmeakin:km/optimize-str-chars-iterator, r=<try>
Optimize `std::str::Chars::next` and `std::str::Chars::next_back`
2 parents 91edc3e + 54a699b commit 38eb248

File tree

4 files changed

+183
-73
lines changed

4 files changed

+183
-73
lines changed

library/alloctests/tests/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
#![feature(iter_next_chunk)]
2727
#![feature(slice_partition_dedup)]
2828
#![feature(string_from_utf8_lossy_owned)]
29+
#![feature(str_internals)]
30+
#![feature(char_internals)]
2931
#![feature(string_remove_matches)]
3032
#![feature(const_btree_len)]
3133
#![feature(const_trait_impl)]

library/alloctests/tests/str.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,99 @@ fn test_total_ord() {
11561156
assert_eq!("22".cmp("1234"), Greater);
11571157
}
11581158

1159+
// There are only 1,114,112 code points (including surrogates for WTF-8). So we
1160+
// can test `next_code_point` and `next_code_point_reverse` exhaustively on all
1161+
// possible inputs.
1162+
1163+
/// Assert that encoding a codepoint with `encode_utf8_raw` and then decoding it
1164+
/// with `next_code_point` preserves the codepoint.
1165+
fn test_next_code_point(codepoint: u32) {
1166+
let mut bytes = [0; 4];
1167+
let mut bytes = std::char::encode_utf8_raw(codepoint, &mut bytes).iter();
1168+
1169+
// SAFETY: `bytes` is UTF8-like
1170+
let got = unsafe { core::str::next_code_point(&mut bytes) };
1171+
assert_eq!(got, Some(codepoint));
1172+
1173+
// SAFETY: `bytes` is UTF8-like
1174+
let got = unsafe { core::str::next_code_point(&mut bytes) };
1175+
assert_eq!(got, None);
1176+
}
1177+
1178+
/// The same but for `next_code_point_reverse`.
1179+
fn test_next_code_point_reverse(codepoint: u32) {
1180+
let mut bytes = [0; 4];
1181+
let mut bytes = std::char::encode_utf8_raw(codepoint, &mut bytes).iter();
1182+
1183+
// SAFETY: `bytes` is UTF8-like
1184+
let got = unsafe { core::str::next_code_point_reverse(&mut bytes) };
1185+
assert_eq!(got, Some(codepoint));
1186+
1187+
// SAFETY: `bytes` is UTF8-like
1188+
let got = unsafe { core::str::next_code_point_reverse(&mut bytes) };
1189+
assert_eq!(got, None);
1190+
}
1191+
1192+
#[test]
1193+
fn test_next_code_point_1byte() {
1194+
for c in 0..0x80 {
1195+
test_next_code_point(c);
1196+
}
1197+
}
1198+
1199+
#[test]
1200+
fn test_next_code_point_2byte() {
1201+
for c in 0x80..0x800 {
1202+
test_next_code_point(c);
1203+
}
1204+
}
1205+
1206+
#[test]
1207+
#[cfg(not(miri))] // Disabled on Miri because it is too slow
1208+
fn test_next_code_point_3byte() {
1209+
for c in 0x800..0x10_000 {
1210+
test_next_code_point(c);
1211+
}
1212+
}
1213+
1214+
#[test]
1215+
// #[cfg(not(miri))] // Disabled on Miri because it is too slow
1216+
fn test_next_code_point_4byte() {
1217+
for c in 0x10_000..=u32::from(char::MAX) {
1218+
test_next_code_point(c);
1219+
}
1220+
}
1221+
1222+
#[test]
1223+
fn test_next_code_point_reverse_1byte() {
1224+
for c in 0..0x80 {
1225+
test_next_code_point_reverse(c);
1226+
}
1227+
}
1228+
1229+
#[test]
1230+
fn test_next_code_point_reverse_2byte() {
1231+
for c in 0x80..0x800 {
1232+
test_next_code_point_reverse(c);
1233+
}
1234+
}
1235+
1236+
#[test]
1237+
#[cfg(not(miri))] // Disabled on Miri because it is too slow
1238+
fn test_next_code_point_reverse_3byte() {
1239+
for c in 0x800..0x10_000 {
1240+
test_next_code_point_reverse(c);
1241+
}
1242+
}
1243+
1244+
#[test]
1245+
#[cfg(not(miri))] // Disabled on Miri because it is too slow
1246+
fn test_next_code_point_reverse_4byte() {
1247+
for c in 0x10_000..=u32::from(char::MAX) {
1248+
test_next_code_point_reverse(c);
1249+
}
1250+
}
1251+
11591252
#[test]
11601253
fn test_iterator() {
11611254
let s = "ศไทย中华Việt Nam";

library/core/src/str/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ pub use lossy::{Utf8Chunk, Utf8Chunks};
5858
#[stable(feature = "rust1", since = "1.0.0")]
5959
pub use traits::FromStr;
6060
#[unstable(feature = "str_internals", issue = "none")]
61-
pub use validations::{next_code_point, utf8_char_width};
61+
pub use validations::{next_code_point, next_code_point_reverse, utf8_char_width};
6262

6363
#[inline(never)]
6464
#[cold]

library/core/src/str/validations.rs

Lines changed: 87 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,7 @@
11
//! Operations related to UTF-8 validation.
22
33
use super::Utf8Error;
4-
use crate::intrinsics::const_eval_select;
5-
6-
/// Returns the initial codepoint accumulator for the first byte.
7-
/// The first byte is special, only want bottom 5 bits for width 2, 4 bits
8-
/// for width 3, and 3 bits for width 4.
9-
#[inline]
10-
const fn utf8_first_byte(byte: u8, width: u32) -> u32 {
11-
(byte & (0x7F >> width)) as u32
12-
}
13-
14-
/// Returns the value of `ch` updated with continuation byte `byte`.
15-
#[inline]
16-
const fn utf8_acc_cont_byte(ch: u32, byte: u8) -> u32 {
17-
(ch << 6) | (byte & CONT_MASK) as u32
18-
}
4+
use crate::intrinsics::{assume, const_eval_select, disjoint_bitor};
195

206
/// Checks whether the byte is a UTF-8 continuation byte (i.e., starts with the
217
/// bits `10`).
@@ -33,39 +19,51 @@ pub(super) const fn utf8_is_cont_byte(byte: u8) -> bool {
3319
#[unstable(feature = "str_internals", issue = "none")]
3420
#[inline]
3521
pub unsafe fn next_code_point<'a, I: Iterator<Item = &'a u8>>(bytes: &mut I) -> Option<u32> {
36-
// Decode UTF-8
37-
let x = *bytes.next()?;
38-
if x < 128 {
39-
return Some(x as u32);
22+
let b1 = *bytes.next()?;
23+
if b1 < 0x80 {
24+
// 1 byte case (U+00_00 ..= U+00_7F):
25+
// c = b1
26+
return Some(u32::from(b1));
4027
}
4128

42-
// Multibyte case follows
43-
// Decode from a byte combination out of: [[[x y] z] w]
44-
// NOTE: Performance is sensitive to the exact formulation here
45-
let init = utf8_first_byte(x, 2);
46-
// SAFETY: `bytes` produces an UTF-8-like string,
47-
// so the iterator must produce a value here.
48-
let y = unsafe { *bytes.next().unwrap_unchecked() };
49-
let mut ch = utf8_acc_cont_byte(init, y);
50-
if x >= 0xE0 {
51-
// [[x y z] w] case
52-
// 5th bit in 0xE0 .. 0xEF is always clear, so `init` is still valid
53-
// SAFETY: `bytes` produces an UTF-8-like string,
54-
// so the iterator must produce a value here.
55-
let z = unsafe { *bytes.next().unwrap_unchecked() };
56-
let y_z = utf8_acc_cont_byte((y & CONT_MASK) as u32, z);
57-
ch = init << 12 | y_z;
58-
if x >= 0xF0 {
59-
// [x y z w] case
60-
// use only the lower 3 bits of `init`
61-
// SAFETY: `bytes` produces an UTF-8-like string,
62-
// so the iterator must produce a value here.
63-
let w = unsafe { *bytes.next().unwrap_unchecked() };
64-
ch = (init & 7) << 18 | utf8_acc_cont_byte(y_z, w);
65-
}
29+
// SAFETY: `bytes` produces a UTF-8-like string
30+
let mut next_byte = || unsafe {
31+
let b = *bytes.next().unwrap_unchecked();
32+
assume(utf8_is_cont_byte(b));
33+
b
34+
};
35+
36+
// SAFETY: `bytes` produces a UTF-8-like string
37+
let combine = |c: u32, b: u8| unsafe { disjoint_bitor(c << 6, u32::from(b & CONT_MASK)) };
38+
39+
let b2 = next_byte();
40+
let c = u32::from(b1 & 0x1F);
41+
let c = combine(c, b2);
42+
if b1 < 0xE0 {
43+
// 2 byte case (U+00_80 ..= U+07_FF):
44+
// c = (b1 & 0x1F) << 6
45+
// | (b2 & 0x3F) << 0
46+
return Some(c);
6647
}
6748

68-
Some(ch)
49+
let b3 = next_byte();
50+
let c = combine(c, b3);
51+
if b1 < 0xF0 {
52+
// 3 byte case (U+08_00 ..= U+FF_FF):
53+
// c = (b1 & 0x1F) << 12
54+
// | (b2 & 0x3F) << 6
55+
// | (b3 & 0x3F) << 0
56+
return Some(c);
57+
}
58+
59+
let b4 = next_byte();
60+
let c = combine(c, b4);
61+
// 4 byte case (U+01_00_00 ..= U+10_FF_FF):
62+
// c = ((b1 & 0x1F) << 18
63+
// | (b2 & 0x3F) << 12
64+
// | (b3 & 0x3F) << 6
65+
// | (b4 & 0x3F) << 0) & 0x1F_FF_FF
66+
Some(c & 0x1F_FF_FF)
6967
}
7068

7169
/// Reads the last code point out of a byte iterator (assuming a
@@ -74,41 +72,58 @@ pub unsafe fn next_code_point<'a, I: Iterator<Item = &'a u8>>(bytes: &mut I) ->
7472
/// # Safety
7573
///
7674
/// `bytes` must produce a valid UTF-8-like (UTF-8 or WTF-8) string
75+
#[unstable(feature = "str_internals", issue = "none")]
7776
#[inline]
78-
pub(super) unsafe fn next_code_point_reverse<'a, I>(bytes: &mut I) -> Option<u32>
77+
pub unsafe fn next_code_point_reverse<'a, I>(bytes: &mut I) -> Option<u32>
7978
where
8079
I: DoubleEndedIterator<Item = &'a u8>,
8180
{
82-
// Decode UTF-8
83-
let w = match *bytes.next_back()? {
84-
next_byte if next_byte < 128 => return Some(next_byte as u32),
85-
back_byte => back_byte,
81+
let b1 = *bytes.next_back()?;
82+
if b1 < 0x80 {
83+
// 1 byte case (U+00_00 ..= U+00_7F):
84+
// c = b1
85+
return Some(u32::from(b1));
86+
}
87+
88+
// SAFETY: `bytes` produces a UTF-8-like string
89+
let mut next_byte = || unsafe {
90+
let b = *bytes.next_back().unwrap_unchecked();
91+
assume(!b.is_ascii());
92+
b
8693
};
8794

88-
// Multibyte case follows
89-
// Decode from a byte combination out of: [x [y [z w]]]
90-
let mut ch;
91-
// SAFETY: `bytes` produces an UTF-8-like string,
92-
// so the iterator must produce a value here.
93-
let z = unsafe { *bytes.next_back().unwrap_unchecked() };
94-
ch = utf8_first_byte(z, 2);
95-
if utf8_is_cont_byte(z) {
96-
// SAFETY: `bytes` produces an UTF-8-like string,
97-
// so the iterator must produce a value here.
98-
let y = unsafe { *bytes.next_back().unwrap_unchecked() };
99-
ch = utf8_first_byte(y, 3);
100-
if utf8_is_cont_byte(y) {
101-
// SAFETY: `bytes` produces an UTF-8-like string,
102-
// so the iterator must produce a value here.
103-
let x = unsafe { *bytes.next_back().unwrap_unchecked() };
104-
ch = utf8_first_byte(x, 4);
105-
ch = utf8_acc_cont_byte(ch, y);
106-
}
107-
ch = utf8_acc_cont_byte(ch, z);
95+
// SAFETY: `bytes` produces a UTF-8-like string
96+
let combine = |c: u32, b: u8, n| unsafe { disjoint_bitor(c, u32::from(b & CONT_MASK) << n) };
97+
98+
let b2 = next_byte();
99+
let c = u32::from(b1 & CONT_MASK);
100+
let c = combine(c, b2, 6);
101+
if !utf8_is_cont_byte(b2) {
102+
// 2 byte case (U+00_80 ..= U+07_FF):
103+
// c = (b2 & 0x3F) << 6
104+
// | (b1 & 0x3F) << 0
105+
return Some(c);
106+
}
107+
108+
let b3 = next_byte();
109+
let c = combine(c, b3, 12);
110+
if !utf8_is_cont_byte(b3) {
111+
// 3 byte case (U+08_00 ..= U+FF_FF):
112+
// c = ((b3 & 0x3F) << 12
113+
// | (b2 & 0x3F) << 6
114+
// | (b1 & 0x3F) << 0) & 0xFF_FF
115+
return Some(c & 0xFF_FF);
108116
}
109-
ch = utf8_acc_cont_byte(ch, w);
110117

111-
Some(ch)
118+
let b4 = next_byte();
119+
let c = combine(c, b4, 18);
120+
// let c = c | u32::from(b4 & CONT_MASK) << 18;
121+
// 4 byte case (U+01_00_00 ..= U+10_FF_FF):
122+
// c = ((b4 & 0x3F) << 18
123+
// | (b3 & 0x3F) << 12
124+
// | (b2 & 0x3F) << 6
125+
// | (b1 & 0x3F) << 0) & 0x1F_FF_FF
126+
Some(c & 0x1F_FF_FF)
112127
}
113128

114129
const NONASCII_MASK: usize = usize::repeat_u8(0x80);
@@ -279,5 +294,5 @@ pub const fn utf8_char_width(b: u8) -> usize {
279294
UTF8_CHAR_WIDTH[b as usize] as usize
280295
}
281296

282-
/// Mask of the value bits of a continuation byte.
297+
/// Mask of the value bits of a continuation byte (ie the lowest 6 bits).
283298
const CONT_MASK: u8 = 0b0011_1111;

0 commit comments

Comments
 (0)