Skip to content

Commit 1f72c4d

Browse files
committed
cryprot-core: portable transpose remove unsafe
1 parent d7d4a37 commit 1f72c4d

File tree

1 file changed

+78
-64
lines changed

1 file changed

+78
-64
lines changed

cryprot-core/src/transpose/portable.rs

Lines changed: 78 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,47 +3,91 @@ use wide::{i8x16, i64x2};
33
/// Transpose a bit matrix.
44
///
55
/// # Panics
6-
/// TODO
7-
pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
8-
assert!(rows >= 16);
9-
assert_eq!(0, rows % 16);
10-
assert_eq!(0, input.len() % rows);
6+
/// - If `rows < 16`
7+
/// - If `rows` is not divisible by 16
8+
/// - If `input.len()` is not divisible by `rows`
9+
/// - If the number of columns, computed as `input.len() * 8 / rows` is less
10+
/// than 16
11+
/// - If the number of columns is not divisible by 8
12+
pub(super) fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
13+
assert!(rows >= 16, "rows must be at least 16");
14+
assert_eq!(0, rows % 16, "rows must be divisible by 16");
15+
assert_eq!(
16+
0,
17+
input.len() % rows,
18+
"input.len() must be divisible by rows"
19+
);
1120
let cols = input.len() * 8 / rows;
12-
assert!(cols >= 16);
21+
assert!(cols >= 16, "columns must be at least 16. Columns {cols}");
1322
assert_eq!(
1423
0,
1524
cols % 8,
1625
"Number of bitmatrix columns must be divisable by 8. columns: {cols}"
1726
);
1827

