Skip to content

Commit 8ae5290

Browse files
committed
Auto merge of #430 - Amanieu:neon2, r=Amanieu
Add NEON backend for RawTable The core algorithm is based on the NEON support in [SwissTable], adapted for the different control byte encodings used in hashbrown. [SwissTable]: abseil/abseil-cpp@6481443
2 parents 6033fa1 + 18ef2d7 commit 8ae5290

File tree

5 files changed

+167
-30
lines changed

5 files changed

+167
-30
lines changed

src/raw/bitmask.rs

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use super::imp::{BitMaskWord, BITMASK_MASK, BITMASK_STRIDE};
1+
use super::imp::{BitMaskWord, BITMASK_ITER_MASK, BITMASK_MASK, BITMASK_STRIDE};
22
#[cfg(feature = "nightly")]
33
use core::intrinsics;
44

@@ -8,11 +8,16 @@ use core::intrinsics;
88
/// The bit mask is arranged so that low-order bits represent lower memory
99
/// addresses for group match results.
1010
///
11-
/// For implementation reasons, the bits in the set may be sparsely packed, so
12-
/// that there is only one bit-per-byte used (the high bit, 7). If this is the
11+
/// For implementation reasons, the bits in the set may be sparsely packed with
12+
/// groups of 8 bits representing one element. If any of these bits are non-zero
13+
/// then this element is considered to true in the mask. If this is the
1314
/// case, `BITMASK_STRIDE` will be 8 to indicate a divide-by-8 should be
1415
/// performed on counts/indices to normalize this difference. `BITMASK_MASK` is
1516
/// similarly a mask of all the actually-used bits.
17+
///
18+
/// To iterate over a bit mask, it must be converted to a form where only 1 bit
19+
/// is set per element. This is done by applying `BITMASK_ITER_MASK` on the
20+
/// mask bits.
1621
#[derive(Copy, Clone)]
1722
pub(crate) struct BitMask(pub(crate) BitMaskWord);
1823

