@@ -3,47 +3,52 @@ use wide::{i8x16, i64x2};
33/// Transpose a bit matrix.
44///
55/// # Panics
6- /// TODO
6+ /// - If `rows < 16`
7+ /// - If `rows` is not divisible by 16
8+ /// - If `input.len()` is not divisible by `rows`
9+ /// - If the number of columns, computed as `input.len() * 8 / rows` is less
10+ /// than 16
11+ /// - If the number of columns is not divisible by 8
712pub fn transpose_bitmatrix ( input : & [ u8 ] , output : & mut [ u8 ] , rows : usize ) {
8- assert ! ( rows >= 16 ) ;
9- assert_eq ! ( 0 , rows % 16 ) ;
10- assert_eq ! ( 0 , input. len( ) % rows) ;
13+ assert ! ( rows >= 16 , "rows must be at least 16" ) ;
14+ assert_eq ! ( 0 , rows % 16 , "rows must be divisible by 16" ) ;
15+ assert_eq ! (
16+ 0 ,
17+ input. len( ) % rows,
18+ "input.len() must be divisible by rows"
19+ ) ;
1120 let cols = input. len ( ) * 8 / rows;
12- assert ! ( cols >= 16 ) ;
21+ assert ! ( cols >= 16 , "columns must be at least 16. Columns {cols}" ) ;
1322 assert_eq ! (
1423 0 ,
1524 cols % 8 ,
1625 "Number of bitmatrix columns must be divisable by 8. columns: {cols}"
1726 ) ;
1827
19- unsafe {
20- let mut row: usize = 0 ;
21- while row <= rows - 16 {
22- let mut col = 0 ;
23- while col < cols {
24- let mut v = load_bytes ( input, row, col, cols) ;
25- // reverse iterator because we start writing the msb of each byte, then shift
26- // left for i = 0, we write the previous lsb
27- for i in ( 0 ..8 ) . rev ( ) {
28- // get msb of each byte
29- let msbs = v. move_mask ( ) . to_le_bytes ( ) ;
30- // write msbs to output at transposed position as one i16
31- let msb_i16 = i16:: from_ne_bytes ( [ msbs[ 0 ] , msbs[ 1 ] ] ) ;
32- let idx = out ( row, col + i, rows) as isize ;
33- let out_ptr = output. as_mut_ptr ( ) . offset ( idx) as * mut i16 ;
34- // ptr is potentially unaligned
35- out_ptr. write_unaligned ( msb_i16) ;
28+ let mut row: usize = 0 ;
29+ while row <= rows - 16 {
30+ let mut col = 0 ;
31+ while col < cols {
32+ let mut v = load_bytes ( input, row, col, cols) ;
33+ // reverse iterator because we start writing the msb of each byte, then shift
34+ // left for i = 0, we write the previous lsb
35+ for i in ( 0 ..8 ) . rev ( ) {
36+ // get msb of each byte
37+ let msbs = v. move_mask ( ) . to_le_bytes ( ) ;
38+ // write msbs to output at transposed position
39+ let idx = out ( row, col + i, rows) as isize ;
40+ // This should result in only one bounds check for the output
41+ let out_bytes = & mut output[ idx as usize ..idx as usize + 2 ] ;
42+ out_bytes[ 0 ] = msbs[ 0 ] ;
43+ out_bytes[ 1 ] = msbs[ 1 ] ;
3644
37- // SAFETY: u8x16 and i64x2 have the same layout
38- // we need to convert cast it, because there is no shift impl for u8x16
39- let v_i64x2 = & mut v as * mut _ as * mut i64x2 ;
40- // shift each byte by one to the left (by shifting it as two i64)
41- * v_i64x2 = * v_i64x2 << 1 ;
42- }
43- col += 8 ;
45+ let v: & mut i64x2 = bytemuck:: must_cast_mut ( & mut v) ;
46+ // shift each byte by one to the left (by shifting it as two i64)
47+ * v = * v << 1 ;
4448 }
45- row += 16 ;
49+ col += 8 ;
4650 }
51+ row += 16 ;
4752 }
4853}
4954
@@ -58,39 +63,9 @@ fn out(x: usize, y: usize, rows: usize) -> usize {
5863
5964#[ inline]
6065// get col byte of row to row + 15
61- unsafe fn load_bytes ( b : & [ u8 ] , row : usize , col : usize , cols : usize ) -> i8x16 {
62- unsafe {
63- // if we have sse2 we use _mm_setr_epi8 and transmute to convert bytes
64- // faster than from impl
65- #[ cfg( target_feature = "sse2" ) ]
66- {
67- use std:: { arch:: x86_64:: _mm_setr_epi8, mem:: transmute} ;
68- let v = _mm_setr_epi8 (
69- * b. get_unchecked ( inp ( row, col, cols) ) as i8 ,
70- * b. get_unchecked ( inp ( row + 1 , col, cols) ) as i8 ,
71- * b. get_unchecked ( inp ( row + 2 , col, cols) ) as i8 ,
72- * b. get_unchecked ( inp ( row + 3 , col, cols) ) as i8 ,
73- * b. get_unchecked ( inp ( row + 4 , col, cols) ) as i8 ,
74- * b. get_unchecked ( inp ( row + 5 , col, cols) ) as i8 ,
75- * b. get_unchecked ( inp ( row + 6 , col, cols) ) as i8 ,
76- * b. get_unchecked ( inp ( row + 7 , col, cols) ) as i8 ,
77- * b. get_unchecked ( inp ( row + 8 , col, cols) ) as i8 ,
78- * b. get_unchecked ( inp ( row + 9 , col, cols) ) as i8 ,
79- * b. get_unchecked ( inp ( row + 10 , col, cols) ) as i8 ,
80- * b. get_unchecked ( inp ( row + 11 , col, cols) ) as i8 ,
81- * b. get_unchecked ( inp ( row + 12 , col, cols) ) as i8 ,
82- * b. get_unchecked ( inp ( row + 13 , col, cols) ) as i8 ,
83- * b. get_unchecked ( inp ( row + 14 , col, cols) ) as i8 ,
84- * b. get_unchecked ( inp ( row + 15 , col, cols) ) as i8 ,
85- ) ;
86- transmute ( v)
87- }
88- #[ cfg( not( target_feature = "sse2" ) ) ]
89- {
90- let bytes = std:: array:: from_fn ( |i| * b. get_unchecked ( inp ( row + i, col, cols) ) as i8 ) ;
91- i8x16:: from ( bytes)
92- }
93- }
66+ fn load_bytes ( b : & [ u8 ] , row : usize , col : usize , cols : usize ) -> i8x16 {
67+ let bytes = std:: array:: from_fn ( |i| b[ inp ( row + i, col, cols) ] as i8 ) ;
68+ i8x16:: from ( bytes)
9469}
9570
9671#[ cfg( test) ]
0 commit comments