19-
unsafe {
20-
let mut row: usize = 0;
21-
while row <= rows - 16 {
22-
let mut col = 0;
23-
while col < cols {
24-
let mut v = load_bytes(input, row, col, cols);
25-
// reverse iterator because we start writing the msb of each byte, then shift
26-
// left for i = 0, we write the previous lsb
27-
for i in (0..8).rev() {
28-
// get msb of each byte
29-
let msbs = v.to_bitmask().to_le_bytes();
30-
// write msbs to output at transposed position as one i16
31-
let msb_i16 = i16::from_ne_bytes([msbs[0], msbs[1]]);
32-
let idx = out(row, col + i, rows) as isize;
33-
let out_ptr = output.as_mut_ptr().offset(idx) as *mut i16;
34-
// ptr is potentially unaligned
35-
out_ptr.write_unaligned(msb_i16);
28+
// Transpose a matrix by splitting it into 16x8 blocks (=16 bytes which can be
29+
// one SSE2 128 bit vector depending on target support), write transposed
30+
// block into output at the transposed position and continue with next
31+
// block.
32+
33+
let mut row: usize = 0;
34+
while row <= rows - 16 {
35+
let mut col = 0;
36+
while col < cols {
37+
// Load 16x8 sub-block by loading (row + 0, col) .. (row + 15, col)
38+
let mut v = load_bytes(input, row, col, cols);
39+
// The ideas is to take the most significant bit of each row of the sub-block
40+
// (msb of each byte) and write these 16 bits coming from 16 rows
41+
// and one column in the input to the correct row and columns in the
42+
// output. Because the `move_mask` instruction gives us the `msb` of each byte
43+
// (=row) in our input, we iterate the output_row_offset from large
44+
// to small. After each iteration, we shift each byte in the
45+
// sub-block one bit to the left and then again get the msb of each byte and
46+
// write it to the next row.
47+
// Visualization done by Claude 4 Sonnet:
48+
// ┌─────────────────────────────────────────────────────────────┐
49+
// │ BIT MATRIX TRANSPOSE: 16x8 BLOCK │
50+
// ├─────────────────────────────────────────────────────────────┤
51+
// │ │
52+
// │ INPUT (16×8) OUTPUT (8×16) │
53+
// │ │
54+
// │ 0 1 ⋯ 6 7 0 1 ⋯ E F │
55+
// │ 0 ◆|◇|⋯|▲|△ 0 ◆|◆|⋯|◆|◆ │
56+
// │ 1 ◆|◇|⋯|▲|△ 1 ◇|◇|⋯|◇|◇ │
57+
// │ ⋮ ⋮ │
58+
// │ E ◆|◇|⋯|▲|△ 6 ▲|▲|⋯|▲|▲ │
59+
// │ F ◆|◇|⋯|▲|△ 7 △|△|⋯|△|△ │
60+
// │ │
61+
// │ MOVE_MASK: Extract column bits → Write as rows │
62+
// │ │
63+
// │ Iter 0: MSB column 7 → output row 7 │
64+
// │ Iter 1: << 1, MSB → output row 6 │
65+
// │ ⋮ │
66+
// │ Iter 7: << 7, MSB → output row 0 │
67+
// │ │
68+
// │ INPUT[row,col] → OUTPUT[col,row] │
69+
// │ │
70+
// └─────────────────────────────────────────────────────────────┘
71+
for output_row_offset in (0..8).rev() {
72+
// get msb of each byte
73+
let msbs = v.to_bitmask().to_le_bytes();
74+
// write msbs to output at transposed position
75+
let idx = out(row, col + output_row_offset, rows) as isize;
76+
// This should result in only one bounds check for the output
77+
let out_bytes = &mut output[idx as usize..idx as usize + 2];
78+
out_bytes[0] = msbs[0];
79+
out_bytes[1] = msbs[1];
3680

37-
// SAFETY: u8x16 and i64x2 have the same layout
38-
// we need to convert cast it, because there is no shift impl for u8x16
39-
let v_i64x2 = &mut v as *mut _ as *mut i64x2;
40-
// shift each byte by one to the left (by shifting it as two i64)
41-
*v_i64x2 = *v_i64x2 << 1;
42-
}
43-
col += 8;
81+
// There is no shift impl for i8x16 so we cast to i64x2 and shift these.
82+
// The bits shifted to neighbouring bytes are ignored because we iterate
83+
// and call move_mask 8 times.
84+
let v: &mut i64x2 = bytemuck::must_cast_mut(&mut v);
85+
// shift each byte by one to the left (by shifting it as two i64)
86+
*v = *v << 1;
4487
}
45-
row += 16;
88+
col += 8;
4689
}
90+
row += 16;
4791
}
4892
}
4993

@@ -58,39 +102,9 @@ fn out(x: usize, y: usize, rows: usize) -> usize {
58102

59103
#[inline]
60104
// get col byte of row to row + 15
61-
unsafe fn load_bytes(b: &[u8], row: usize, col: usize, cols: usize) -> i8x16 {
62-
unsafe {
63-
// if we have sse2 we use _mm_setr_epi8 and transmute to convert bytes
64-
// faster than from impl
65-
#[cfg(target_feature = "sse2")]
66-
{
67-
use std::{arch::x86_64::_mm_setr_epi8, mem::transmute};
68-
let v = _mm_setr_epi8(
69-
*b.get_unchecked(inp(row, col, cols)) as i8,
70-
*b.get_unchecked(inp(row + 1, col, cols)) as i8,
71-
*b.get_unchecked(inp(row + 2, col, cols)) as i8,
72-
*b.get_unchecked(inp(row + 3, col, cols)) as i8,
73-
*b.get_unchecked(inp(row + 4, col, cols)) as i8,
74-
*b.get_unchecked(inp(row + 5, col, cols)) as i8,
75-
*b.get_unchecked(inp(row + 6, col, cols)) as i8,
76-
*b.get_unchecked(inp(row + 7, col, cols)) as i8,
77-
*b.get_unchecked(inp(row + 8, col, cols)) as i8,
78-
*b.get_unchecked(inp(row + 9, col, cols)) as i8,
79-
*b.get_unchecked(inp(row + 10, col, cols)) as i8,
80-
*b.get_unchecked(inp(row + 11, col, cols)) as i8,
81-
*b.get_unchecked(inp(row + 12, col, cols)) as i8,
82-
*b.get_unchecked(inp(row + 13, col, cols)) as i8,
83-
*b.get_unchecked(inp(row + 14, col, cols)) as i8,
84-
*b.get_unchecked(inp(row + 15, col, cols)) as i8,
85-
);
86-
transmute(v)
87-
}
88-
#[cfg(not(target_feature = "sse2"))]
89-
{
90-
let bytes = std::array::from_fn(|i| *b.get_unchecked(inp(row + i, col, cols)) as i8);
91-
i8x16::from(bytes)
92-
}
93-
}
105+
fn load_bytes(b: &[u8], row: usize, col: usize, cols: usize) -> i8x16 {
106+
let bytes = std::array::from_fn(|i| b[inp(row + i, col, cols)] as i8);
107+
i8x16::from(bytes)
94108
}
95109

96110
#[cfg(test)]

0 commit comments

Comments
 (0)