diff --git a/src/bst_slice.rs b/src/bst_slice.rs new file mode 100644 index 0000000..47c01ae --- /dev/null +++ b/src/bst_slice.rs @@ -0,0 +1,31 @@ +#[cfg(target_pointer_width = "64")] +use crate::BST_BITS; +use crate::{BitSliceType, INLINE_SLICE_PARTS, SmolBitSet}; + +pub enum BstSlice<'a> { + Inline([BitSliceType; INLINE_SLICE_PARTS]), + Heap(&'a [BitSliceType]), +} + +impl<'a> BstSlice<'a> { + pub fn new(sbs: &'a SmolBitSet) -> Self { + if sbs.is_inline() { + let data = unsafe { sbs.get_inline_data_unchecked() }; + #[cfg(target_pointer_width = "32")] + let data = [data as BitSliceType]; + #[cfg(target_pointer_width = "64")] + let data = [(data as BitSliceType), ((data >> BST_BITS) as BitSliceType)]; + return Self::Inline(data); + } else { + let slice = unsafe { sbs.as_slice_unchecked() }; + return Self::Heap(slice); + } + } + + pub fn slice(&self) -> &[BitSliceType] { + match self { + Self::Inline(items) => items, + Self::Heap(items) => *items, + } + } +} diff --git a/src/cmp.rs b/src/cmp.rs index 1c13ad7..d52d095 100644 --- a/src/cmp.rs +++ b/src/cmp.rs @@ -1,21 +1,24 @@ -use crate::SmolBitSet; +use crate::{ BitSliceType, SmolBitSet}; +use crate::bst_slice::BstSlice; -use core::cmp; +use core::{cmp, iter}; impl cmp::PartialEq for SmolBitSet { fn eq(&self, other: &Self) -> bool { - match (self.len(), other.len()) { - (0, 0) => unsafe { - self.get_inline_data_unchecked() == other.get_inline_data_unchecked() - }, - (a, b) if a == b => { - let a = unsafe { self.as_slice_unchecked() }; - let b = unsafe { other.as_slice_unchecked() }; - - a == b - } - _ => false, - } + let this = BstSlice::new(self); + let other = BstSlice::new(other); + + let this = this.slice(); + let other = other.slice(); + + let (long, short) = if this.len() >= other.len() { + (this, other) + } else { + (other, this) + }; + + let (prefix, suffix) = long.split_at(short.len()); + prefix == short && suffix.iter().all(|&x| x == 0) } } @@ -29,27 +32,32 @@ impl cmp::PartialOrd for SmolBitSet { impl cmp::Ord for SmolBitSet { fn cmp(&self, other: &Self) -> cmp::Ordering { - match (self.len(), other.len()) { - (0, 0) => unsafe { - self.get_inline_data_unchecked() - .cmp(&other.get_inline_data_unchecked()) - }, - (0, _) => cmp::Ordering::Less, - (_, 0) => cmp::Ordering::Greater, - (a, b) if a == b => unsafe { - let a = self.as_slice_unchecked(); - let b = other.as_slice_unchecked(); - - for (a, b) in a.iter().zip(b.iter()).rev() { - let cmp = a.cmp(b); - if cmp != cmp::Ordering::Equal { - return cmp; - } + fn inner(long: &[BitSliceType], short: &[BitSliceType]) -> cmp::Ordering { + let (prefix, suffix) = long.split_at(short.len()); + if suffix.iter().any(|&x| x != 0) { + return cmp::Ordering::Greater; + } + + for (a, b) in iter::zip(prefix, short).rev() { + let cmp = a.cmp(b); + if cmp != cmp::Ordering::Equal { + return cmp; } + } + + cmp::Ordering::Equal + } + + let this = BstSlice::new(self); + let other = BstSlice::new(other); + + let this = this.slice(); + let other = other.slice(); - cmp::Ordering::Equal - }, - (a, b) => a.cmp(&b), + if this.len() >= other.len() { + inner(this, other) + } else { + inner(other, this).reverse() } } } diff --git a/src/fmt.rs b/src/fmt.rs index bbd6d49..6cb4ea6 100644 --- a/src/fmt.rs +++ b/src/fmt.rs @@ -1,4 +1,5 @@ -use crate::{BST_BITS, BitSliceType, SmolBitSet}; +use crate::bst_slice::BstSlice; +use crate::{BST_BITS, SmolBitSet}; #[cfg(feature = "std")] use std::fmt; @@ -10,14 +11,7 @@ use fmt::{Binary, Debug, Display, Formatter, LowerHex, Octal, Result, UpperHex}; impl Debug for SmolBitSet { fn fmt(&self, f: &mut Formatter<'_>) -> Result { - let data = if self.is_inline() { - let d = unsafe { self.get_inline_data_unchecked() }; - &[d as BitSliceType, (d >> BST_BITS) as BitSliceType] - } else { - unsafe { self.as_slice_unchecked() } - }; - - f.debug_list().entries(data).finish() + f.debug_list().entries(BstSlice::new(self).slice()).finish() } } diff --git a/src/lib.rs b/src/lib.rs index 87db400..df83037 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,6 +62,7 @@ macro_rules! highest_set_bit { } mod bitop; +mod bst_slice; mod cmp; mod fmt; mod from; @@ -1123,5 +1124,49 @@ mod tests { b <<= 72; assert!(a < b); } + + #[test] + fn eq_but_not_physical_same() { + let mut a = SmolBitSet::from(u16::MAX); + let mut b = SmolBitSet::from(0xFFFFu16); + + // ensure a is larger than b in memory + a.spill(256); + + assert_eq!(a, b); + assert_eq!(b, a); + + a <<= 55; + assert_ne!(a, b); + assert_ne!(b, a); + + b <<= 55; + assert_eq!(a, b); + assert_eq!(b, a); + } + + #[test] + fn ord_but_not_physical_same() { + let mut a = SmolBitSet::from(0xBEEFu16); + let mut b = SmolBitSet::from(0x00C5_F00Du32); + + // ensure a is larger than b in memory + a.spill(256); + + assert!(a < b); + assert!(b > a); + + a <<= 18; + assert!(a > b); + assert!(b < a); + + a <<= 54; + assert!(a > b); + assert!(b < a); + + b <<= 72; + assert!(a < b); + assert!(b > a); + } } }