From 23903f79028efc75a755288effb20bbdd1c3e6ac Mon Sep 17 00:00:00 2001 From: Ewan Mount Date: Wed, 16 Apr 2025 18:12:32 +0100 Subject: [PATCH] Updating Unstructured::int_in_range to minimize entropy consumption - int_in_range now uses bits from the given data individually when the span of the range given to it allows - Adds a new public method, int_range_bytes_needed to query how much data from the buffer is needed for an int_in_range call to succeed without using dummy values - int_in_range_impl is given a slice instead of a iterator so its use of the data is clearer - Adjusts internal (but not public) APIs to reflect certain functions being infallible - Docs and *most* tests are updated accordingly. Tests that check for specific values being produced from arbitrary() calls are left failing. --- src/unstructured.rs | 298 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 241 insertions(+), 57 deletions(-) diff --git a/src/unstructured.rs b/src/unstructured.rs index 519528e..216cce3 100644 --- a/src/unstructured.rs +++ b/src/unstructured.rs @@ -71,6 +71,8 @@ use std::{mem, ops}; #[derive(Debug)] pub struct Unstructured<'a> { data: &'a [u8], + leftover_bits: u8, + remaining_bit_count: u8, } impl<'a> Unstructured<'a> { @@ -84,7 +86,11 @@ impl<'a> Unstructured<'a> { /// let u = Unstructured::new(&[1, 2, 3, 4]); /// ``` pub fn new(data: &'a [u8]) -> Self { - Unstructured { data } + Unstructured { + data, + leftover_bits: 0, + remaining_bit_count: 0, + } } /// Get the number of remaining bytes of underlying data that are still @@ -215,19 +221,19 @@ impl<'a> Unstructured<'a> { where ElementType: Arbitrary<'a>, { - let byte_size = self.arbitrary_byte_size()?; + let byte_size = self.arbitrary_byte_size(); let (lower, upper) = ::size_hint(0); let elem_size = upper.unwrap_or(lower * 2); let elem_size = std::cmp::max(1, elem_size); Ok(byte_size / elem_size) } - fn arbitrary_byte_size(&mut self) -> Result { + fn arbitrary_byte_size(&mut self) -> usize { if self.data.is_empty() { - Ok(0) + 0 } else if self.data.len() == 1 { self.data = &[]; - Ok(0) + 0 } else { // Take lengths from the end of the data, since the `libFuzzer` folks // found that this lets fuzzers more efficiently explore the input @@ -238,45 +244,97 @@ impl<'a> Unstructured<'a> { // We only consume as many bytes as necessary to cover the entire // range of the byte string. // Note: We cast to u64 so we don't overflow when checking u32::MAX + 4 on 32-bit archs - let len = if self.data.len() as u64 <= u8::MAX as u64 + 1 { + if self.data.len() as u64 <= u8::MAX as u64 + 1 { let bytes = 1; let max_size = self.data.len() - bytes; let (rest, for_size) = self.data.split_at(max_size); self.data = rest; - Self::int_in_range_impl(0..=max_size as u8, for_size.iter().copied())?.0 as usize + self.int_in_range_impl(0..=max_size as u8, for_size).0 as usize } else if self.data.len() as u64 <= u16::MAX as u64 + 2 { let bytes = 2; let max_size = self.data.len() - bytes; let (rest, for_size) = self.data.split_at(max_size); self.data = rest; - Self::int_in_range_impl(0..=max_size as u16, for_size.iter().copied())?.0 as usize + self.int_in_range_impl(0..=max_size as u16, for_size).0 as usize } else if self.data.len() as u64 <= u32::MAX as u64 + 4 { let bytes = 4; let max_size = self.data.len() - bytes; let (rest, for_size) = self.data.split_at(max_size); self.data = rest; - Self::int_in_range_impl(0..=max_size as u32, for_size.iter().copied())?.0 as usize + self.int_in_range_impl(0..=max_size as u32, for_size).0 as usize } else { let bytes = 8; let max_size = self.data.len() - bytes; let (rest, for_size) = self.data.split_at(max_size); self.data = rest; - Self::int_in_range_impl(0..=max_size as u64, for_size.iter().copied())?.0 as usize - }; + self.int_in_range_impl(0..=max_size as u64, for_size).0 as usize + } + } + } + + /// Returns the minimum number of whole bytes needed from the data to call `int_in_range` over the given range without resorting to dummy values. + /// This takes into account any bits stored from previous calls - it is not a function of the range alone + /// + /// # Example + /// + /// ``` + /// + /// use arbitrary::Unstructured; + /// + /// let mut u = Unstructured::new(&[1]); + /// assert_eq!(u.int_range_bytes_needed(0..=255u8), 1); + /// + /// let _ = u.int_in_range(0..=1); // Use a partial byte + /// assert_eq!(u.int_range_bytes_needed(0..=127u8), 0, "Unit test"); + /// ``` + /// + pub fn int_range_bytes_needed(&self, range: ops::RangeInclusive) -> usize { + if range.is_empty() { + return 0; + } + let start = *range.start(); + let end = *range.end(); + let raw_bits = Self::int_range_bits_needed(start, end); + let external_bits = raw_bits.saturating_sub(self.remaining_bit_count as _); + let round_up = if external_bits % 8 > 0 { 1 } else { 0 }; + (external_bits / 8) + round_up + } - Ok(len) + /// Count how many bits needed to be pulled out of the data to cover the whole range + /// This is more awkward than it needs to be to avoid adding requirements to `Int` unnecessarily + /// This does *not* attempt to take into account how many stored bits we have + /// It also assumes we've checked for empty ranges + fn int_range_bits_needed(start: T, end: T) -> usize { + let start = start.to_unsigned(); + let end = end.to_unsigned(); + + let orig_option_count = end.wrapping_sub(start); + // Try to add 1 to account for the range being inclusive + let option_count = orig_option_count + .checked_add(T::Unsigned::ONE) + .unwrap_or(orig_option_count); + // If the addition overflowed it's because we're selecting from the entire range of the type + // But reusing the original value is fine in that case because it is always one less than the + // power of two we wanted anyway + + // Find the binary exponent needed to cover from 0 up to `option_count` (or beyond) + let mut counter = T::Unsigned::ONE; + let mut bits_needed = 0; + while counter < option_count && counter > T::Unsigned::ZERO { + counter = counter << 1; + bits_needed += 1; } + bits_needed } - /// Generate an integer within the given range. + /// Generate an integer within the given *inclusive* range. /// /// Do not use this to generate the size of a collection. Use /// `arbitrary_len` instead. /// - /// The probability distribution of the return value is not necessarily uniform. + /// Attempts to cache unused bits from the data buffer if covering the given range needs some but not all of a byte. (e.g. calling this function with `0..=15` *twice* will advance the buffer only by a single byte, after the first call.) The leftover bits will be reused in future calls to methods in this type, *except* those methods returning entire segments of the buffer, namely [`fill_buffer`], [`peek_buffer`] and [`take_rest`]. /// - /// Returns `range.start()`, not an error, - /// if this `Unstructured` [is empty][Unstructured::is_empty]. + /// The probability distribution of the return value is not necessarily uniform. /// /// # Panics /// @@ -301,20 +359,25 @@ impl<'a> Unstructured<'a> { where T: Int, { - let (result, bytes_consumed) = Self::int_in_range_impl(range, self.data.iter().cloned())?; - self.data = &self.data[bytes_consumed..]; + let (result, data_bits_consumed) = self.int_in_range_impl(range, self.data); + + // If `bits_consumed` is not a multiple of 8, the partial byte has been stored in `leftover_bits` so we should remove it from the data buffer + let whole_bytes = (data_bits_consumed / 8) + if data_bits_consumed % 8 > 0 { 1 } else { 0 }; + self.data = &self.data[whole_bytes..]; Ok(result) } + /// Core algorithm for `int_in_range`, split out so different uses can provide different data buffers fn int_in_range_impl( - range: ops::RangeInclusive, - mut bytes: impl Iterator, - ) -> Result<(T, usize)> + &mut self, + value_range: ops::RangeInclusive, + bytes: &[u8], + ) -> (T, usize) where T: Int, { - let start = *range.start(); - let end = *range.end(); + let start = *value_range.start(); + let end = *value_range.end(); assert!( start <= end, "`arbitrary::Unstructured::int_in_range` requires a non-empty range" @@ -323,7 +386,7 @@ impl<'a> Unstructured<'a> { // When there is only one possible choice, don't waste any entropy from // the underlying data. if start == end { - return Ok((start, 0)); + return (start, 0); } // From here on out we work with the unsigned representation. All of the @@ -335,31 +398,72 @@ impl<'a> Unstructured<'a> { let delta = end.wrapping_sub(start); debug_assert_ne!(delta, T::Unsigned::ZERO); - // Compute an arbitrary integer offset from the start of the range. We - // do this by consuming `size_of(T)` bytes from the input to create an - // arbitrary integer and then clamping that int into our range bounds - // with a modulo operation. - let mut arbitrary_int = T::Unsigned::ZERO; - let mut bytes_consumed: usize = 0; - - while (bytes_consumed < mem::size_of::()) - && (delta >> T::Unsigned::from_usize(bytes_consumed * 8)) > T::Unsigned::ZERO - { - let byte = match bytes.next() { - None => break, - Some(b) => b, - }; - bytes_consumed += 1; - - // Combine this byte into our arbitrary integer, but avoid - // overflowing the shift for `u8` and `i8`. - arbitrary_int = if mem::size_of::() == 1 { - T::Unsigned::from_u8(byte) - } else { - (arbitrary_int << 8) | T::Unsigned::from_u8(byte) - }; + let total_bits_needed = Self::int_range_bits_needed(start, end); + + // Compute an arbitrary integer offset from the start of the range. + + // This will be built from three parts - any stored bits we have from previous calls, a number of whole bytes, then a partial byte + // Any of these parts may be empty depending on how exactly how many bits are needed and are saved. + + // Use any leftover bits from the last byte consumed + let (arbitrary_int_head, internal_bits_consumed) = if self.remaining_bit_count > 0 { + self.shift_bits_through(None, total_bits_needed) + } else { + (0, 0) + }; + + let mut arbitrary_int = T::Unsigned::from_u8(arbitrary_int_head); + + let bits_still_needed = total_bits_needed - internal_bits_consumed; + let whole_bytes_needed = bits_still_needed / 8; + let trailing_bits = bits_still_needed % 8; + let part_byte_needed = if trailing_bits > 0 { 1 } else { 0 }; + + let external_bytes_used = if whole_bytes_needed + part_byte_needed < bytes.len() { + bytes.split_at(whole_bytes_needed + part_byte_needed).0 + } else { + bytes + }; + + let mut external_bits_consumed = 0; + + match whole_bytes_needed { + 0 => { /* nothing to do */ } + 1 => { + // We handle this case specially because this can happen when `sizeof() == 1` and we have no stored bits + // In that case, shifting up a whole byte causes an overflow. + // If `sizeof() == 1`, we cannot need more than 8 bits in total, so this doesn't need to be handled in other cases + let b = external_bytes_used.first().copied().unwrap_or_default(); + arbitrary_int = T::Unsigned::from_u8(b); + external_bits_consumed += 8; + } + _ => { + // This `get` call will fail if we don't have enough data to cover the whole range but then just use what we have + for &b in external_bytes_used + .get(..whole_bytes_needed) + .unwrap_or(external_bytes_used) + { + arbitrary_int = (arbitrary_int << 8) | T::Unsigned::from_u8(b); + external_bits_consumed += 8 + } + } } + // Get any fractional byte we still need, defaulting to 0 if there's no more data. + let (arbitrary_int_tail, tail_bits_consumed) = if trailing_bits > 0 { + let last_byte = *external_bytes_used.last().unwrap_or(&0); + self.shift_bits_through(Some(last_byte), trailing_bits) + } else { + (0, 0) + }; + + let arbitrary_int_tail = T::Unsigned::from_u8(arbitrary_int_tail); + external_bits_consumed += tail_bits_consumed; + + // Shift the partial byte into the end result + // `trailing_bits < 8` by definition so this will be safe even if `sizeof() == 1` + let arbitrary_int = (arbitrary_int << trailing_bits) | arbitrary_int_tail; + let offset = if delta == T::Unsigned::MAX { arbitrary_int } else { @@ -372,10 +476,47 @@ impl<'a> Unstructured<'a> { // And convert back to our maybe-signed representation. let result = T::from_unsigned(result); - debug_assert!(*range.start() <= result); - debug_assert!(result <= *range.end()); + debug_assert!(*value_range.start() <= result); + debug_assert!(result <= *value_range.end()); + + (result, external_bits_consumed) + } + + /// Returns the least significant N bits, up to all 8, from `external_bits` or from `self.leftover_bits` if the former isn't provided. + /// + /// It additionally returns the number of bits actually used. This wil be the smaller of + /// a) the number of bits needed + /// b) the number of previously stored bits, if `external_bits` was not provided + /// c) 8, meaning the entire byte of input was used (although this doesn't happen given current use in `int_in_range_impl`) + /// + /// If less than a whole byte is used, the excess high bits are stored to use for a later call. + /// The unused bits from the input are stored unconditionally - if `external_bits` is `Some` while `leftover_bits` is non-zero, the stored bits will be lost, unused. + fn shift_bits_through(&mut self, external_bits: Option, bits_needed: usize) -> (u8, usize) { + let (bit_data, remaining_bits) = if let Some(bit_data) = external_bits { + (bit_data, 8) + } else { + (self.leftover_bits, self.remaining_bit_count) + }; + let bits_used = std::cmp::min(bits_needed, remaining_bits as usize); + debug_assert!(bits_used <= 8, "Tried to use {bits_used}>8 bits"); + + let mask = (1usize << bits_used) - 1; + let low_bits = bit_data & (mask as u8); + let high_bits = if bits_used < 8 { + bit_data >> bits_used + } else { + 0 + }; - Ok((result, bytes_consumed)) + // If we're being asked for 8 or more bits, both of these should end up 0 + self.remaining_bit_count = remaining_bits - bits_used as u8; + self.leftover_bits = high_bits; + + // The RBC should *always* be less than a full byte + debug_assert!(self.remaining_bit_count < 8); + debug_assert!(bits_used <= bits_needed); + + (low_bits, bits_used) } /// Choose one of the given choices. @@ -929,14 +1070,14 @@ mod tests { fn test_byte_size() { let mut u = Unstructured::new(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 6]); // Should take one byte off the end - assert_eq!(u.arbitrary_byte_size().unwrap(), 6); + assert_eq!(u.arbitrary_byte_size(), 2); assert_eq!(u.len(), 9); let mut v = vec![0; 260]; v.push(1); v.push(4); let mut u = Unstructured::new(&v); // Should read two bytes off the end - assert_eq!(u.arbitrary_byte_size().unwrap(), 0x104); + assert_eq!(u.arbitrary_byte_size(), 2); assert_eq!(u.len(), 260); } @@ -952,15 +1093,15 @@ mod tests { #[test] fn int_in_range_uses_minimal_amount_of_bytes() { let mut u = Unstructured::new(&[1, 2]); - assert_eq!(1, u.int_in_range::(0..=u8::MAX).unwrap()); + let _ = u.int_in_range::(0..=u8::MAX).unwrap(); assert_eq!(u.len(), 1); let mut u = Unstructured::new(&[1, 2]); - assert_eq!(1, u.int_in_range::(0..=u8::MAX as u32).unwrap()); + let _ = u.int_in_range::(0..=u8::MAX as u32).unwrap(); assert_eq!(u.len(), 1); let mut u = Unstructured::new(&[1]); - assert_eq!(1, u.int_in_range::(0..=u8::MAX as u32 + 1).unwrap()); + let _ = u.int_in_range::(0..=u8::MAX as u32 + 1).unwrap(); assert!(u.is_empty()); } @@ -1010,7 +1151,11 @@ mod tests { } for (i, covered) in full.iter().enumerate() { - assert!(covered, "full[{}] should have been generated", i); + assert!( + covered, + "full[{}] should have been generated: {:?}", + i, full + ); } for (i, covered) in no_zero.iter().enumerate() { assert!(covered, "no_zero[{}] should have been generated", i); @@ -1019,7 +1164,11 @@ mod tests { assert!(covered, "no_max[{}] should have been generated", i); } for (i, covered) in narrow.iter().enumerate() { - assert!(covered, "narrow[{}] should have been generated", i); + assert!( + covered, + "narrow[{}] should have been generated: {narrow:?}", + i + ); } } @@ -1068,4 +1217,39 @@ mod tests { assert!(covered, "narrow[{}] should have been generated", i); } } + + #[test] + fn test_bit_counts() { + assert_eq!(0, Unstructured::int_range_bits_needed(2, 2)); + assert_eq!(1, Unstructured::int_range_bits_needed(2, 3)); + assert_eq!(2, Unstructured::int_range_bits_needed(2, 4)); + assert_eq!(8, Unstructured::int_range_bits_needed(0, u8::MAX)); + assert_eq!(32, Unstructured::int_range_bits_needed(0, u32::MAX)); + } + + #[test] + fn test_byte_counts_empty() { + let blank = Unstructured::new(&[]); + + assert_eq!(0, blank.int_range_bytes_needed(2..=1)); + assert_eq!(0, blank.int_range_bytes_needed(2..=2)); + assert_eq!(1, blank.int_range_bytes_needed(2..=3)); + assert_eq!(1, blank.int_range_bytes_needed(2..=4)); + assert_eq!(1, blank.int_range_bytes_needed(0..=u8::MAX)); + assert_eq!(4, blank.int_range_bytes_needed(0..=u32::MAX)); + } + + #[test] + fn test_byte_counts_loaded() { + let mut loaded = Unstructured::new(&[255]); + let _ = loaded.int_in_range(0..=1); + + assert_eq!(0, loaded.int_range_bytes_needed(2..=1)); + assert_eq!(0, loaded.int_range_bytes_needed(2..=2)); + assert_eq!(0, loaded.int_range_bytes_needed(2..=3)); + assert_eq!(0, loaded.int_range_bytes_needed(2..=4)); + assert_eq!(0, loaded.int_range_bytes_needed(0..=126)); + assert_eq!(1, loaded.int_range_bytes_needed(0..=u8::MAX)); + assert_eq!(3, loaded.int_range_bytes_needed(0..=u32::MAX / 2)); + } }