@@ -3,47 +3,91 @@ use wide::{i8x16, i64x2};
33/// Transpose a bit matrix.
44///
55/// # Panics
6- /// TODO
7- pub fn transpose_bitmatrix ( input : & [ u8 ] , output : & mut [ u8 ] , rows : usize ) {
8- assert ! ( rows >= 16 ) ;
9- assert_eq ! ( 0 , rows % 16 ) ;
10- assert_eq ! ( 0 , input. len( ) % rows) ;
6+ /// - If `rows < 16`
7+ /// - If `rows` is not divisible by 16
8+ /// - If `input.len()` is not divisible by `rows`
9+ /// - If the number of columns, computed as `input.len() * 8 / rows` is less
10+ /// than 16
11+ /// - If the number of columns is not divisible by 8
12+ pub ( super ) fn transpose_bitmatrix ( input : & [ u8 ] , output : & mut [ u8 ] , rows : usize ) {
13+ assert ! ( rows >= 16 , "rows must be at least 16" ) ;
14+ assert_eq ! ( 0 , rows % 16 , "rows must be divisible by 16" ) ;
15+ assert_eq ! (
16+ 0 ,
17+ input. len( ) % rows,
18+ "input.len() must be divisible by rows"
19+ ) ;
1120 let cols = input. len ( ) * 8 / rows;
12- assert ! ( cols >= 16 ) ;
21+ assert ! ( cols >= 16 , "columns must be at least 16. Columns {cols}" ) ;
1322 assert_eq ! (
1423 0 ,
1524 cols % 8 ,
1625 "Number of bitmatrix columns must be divisable by 8. columns: {cols}"
1726 ) ;
1827
19- unsafe {
20- let mut row: usize = 0 ;
21- while row <= rows - 16 {
22- let mut col = 0 ;
23- while col < cols {
24- let mut v = load_bytes ( input, row, col, cols) ;
25- // reverse iterator because we start writing the msb of each byte, then shift
26- // left for i = 0, we write the previous lsb
27- for i in ( 0 ..8 ) . rev ( ) {
28- // get msb of each byte
29- let msbs = v. to_bitmask ( ) . to_le_bytes ( ) ;
30- // write msbs to output at transposed position as one i16
31- let msb_i16 = i16:: from_ne_bytes ( [ msbs[ 0 ] , msbs[ 1 ] ] ) ;
32- let idx = out ( row, col + i, rows) as isize ;
33- let out_ptr = output. as_mut_ptr ( ) . offset ( idx) as * mut i16 ;
34- // ptr is potentially unaligned
35- out_ptr. write_unaligned ( msb_i16) ;
28+ // Transpose a matrix by splitting it into 16x8 blocks (=16 bytes which can be
29+ // one SSE2 128 bit vector depending on target support), write transposed
30+ // block into output at the transposed position and continue with next
31+ // block.
32+
33+ let mut row: usize = 0 ;
34+ while row <= rows - 16 {
35+ let mut col = 0 ;
36+ while col < cols {
37+ // Load 16x8 sub-block by loading (row + 0, col) .. (row + 15, col)
38+ let mut v = load_bytes ( input, row, col, cols) ;
39+ // The ideas is to take the most significant bit of each row of the sub-block
40+ // (msb of each byte) and write these 16 bits coming from 16 rows
41+ // and one column in the input to the correct row and columns in the
42+ // output. Because the `move_mask` instruction gives us the `msb` of each byte
43+ // (=row) in our input, we iterate the output_row_offset from large
44+ // to small. After each iteration, we shift each byte in the
45+ // sub-block one bit to the left and then again get the msb of each byte and
46+ // write it to the next row.
47+ // Visualization done by Claude 4 Sonnet:
48+ // ┌─────────────────────────────────────────────────────────────┐
49+ // │ BIT MATRIX TRANSPOSE: 16x8 BLOCK │
50+ // ├─────────────────────────────────────────────────────────────┤
51+ // │ │
52+ // │ INPUT (16×8) OUTPUT (8×16) │
53+ // │ │
54+ // │ 0 1 ⋯ 6 7 0 1 ⋯ E F │
55+ // │ 0 ◆|◇|⋯|▲|△ 0 ◆|◆|⋯|◆|◆ │
56+ // │ 1 ◆|◇|⋯|▲|△ 1 ◇|◇|⋯|◇|◇ │
57+ // │ ⋮ ⋮ │
58+ // │ E ◆|◇|⋯|▲|△ 6 ▲|▲|⋯|▲|▲ │
59+ // │ F ◆|◇|⋯|▲|△ 7 △|△|⋯|△|△ │
60+ // │ │
61+ // │ MOVE_MASK: Extract column bits → Write as rows │
62+ // │ │
63+ // │ Iter 0: MSB column 7 → output row 7 │
64+ // │ Iter 1: << 1, MSB → output row 6 │
65+ // │ ⋮ │
66+ // │ Iter 7: << 7, MSB → output row 0 │
67+ // │ │
68+ // │ INPUT[row,col] → OUTPUT[col,row] │
69+ // │ │
70+ // └─────────────────────────────────────────────────────────────┘
71+ for output_row_offset in ( 0 ..8 ) . rev ( ) {
72+ // get msb of each byte
73+ let msbs = v. to_bitmask ( ) . to_le_bytes ( ) ;
74+ // write msbs to output at transposed position
75+ let idx = out ( row, col + output_row_offset, rows) as isize ;
76+ // This should result in only one bounds check for the output
77+ let out_bytes = & mut output[ idx as usize ..idx as usize + 2 ] ;
78+ out_bytes[ 0 ] = msbs[ 0 ] ;
79+ out_bytes[ 1 ] = msbs[ 1 ] ;
3680
37- // SAFETY: u8x16 and i64x2 have the same layout
38- // we need to convert cast it, because there is no shift impl for u8x16
39- let v_i64x2 = & mut v as * mut _ as * mut i64x2 ;
40- // shift each byte by one to the left (by shifting it as two i64)
41- * v_i64x2 = * v_i64x2 << 1 ;
42- }
43- col += 8 ;
81+ // There is no shift impl for i8x16 so we cast to i64x2 and shift these.
82+ // The bits shifted to neighbouring bytes are ignored because we iterate
83+ // and call move_mask 8 times.
84+ let v: & mut i64x2 = bytemuck:: must_cast_mut ( & mut v) ;
85+ // shift each byte by one to the left (by shifting it as two i64)
86+ * v = * v << 1 ;
4487 }
45- row += 16 ;
88+ col += 8 ;
4689 }
90+ row += 16 ;
4791 }
4892}
4993
@@ -58,39 +102,9 @@ fn out(x: usize, y: usize, rows: usize) -> usize {
58102
59103#[ inline]
60104// get col byte of row to row + 15
61- unsafe fn load_bytes ( b : & [ u8 ] , row : usize , col : usize , cols : usize ) -> i8x16 {
62- unsafe {
63- // if we have sse2 we use _mm_setr_epi8 and transmute to convert bytes
64- // faster than from impl
65- #[ cfg( target_feature = "sse2" ) ]
66- {
67- use std:: { arch:: x86_64:: _mm_setr_epi8, mem:: transmute} ;
68- let v = _mm_setr_epi8 (
69- * b. get_unchecked ( inp ( row, col, cols) ) as i8 ,
70- * b. get_unchecked ( inp ( row + 1 , col, cols) ) as i8 ,
71- * b. get_unchecked ( inp ( row + 2 , col, cols) ) as i8 ,
72- * b. get_unchecked ( inp ( row + 3 , col, cols) ) as i8 ,
73- * b. get_unchecked ( inp ( row + 4 , col, cols) ) as i8 ,
74- * b. get_unchecked ( inp ( row + 5 , col, cols) ) as i8 ,
75- * b. get_unchecked ( inp ( row + 6 , col, cols) ) as i8 ,
76- * b. get_unchecked ( inp ( row + 7 , col, cols) ) as i8 ,
77- * b. get_unchecked ( inp ( row + 8 , col, cols) ) as i8 ,
78- * b. get_unchecked ( inp ( row + 9 , col, cols) ) as i8 ,
79- * b. get_unchecked ( inp ( row + 10 , col, cols) ) as i8 ,
80- * b. get_unchecked ( inp ( row + 11 , col, cols) ) as i8 ,
81- * b. get_unchecked ( inp ( row + 12 , col, cols) ) as i8 ,
82- * b. get_unchecked ( inp ( row + 13 , col, cols) ) as i8 ,
83- * b. get_unchecked ( inp ( row + 14 , col, cols) ) as i8 ,
84- * b. get_unchecked ( inp ( row + 15 , col, cols) ) as i8 ,
85- ) ;
86- transmute ( v)
87- }
88- #[ cfg( not( target_feature = "sse2" ) ) ]
89- {
90- let bytes = std:: array:: from_fn ( |i| * b. get_unchecked ( inp ( row + i, col, cols) ) as i8 ) ;
91- i8x16:: from ( bytes)
92- }
93- }
105+ fn load_bytes ( b : & [ u8 ] , row : usize , col : usize , cols : usize ) -> i8x16 {
106+ let bytes = std:: array:: from_fn ( |i| b[ inp ( row + i, col, cols) ] as i8 ) ;
107+ i8x16:: from ( bytes)
94108}
95109
96110#[ cfg( test) ]
0 commit comments