Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 77 additions & 63 deletions cryprot-core/src/transpose/portable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,91 @@ use wide::{i8x16, i64x2};
/// Transpose a bit matrix.
///
/// # Panics
/// TODO
/// - If `rows < 16`
/// - If `rows` is not divisible by 16
/// - If `input.len()` is not divisible by `rows`
/// - If the number of columns, computed as `input.len() * 8 / rows` is less
/// than 16
/// - If the number of columns is not divisible by 8
pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
assert!(rows >= 16);
assert_eq!(0, rows % 16);
assert_eq!(0, input.len() % rows);
assert!(rows >= 16, "rows must be at least 16");
assert_eq!(0, rows % 16, "rows must be divisible by 16");
assert_eq!(
0,
input.len() % rows,
"input.len() must be divisible by rows"
);
let cols = input.len() * 8 / rows;
assert!(cols >= 16);
assert!(cols >= 16, "columns must be at least 16. Columns {cols}");
assert_eq!(
0,
cols % 8,
"Number of bitmatrix columns must be divisable by 8. columns: {cols}"
);

unsafe {
let mut row: usize = 0;
while row <= rows - 16 {
let mut col = 0;
while col < cols {
let mut v = load_bytes(input, row, col, cols);
// reverse iterator because we start writing the msb of each byte, then shift
// left for i = 0, we write the previous lsb
for i in (0..8).rev() {
// get msb of each byte
let msbs = v.to_bitmask().to_le_bytes();
// write msbs to output at transposed position as one i16
let msb_i16 = i16::from_ne_bytes([msbs[0], msbs[1]]);
let idx = out(row, col + i, rows) as isize;
let out_ptr = output.as_mut_ptr().offset(idx) as *mut i16;
// ptr is potentially unaligned
out_ptr.write_unaligned(msb_i16);
// Transpose a matrix by splitting it into 16x8 blocks (=16 bytes which can be
// one SSE2 128 bit vector depending on target support), write transposed
// block into output at the transposed position and continue with next
// block.

let mut row: usize = 0;
while row <= rows - 16 {
let mut col = 0;
while col < cols {
// Load 16x8 sub-block by loading (row + 0, col) .. (row + 15, col)
let mut v = load_bytes(input, row, col, cols);
// The ideas is to take the most significant bit of each row of the sub-block
// (msb of each byte) and write these 16 bits coming from 16 rows
// and one column in the input to the correct row and columns in the
// output. Because the `move_mask` instruction gives us the `msb` of each byte
// (=row) in our input, we iterate the output_row_offset from large
// to small. After each iteration, we shift each byte in the
// sub-block one bit to the left and then again get the msb of each byte and
// write it to the next row.
// Visualization done by Claude 4 Sonnet:
// ┌─────────────────────────────────────────────────────────────┐
// │ BIT MATRIX TRANSPOSE: 16x8 BLOCK │
// ├─────────────────────────────────────────────────────────────┤
// │ │
// │ INPUT (16×8) OUTPUT (8×16) │
// │ │
// │ 0 1 ⋯ 6 7 0 1 ⋯ E F │
// │ 0 ◆|◇|⋯|▲|△ 0 ◆|◆|⋯|◆|◆ │
// │ 1 ◆|◇|⋯|▲|△ 1 ◇|◇|⋯|◇|◇ │
// │ ⋮ ⋮ │
// │ E ◆|◇|⋯|▲|△ 6 ▲|▲|⋯|▲|▲ │
// │ F ◆|◇|⋯|▲|△ 7 △|△|⋯|△|△ │
// │ │
// │ MOVE_MASK: Extract column bits → Write as rows │
// │ │
// │ Iter 0: MSB column 7 → output row 7 │
// │ Iter 1: << 1, MSB → output row 6 │
// │ ⋮ │
// │ Iter 7: << 7, MSB → output row 0 │
// │ │
// │ INPUT[row,col] → OUTPUT[col,row] │
// │ │
// └─────────────────────────────────────────────────────────────┘
for output_row_offset in (0..8).rev() {
// get msb of each byte
let msbs = v.to_bitmask().to_le_bytes();
// write msbs to output at transposed position
let idx = out(row, col + output_row_offset, rows) as isize;
// This should result in only one bounds check for the output
let out_bytes = &mut output[idx as usize..idx as usize + 2];
out_bytes[0] = msbs[0];
out_bytes[1] = msbs[1];

// SAFETY: u8x16 and i64x2 have the same layout
// we need to convert cast it, because there is no shift impl for u8x16
let v_i64x2 = &mut v as *mut _ as *mut i64x2;
// shift each byte by one to the left (by shifting it as two i64)
*v_i64x2 = *v_i64x2 << 1;
}
col += 8;
// There is no shift impl for i8x16 so we cast to i64x2 and shift these.
// The bits shifted to neighbouring bytes are ignored because we iterate
// and call move_mask 8 times.
let v: &mut i64x2 = bytemuck::must_cast_mut(&mut v);
// shift each byte by one to the left (by shifting it as two i64)
*v = *v << 1;
}
row += 16;
col += 8;
}
row += 16;
}
}

Expand All @@ -58,39 +102,9 @@ fn out(x: usize, y: usize, rows: usize) -> usize {

#[inline]
// get col byte of row to row + 15
unsafe fn load_bytes(b: &[u8], row: usize, col: usize, cols: usize) -> i8x16 {
unsafe {
// if we have sse2 we use _mm_setr_epi8 and transmute to convert bytes
// faster than from impl
#[cfg(target_feature = "sse2")]
{
use std::{arch::x86_64::_mm_setr_epi8, mem::transmute};
let v = _mm_setr_epi8(
*b.get_unchecked(inp(row, col, cols)) as i8,
*b.get_unchecked(inp(row + 1, col, cols)) as i8,
*b.get_unchecked(inp(row + 2, col, cols)) as i8,
*b.get_unchecked(inp(row + 3, col, cols)) as i8,
*b.get_unchecked(inp(row + 4, col, cols)) as i8,
*b.get_unchecked(inp(row + 5, col, cols)) as i8,
*b.get_unchecked(inp(row + 6, col, cols)) as i8,
*b.get_unchecked(inp(row + 7, col, cols)) as i8,
*b.get_unchecked(inp(row + 8, col, cols)) as i8,
*b.get_unchecked(inp(row + 9, col, cols)) as i8,
*b.get_unchecked(inp(row + 10, col, cols)) as i8,
*b.get_unchecked(inp(row + 11, col, cols)) as i8,
*b.get_unchecked(inp(row + 12, col, cols)) as i8,
*b.get_unchecked(inp(row + 13, col, cols)) as i8,
*b.get_unchecked(inp(row + 14, col, cols)) as i8,
*b.get_unchecked(inp(row + 15, col, cols)) as i8,
);
transmute(v)
}
#[cfg(not(target_feature = "sse2"))]
{
let bytes = std::array::from_fn(|i| *b.get_unchecked(inp(row + i, col, cols)) as i8);
i8x16::from(bytes)
}
}
fn load_bytes(b: &[u8], row: usize, col: usize, cols: usize) -> i8x16 {
let bytes = std::array::from_fn(|i| b[inp(row + i, col, cols)] as i8);
i8x16::from(bytes)
}

#[cfg(test)]
Expand Down