From 839b0aa0c61f5b1fabf7c195c65d13edb5ae1dec Mon Sep 17 00:00:00 2001 From: robinhundt <24554122+robinhundt@users.noreply.github.com> Date: Fri, 6 Jun 2025 12:25:14 +0200 Subject: [PATCH] cryprot-core: portable transpose remove unsafe --- cryprot-core/src/transpose/portable.rs | 140 ++++++++++++++----------- 1 file changed, 77 insertions(+), 63 deletions(-) diff --git a/cryprot-core/src/transpose/portable.rs b/cryprot-core/src/transpose/portable.rs index 50d04ce..31f8d75 100644 --- a/cryprot-core/src/transpose/portable.rs +++ b/cryprot-core/src/transpose/portable.rs @@ -3,47 +3,91 @@ use wide::{i8x16, i64x2}; /// Transpose a bit matrix. /// /// # Panics -/// TODO +/// - If `rows < 16` +/// - If `rows` is not divisible by 16 +/// - If `input.len()` is not divisible by `rows` +/// - If the number of columns, computed as `input.len() * 8 / rows` is less +/// than 16 +/// - If the number of columns is not divisible by 8 pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) { - assert!(rows >= 16); - assert_eq!(0, rows % 16); - assert_eq!(0, input.len() % rows); + assert!(rows >= 16, "rows must be at least 16"); + assert_eq!(0, rows % 16, "rows must be divisible by 16"); + assert_eq!( + 0, + input.len() % rows, + "input.len() must be divisible by rows" + ); let cols = input.len() * 8 / rows; - assert!(cols >= 16); + assert!(cols >= 16, "columns must be at least 16. Columns {cols}"); assert_eq!( 0, cols % 8, "Number of bitmatrix columns must be divisable by 8. columns: {cols}" ); - unsafe { - let mut row: usize = 0; - while row <= rows - 16 { - let mut col = 0; - while col < cols { - let mut v = load_bytes(input, row, col, cols); - // reverse iterator because we start writing the msb of each byte, then shift - // left for i = 0, we write the previous lsb - for i in (0..8).rev() { - // get msb of each byte - let msbs = v.to_bitmask().to_le_bytes(); - // write msbs to output at transposed position as one i16 - let msb_i16 = i16::from_ne_bytes([msbs[0], msbs[1]]); - let idx = out(row, col + i, rows) as isize; - let out_ptr = output.as_mut_ptr().offset(idx) as *mut i16; - // ptr is potentially unaligned - out_ptr.write_unaligned(msb_i16); + // Transpose a matrix by splitting it into 16x8 blocks (=16 bytes which can be + // one SSE2 128 bit vector depending on target support), write transposed + // block into output at the transposed position and continue with next + // block. + + let mut row: usize = 0; + while row <= rows - 16 { + let mut col = 0; + while col < cols { + // Load 16x8 sub-block by loading (row + 0, col) .. (row + 15, col) + let mut v = load_bytes(input, row, col, cols); + // The ideas is to take the most significant bit of each row of the sub-block + // (msb of each byte) and write these 16 bits coming from 16 rows + // and one column in the input to the correct row and columns in the + // output. Because the `move_mask` instruction gives us the `msb` of each byte + // (=row) in our input, we iterate the output_row_offset from large + // to small. After each iteration, we shift each byte in the + // sub-block one bit to the left and then again get the msb of each byte and + // write it to the next row. + // Visualization done by Claude 4 Sonnet: + // ┌─────────────────────────────────────────────────────────────┐ + // │ BIT MATRIX TRANSPOSE: 16x8 BLOCK │ + // ├─────────────────────────────────────────────────────────────┤ + // │ │ + // │ INPUT (16×8) OUTPUT (8×16) │ + // │ │ + // │ 0 1 ⋯ 6 7 0 1 ⋯ E F │ + // │ 0 ◆|◇|⋯|▲|△ 0 ◆|◆|⋯|◆|◆ │ + // │ 1 ◆|◇|⋯|▲|△ 1 ◇|◇|⋯|◇|◇ │ + // │ ⋮ ⋮ │ + // │ E ◆|◇|⋯|▲|△ 6 ▲|▲|⋯|▲|▲ │ + // │ F ◆|◇|⋯|▲|△ 7 △|△|⋯|△|△ │ + // │ │ + // │ MOVE_MASK: Extract column bits → Write as rows │ + // │ │ + // │ Iter 0: MSB column 7 → output row 7 │ + // │ Iter 1: << 1, MSB → output row 6 │ + // │ ⋮ │ + // │ Iter 7: << 7, MSB → output row 0 │ + // │ │ + // │ INPUT[row,col] → OUTPUT[col,row] │ + // │ │ + // └─────────────────────────────────────────────────────────────┘ + for output_row_offset in (0..8).rev() { + // get msb of each byte + let msbs = v.to_bitmask().to_le_bytes(); + // write msbs to output at transposed position + let idx = out(row, col + output_row_offset, rows) as isize; + // This should result in only one bounds check for the output + let out_bytes = &mut output[idx as usize..idx as usize + 2]; + out_bytes[0] = msbs[0]; + out_bytes[1] = msbs[1]; - // SAFETY: u8x16 and i64x2 have the same layout - // we need to convert cast it, because there is no shift impl for u8x16 - let v_i64x2 = &mut v as *mut _ as *mut i64x2; - // shift each byte by one to the left (by shifting it as two i64) - *v_i64x2 = *v_i64x2 << 1; - } - col += 8; + // There is no shift impl for i8x16 so we cast to i64x2 and shift these. + // The bits shifted to neighbouring bytes are ignored because we iterate + // and call move_mask 8 times. + let v: &mut i64x2 = bytemuck::must_cast_mut(&mut v); + // shift each byte by one to the left (by shifting it as two i64) + *v = *v << 1; } - row += 16; + col += 8; } + row += 16; } } @@ -58,39 +102,9 @@ fn out(x: usize, y: usize, rows: usize) -> usize { #[inline] // get col byte of row to row + 15 -unsafe fn load_bytes(b: &[u8], row: usize, col: usize, cols: usize) -> i8x16 { - unsafe { - // if we have sse2 we use _mm_setr_epi8 and transmute to convert bytes - // faster than from impl - #[cfg(target_feature = "sse2")] - { - use std::{arch::x86_64::_mm_setr_epi8, mem::transmute}; - let v = _mm_setr_epi8( - *b.get_unchecked(inp(row, col, cols)) as i8, - *b.get_unchecked(inp(row + 1, col, cols)) as i8, - *b.get_unchecked(inp(row + 2, col, cols)) as i8, - *b.get_unchecked(inp(row + 3, col, cols)) as i8, - *b.get_unchecked(inp(row + 4, col, cols)) as i8, - *b.get_unchecked(inp(row + 5, col, cols)) as i8, - *b.get_unchecked(inp(row + 6, col, cols)) as i8, - *b.get_unchecked(inp(row + 7, col, cols)) as i8, - *b.get_unchecked(inp(row + 8, col, cols)) as i8, - *b.get_unchecked(inp(row + 9, col, cols)) as i8, - *b.get_unchecked(inp(row + 10, col, cols)) as i8, - *b.get_unchecked(inp(row + 11, col, cols)) as i8, - *b.get_unchecked(inp(row + 12, col, cols)) as i8, - *b.get_unchecked(inp(row + 13, col, cols)) as i8, - *b.get_unchecked(inp(row + 14, col, cols)) as i8, - *b.get_unchecked(inp(row + 15, col, cols)) as i8, - ); - transmute(v) - } - #[cfg(not(target_feature = "sse2"))] - { - let bytes = std::array::from_fn(|i| *b.get_unchecked(inp(row + i, col, cols)) as i8); - i8x16::from(bytes) - } - } +fn load_bytes(b: &[u8], row: usize, col: usize, cols: usize) -> i8x16 { + let bytes = std::array::from_fn(|i| b[inp(row + i, col, cols)] as i8); + i8x16::from(bytes) } #[cfg(test)]