Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion utils/resb/src/binary/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ pub fn option_i32_tuple<'de, D: Deserializer<'de>>(
/// # Safety
/// Alignment and length are checked, however the caller has to guarantee that the byte representation is valid for type T.
pub unsafe fn cast_bytes_to_slice<T, E: Error>(bytes: &[u8]) -> Result<&[T], E> {
if bytes.as_ptr().align_offset(align_of::<T>()) != 0 && bytes.len() % size_of::<T>() != 0 {
if bytes.as_ptr().align_offset(align_of::<T>()) != 0 || bytes.len() % size_of::<T>() != 0 {
return Err(E::custom("Wrong length or align"));
}

Expand All @@ -109,3 +109,42 @@ pub unsafe fn cast_bytes_to_slice<T, E: Error>(bytes: &[u8]) -> Result<&[T], E>
core::slice::from_raw_parts(bytes.as_ptr() as *const T, bytes.len() / size_of::<T>())
})
}

#[cfg(test)]
mod tests {
use super::*;

type E = value::Error;

#[test]
fn cast_aligned_correct_length() {
let data: [u32; 2] = [1, 2];
let bytes: &[u8] =
unsafe { core::slice::from_raw_parts(data.as_ptr() as *const u8, size_of_val(&data)) };
let result: Result<&[u32], E> = unsafe { cast_bytes_to_slice(bytes) };
assert!(result.is_ok());
assert_eq!(result.unwrap(), &[1u32, 2]);
}

#[test]
fn cast_wrong_length_only() {
// 4-byte aligned but 5 bytes long (not a multiple of 4)
let data: [u32; 2] = [1, 2];
let bytes: &[u8] = unsafe { core::slice::from_raw_parts(data.as_ptr() as *const u8, 5) };
let result: Result<&[u32], E> = unsafe { cast_bytes_to_slice(bytes) };
assert!(result.is_err(), "wrong length alone must be rejected");
}

#[test]
fn cast_wrong_alignment_only() {
// Correctly sized (4 bytes) but misaligned by 1
let data: [u8; 8] = [0; 8];
// Find a u32-aligned start, then offset by 1
let aligned_start = data.as_ptr().align_offset(align_of::<u32>());
let misaligned = &data[aligned_start + 1..aligned_start + 5];
assert_eq!(misaligned.len(), 4); // correct length for one u32
assert_ne!(misaligned.as_ptr().align_offset(align_of::<u32>()), 0);
let result: Result<&[u32], E> = unsafe { cast_bytes_to_slice(misaligned) };
assert!(result.is_err(), "wrong alignment alone must be rejected");
}
}
Loading