Skip to content

Commit 23ffb32

Browse files
committed
cryprot-core: portable transpose remove unsafe
1 parent cde3e70 commit 23ffb32

File tree

1 file changed

+38
-63
lines changed

1 file changed

+38
-63
lines changed

cryprot-core/src/transpose/portable.rs

Lines changed: 38 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,47 +3,52 @@ use wide::{i8x16, i64x2};
33
/// Transpose a bit matrix.
44
///
55
/// # Panics
6-
/// TODO
6+
/// - If `rows < 16`
7+
/// - If `rows` is not divisible by 16
8+
/// - If `input.len()` is not divisible by `rows`
9+
/// - If the number of columns, computed as `input.len() * 8 / rows` is less
10+
/// than 16
11+
/// - If the number of columns is not divisible by 8
712
pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
8-
assert!(rows >= 16);
9-
assert_eq!(0, rows % 16);
10-
assert_eq!(0, input.len() % rows);
13+
assert!(rows >= 16, "rows must be at least 16");
14+
assert_eq!(0, rows % 16, "rows must be divisible by 16");
15+
assert_eq!(
16+
0,
17+
input.len() % rows,
18+
"input.len() must be divisible by rows"
19+
);
1120
let cols = input.len() * 8 / rows;
12-
assert!(cols >= 16);
21+
assert!(cols >= 16, "columns must be at least 16. Columns {cols}");
1322
assert_eq!(
1423
0,
1524
cols % 8,
1625
"Number of bitmatrix columns must be divisable by 8. columns: {cols}"
1726
);
1827

19-
unsafe {
20-
let mut row: usize = 0;
21-
while row <= rows - 16 {
22-
let mut col = 0;
23-
while col < cols {
24-
let mut v = load_bytes(input, row, col, cols);
25-
// reverse iterator because we start writing the msb of each byte, then shift
26-
// left for i = 0, we write the previous lsb
27-
for i in (0..8).rev() {
28-
// get msb of each byte
29-
let msbs = v.move_mask().to_le_bytes();
30-
// write msbs to output at transposed position as one i16
31-
let msb_i16 = i16::from_ne_bytes([msbs[0], msbs[1]]);
32-
let idx = out(row, col + i, rows) as isize;
33-
let out_ptr = output.as_mut_ptr().offset(idx) as *mut i16;
34-
// ptr is potentially unaligned
35-
out_ptr.write_unaligned(msb_i16);
28+
let mut row: usize = 0;
29+
while row <= rows - 16 {
30+
let mut col = 0;
31+
while col < cols {
32+
let mut v = load_bytes(input, row, col, cols);
33+
// reverse iterator because we start writing the msb of each byte, then shift
34+
// left for i = 0, we write the previous lsb
35+
for i in (0..8).rev() {
36+
// get msb of each byte
37+
let msbs = v.move_mask().to_le_bytes();
38+
// write msbs to output at transposed position
39+
let idx = out(row, col + i, rows) as isize;
40+
// This should result in only one bounds check for the output
41+
let out_bytes = &mut output[idx as usize..idx as usize + 2];
42+
out_bytes[0] = msbs[0];
43+
out_bytes[1] = msbs[1];
3644

37-
// SAFETY: u8x16 and i64x2 have the same layout
38-
// we need to convert cast it, because there is no shift impl for u8x16
39-
let v_i64x2 = &mut v as *mut _ as *mut i64x2;
40-
// shift each byte by one to the left (by shifting it as two i64)
41-
*v_i64x2 = *v_i64x2 << 1;
42-
}
43-
col += 8;
45+
let v: &mut i64x2 = bytemuck::must_cast_mut(&mut v);
46+
// shift each byte by one to the left (by shifting it as two i64)
47+
*v = *v << 1;
4448
}
45-
row += 16;
49+
col += 8;
4650
}
51+
row += 16;
4752
}
4853
}
4954

@@ -58,39 +63,9 @@ fn out(x: usize, y: usize, rows: usize) -> usize {
5863

5964
#[inline]
6065
// get col byte of row to row + 15
61-
unsafe fn load_bytes(b: &[u8], row: usize, col: usize, cols: usize) -> i8x16 {
62-
unsafe {
63-
// if we have sse2 we use _mm_setr_epi8 and transmute to convert bytes
64-
// faster than from impl
65-
#[cfg(target_feature = "sse2")]
66-
{
67-
use std::{arch::x86_64::_mm_setr_epi8, mem::transmute};
68-
let v = _mm_setr_epi8(
69-
*b.get_unchecked(inp(row, col, cols)) as i8,
70-
*b.get_unchecked(inp(row + 1, col, cols)) as i8,
71-
*b.get_unchecked(inp(row + 2, col, cols)) as i8,
72-
*b.get_unchecked(inp(row + 3, col, cols)) as i8,
73-
*b.get_unchecked(inp(row + 4, col, cols)) as i8,
74-
*b.get_unchecked(inp(row + 5, col, cols)) as i8,
75-
*b.get_unchecked(inp(row + 6, col, cols)) as i8,
76-
*b.get_unchecked(inp(row + 7, col, cols)) as i8,
77-
*b.get_unchecked(inp(row + 8, col, cols)) as i8,
78-
*b.get_unchecked(inp(row + 9, col, cols)) as i8,
79-
*b.get_unchecked(inp(row + 10, col, cols)) as i8,
80-
*b.get_unchecked(inp(row + 11, col, cols)) as i8,
81-
*b.get_unchecked(inp(row + 12, col, cols)) as i8,
82-
*b.get_unchecked(inp(row + 13, col, cols)) as i8,
83-
*b.get_unchecked(inp(row + 14, col, cols)) as i8,
84-
*b.get_unchecked(inp(row + 15, col, cols)) as i8,
85-
);
86-
transmute(v)
87-
}
88-
#[cfg(not(target_feature = "sse2"))]
89-
{
90-
let bytes = std::array::from_fn(|i| *b.get_unchecked(inp(row + i, col, cols)) as i8);
91-
i8x16::from(bytes)
92-
}
93-
}
66+
fn load_bytes(b: &[u8], row: usize, col: usize, cols: usize) -> i8x16 {
67+
let bytes = std::array::from_fn(|i| b[inp(row + i, col, cols)] as i8);
68+
i8x16::from(bytes)
9469
}
9570

9671
#[cfg(test)]

0 commit comments

Comments
 (0)