@@ -21,30 +26,18 @@ impl BitMask {
2126
/// Returns a new `BitMask` with all bits inverted.
2227
#[inline]
2328
#[must_use]
29+
#[allow(dead_code)]
2430
pub(crate) fn invert(self) -> Self {
2531
BitMask(self.0 ^ BITMASK_MASK)
2632
}
2733

28-
/// Flip the bit in the mask for the entry at the given index.
29-
///
30-
/// Returns the bit's previous state.
31-
#[inline]
32-
#[allow(clippy::cast_ptr_alignment)]
33-
#[cfg(feature = "raw")]
34-
pub(crate) unsafe fn flip(&mut self, index: usize) -> bool {
35-
// NOTE: The + BITMASK_STRIDE - 1 is to set the high bit.
36-
let mask = 1 << (index * BITMASK_STRIDE + BITMASK_STRIDE - 1);
37-
self.0 ^= mask;
38-
// The bit was set if the bit is now 0.
39-
self.0 & mask == 0
40-
}
41-
4234
/// Returns a new `BitMask` with the lowest bit removed.
4335
#[inline]
4436
#[must_use]
45-
pub(crate) fn remove_lowest_bit(self) -> Self {
37+
fn remove_lowest_bit(self) -> Self {
4638
BitMask(self.0 & (self.0 - 1))
4739
}
40+
4841
/// Returns whether the `BitMask` has at least one set bit.
4942
#[inline]
5043
pub(crate) fn any_bit_set(self) -> bool {
@@ -102,13 +95,32 @@ impl IntoIterator for BitMask {
10295

10396
#[inline]
10497
fn into_iter(self) -> BitMaskIter {
105-
BitMaskIter(self)
98+
// A BitMask only requires each element (group of bits) to be non-zero.
99+
// However for iteration we need each element to only contain 1 bit.
100+
BitMaskIter(BitMask(self.0 & BITMASK_ITER_MASK))
106101
}
107102
}
108103

109104
/// Iterator over the contents of a `BitMask`, returning the indices of set
110105
/// bits.
111-
pub(crate) struct BitMaskIter(BitMask);
106+
#[derive(Copy, Clone)]
107+
pub(crate) struct BitMaskIter(pub(crate) BitMask);
108+
109+
impl BitMaskIter {
110+
/// Flip the bit in the mask for the entry at the given index.
111+
///
112+
/// Returns the bit's previous state.
113+
#[inline]
114+
#[allow(clippy::cast_ptr_alignment)]
115+
#[cfg(feature = "raw")]
116+
pub(crate) unsafe fn flip(&mut self, index: usize) -> bool {
117+
// NOTE: The + BITMASK_STRIDE - 1 is to set the high bit.
118+
let mask = 1 << (index * BITMASK_STRIDE + BITMASK_STRIDE - 1);
119+
self.0 .0 ^= mask;
120+
// The bit was set if the bit is now 0.
121+
self.0 .0 & mask == 0
122+
}
123+
}
112124

113125
impl Iterator for BitMaskIter {
114126
type Item = usize;

src/raw/generic.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pub(crate) const BITMASK_STRIDE: usize = 8;
2525
// We only care about the highest bit of each byte for the mask.
2626
#[allow(clippy::cast_possible_truncation, clippy::unnecessary_cast)]
2727
pub(crate) const BITMASK_MASK: BitMaskWord = 0x8080_8080_8080_8080_u64 as GroupWord;
28+
pub(crate) const BITMASK_ITER_MASK: BitMaskWord = !0;
2829

2930
/// Helper function to replicate a byte across a `GroupWord`.
3031
#[inline]

src/raw/mod.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ cfg_if! {
2525
))] {
2626
mod sse2;
2727
use sse2 as imp;
28+
} else if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] {
29+
mod neon;
30+
use neon as imp;
2831
} else {
29-
#[path = "generic.rs"]
3032
mod generic;
3133
use generic as imp;
3234
}
@@ -37,7 +39,7 @@ pub(crate) use self::alloc::{do_alloc, Allocator, Global};
3739

3840
mod bitmask;
3941

40-
use self::bitmask::{BitMask, BitMaskIter};
42+
use self::bitmask::BitMaskIter;
4143
use self::imp::Group;
4244

4345
// Branch prediction hint. This is currently only available on nightly but it
@@ -2751,7 +2753,7 @@ impl<T, A: Allocator + Clone> IntoIterator for RawTable<T, A> {
27512753
pub(crate) struct RawIterRange<T> {
27522754
// Mask of full buckets in the current group. Bits are cleared from this
27532755
// mask as each element is processed.
2754-
current_group: BitMask,
2756+
current_group: BitMaskIter,
27552757

27562758
// Pointer to the buckets for the current group.
27572759
data: Bucket<T>,
@@ -2779,7 +2781,7 @@ impl<T> RawIterRange<T> {
27792781
let next_ctrl = ctrl.add(Group::WIDTH);
27802782

27812783
Self {
2782-
current_group,
2784+
current_group: current_group.into_iter(),
27832785
data,
27842786
next_ctrl,
27852787
end,
@@ -2836,8 +2838,7 @@ impl<T> RawIterRange<T> {
28362838
#[cfg_attr(feature = "inline-more", inline)]
28372839
unsafe fn next_impl<const DO_CHECK_PTR_RANGE: bool>(&mut self) -> Option<Bucket<T>> {
28382840
loop {
2839-
if let Some(index) = self.current_group.lowest_set_bit() {
2840-
self.current_group = self.current_group.remove_lowest_bit();
2841+
if let Some(index) = self.current_group.next() {
28412842
return Some(self.data.next_n(index));
28422843
}
28432844

@@ -2850,7 +2851,7 @@ impl<T> RawIterRange<T> {
28502851
// than the group size where the trailing control bytes are all
28512852
// EMPTY. On larger tables self.end is guaranteed to be aligned
28522853
// to the group size (since tables are power-of-two sized).
2853-
self.current_group = Group::load_aligned(self.next_ctrl).match_full();
2854+
self.current_group = Group::load_aligned(self.next_ctrl).match_full().into_iter();
28542855
self.data = self.data.next_n(Group::WIDTH);
28552856
self.next_ctrl = self.next_ctrl.add(Group::WIDTH);
28562857
}
@@ -2990,7 +2991,7 @@ impl<T> RawIter<T> {
29902991
// - Otherwise, update the iterator cached group so that it won't
29912992
// yield a to-be-removed bucket, or _will_ yield a to-be-added bucket.
29922993
// We'll also need to update the item count accordingly.
2993-
if let Some(index) = self.iter.current_group.lowest_set_bit() {
2994+
if let Some(index) = self.iter.current_group.0.lowest_set_bit() {
29942995
let next_bucket = self.iter.data.next_n(index);
29952996
if b.as_ptr() > next_bucket.as_ptr() {
29962997
// The toggled bucket is "before" the bucket the iterator would yield next. We
@@ -3023,10 +3024,10 @@ impl<T> RawIter<T> {
30233024
if cfg!(debug_assertions) {
30243025
if b.as_ptr() == next_bucket.as_ptr() {
30253026
// The removed bucket should no longer be next
3026-
debug_assert_ne!(self.iter.current_group.lowest_set_bit(), Some(index));
3027+
debug_assert_ne!(self.iter.current_group.0.lowest_set_bit(), Some(index));
30273028
} else {
30283029
// We should not have changed what bucket comes next.
3029-
debug_assert_eq!(self.iter.current_group.lowest_set_bit(), Some(index));
3030+
debug_assert_eq!(self.iter.current_group.0.lowest_set_bit(), Some(index));
30303031
}
30313032
}
30323033
}

src/raw/neon.rs

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
use super::bitmask::BitMask;
2+
use super::EMPTY;
3+
use core::arch::aarch64 as neon;
4+
use core::mem;
5+
6+
pub(crate) type BitMaskWord = u64;
7+
pub(crate) const BITMASK_STRIDE: usize = 8;
8+
pub(crate) const BITMASK_MASK: BitMaskWord = !0;
9+
pub(crate) const BITMASK_ITER_MASK: BitMaskWord = 0x8080_8080_8080_8080;
10+
11+
/// Abstraction over a group of control bytes which can be scanned in
12+
/// parallel.
13+
///
14+
/// This implementation uses a 64-bit NEON value.
15+
#[derive(Copy, Clone)]
16+
pub(crate) struct Group(neon::uint8x8_t);
17+
18+
#[allow(clippy::use_self)]
19+
impl Group {
20+
/// Number of bytes in the group.
21+
pub(crate) const WIDTH: usize = mem::size_of::<Self>();
22+
23+
/// Returns a full group of empty bytes, suitable for use as the initial
24+
/// value for an empty hash table.
25+
///
26+
/// This is guaranteed to be aligned to the group size.
27+
#[inline]
28+
pub(crate) const fn static_empty() -> &'static [u8; Group::WIDTH] {
29+
#[repr(C)]
30+
struct AlignedBytes {
31+
_align: [Group; 0],
32+
bytes: [u8; Group::WIDTH],
33+
}
34+
const ALIGNED_BYTES: AlignedBytes = AlignedBytes {
35+
_align: [],
36+
bytes: [EMPTY; Group::WIDTH],
37+
};
38+
&ALIGNED_BYTES.bytes
39+
}
40+
41+
/// Loads a group of bytes starting at the given address.
42+
#[inline]
43+
#[allow(clippy::cast_ptr_alignment)] // unaligned load
44+
pub(crate) unsafe fn load(ptr: *const u8) -> Self {
45+
Group(neon::vld1_u8(ptr))
46+
}
47+
48+
/// Loads a group of bytes starting at the given address, which must be
49+
/// aligned to `mem::align_of::<Group>()`.
50+
#[inline]
51+
#[allow(clippy::cast_ptr_alignment)]
52+
pub(crate) unsafe fn load_aligned(ptr: *const u8) -> Self {
53+
// FIXME: use align_offset once it stabilizes
54+
debug_assert_eq!(ptr as usize & (mem::align_of::<Self>() - 1), 0);
55+
Group(neon::vld1_u8(ptr))
56+
}
57+
58+
/// Stores the group of bytes to the given address, which must be
59+
/// aligned to `mem::align_of::<Group>()`.
60+
#[inline]
61+
#[allow(clippy::cast_ptr_alignment)]
62+
pub(crate) unsafe fn store_aligned(self, ptr: *mut u8) {
63+
// FIXME: use align_offset once it stabilizes
64+
debug_assert_eq!(ptr as usize & (mem::align_of::<Self>() - 1), 0);
65+
neon::vst1_u8(ptr, self.0);
66+
}
67+
68+
/// Returns a `BitMask` indicating all bytes in the group which *may*
69+
/// have the given value.
70+
#[inline]
71+
pub(crate) fn match_byte(self, byte: u8) -> BitMask {
72+
unsafe {
73+
let cmp = neon::vceq_u8(self.0, neon::vdup_n_u8(byte));
74+
BitMask(neon::vget_lane_u64(neon::vreinterpret_u64_u8(cmp), 0))
75+
}
76+
}
77+
78+
/// Returns a `BitMask` indicating all bytes in the group which are
79+
/// `EMPTY`.
80+
#[inline]
81+
pub(crate) fn match_empty(self) -> BitMask {
82+
self.match_byte(EMPTY)
83+
}
84+
85+
/// Returns a `BitMask` indicating all bytes in the group which are
86+
/// `EMPTY` or `DELETED`.
87+
#[inline]
88+
pub(crate) fn match_empty_or_deleted(self) -> BitMask {
89+
unsafe {
90+
let cmp = neon::vcltz_s8(neon::vreinterpret_s8_u8(self.0));
91+
BitMask(neon::vget_lane_u64(neon::vreinterpret_u64_u8(cmp), 0))
92+
}
93+
}
94+
95+
/// Returns a `BitMask` indicating all bytes in the group which are full.
96+
#[inline]
97+
pub(crate) fn match_full(self) -> BitMask {
98+
unsafe {
99+
let cmp = neon::vcgez_s8(neon::vreinterpret_s8_u8(self.0));
100+
BitMask(neon::vget_lane_u64(neon::vreinterpret_u64_u8(cmp), 0))
101+
}
102+
}
103+
104+
/// Performs the following transformation on all bytes in the group:
105+
/// - `EMPTY => EMPTY`
106+
/// - `DELETED => EMPTY`
107+
/// - `FULL => DELETED`
108+
#[inline]
109+
pub(crate) fn convert_special_to_empty_and_full_to_deleted(self) -> Self {
110+
// Map high_bit = 1 (EMPTY or DELETED) to 1111_1111
111+
// and high_bit = 0 (FULL) to 1000_0000
112+
//
113+
// Here's this logic expanded to concrete values:
114+
// let special = 0 > byte = 1111_1111 (true) or 0000_0000 (false)
115+
// 1111_1111 | 1000_0000 = 1111_1111
116+
// 0000_0000 | 1000_0000 = 1000_0000
117+
unsafe {
118+
let special = neon::vcltz_s8(neon::vreinterpret_s8_u8(self.0));
119+
Group(neon::vorr_u8(special, neon::vdup_n_u8(0x80)))
120+
}
121+
}
122+
}

src/raw/sse2.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use core::arch::x86_64 as x86;
1010
pub(crate) type BitMaskWord = u16;
1111
pub(crate) const BITMASK_STRIDE: usize = 1;
1212
pub(crate) const BITMASK_MASK: BitMaskWord = 0xffff;
13+
pub(crate) const BITMASK_ITER_MASK: BitMaskWord = !0;
1314

1415
/// Abstraction over a group of control bytes which can be scanned in
1516
/// parallel.

0 commit comments

Comments
 (0)