11use std:: { arch:: x86_64:: * , hint:: unreachable_unchecked} ;
22
3- #[ inline( always) ]
3+ #[ inline]
4+ #[ target_feature( enable = "avx2" ) ]
45unsafe fn _mm256_slli_epi64_var_shift ( a : __m256i , shift : usize ) -> __m256i {
56 unsafe {
67 match shift {
@@ -14,7 +15,8 @@ unsafe fn _mm256_slli_epi64_var_shift(a: __m256i, shift: usize) -> __m256i {
1415 }
1516}
1617
17- #[ inline( always) ]
18+ #[ inline]
19+ #[ target_feature( enable = "avx2" ) ]
1820unsafe fn _mm256_srli_epi64_var_shift ( a : __m256i , shift : usize ) -> __m256i {
1921 unsafe {
2022 match shift {
@@ -29,8 +31,10 @@ unsafe fn _mm256_srli_epi64_var_shift(a: __m256i, shift: usize) -> __m256i {
2931}
3032
3133// Transpose a 2^block_size_shift x 2^block_size_shift block within a larger
32- // matrix Only handles first two rows out of every 2^block_rows_shift rows
33- #[ inline( always) ] // in each block
34+ // matrix Only handles first two rows out of every 2^block_rows_shift rows in
35+ // each block
36+ #[ inline]
37+ #[ target_feature( enable = "avx2" ) ]
3438unsafe fn avx_transpose_block_iter1 (
3539 in_out : * mut __m256i ,
3640 block_size_shift : usize ,
@@ -85,7 +89,8 @@ unsafe fn avx_transpose_block_iter1(
8589 }
8690}
8791
88- #[ inline( always) ] // Process a range of rows in the matrix
92+ #[ inline] // Process a range of rows in the matrix
93+ #[ target_feature( enable = "avx2" ) ]
8994unsafe fn avx_transpose_block_iter2 (
9095 in_out : * mut __m256i ,
9196 block_size_shift : usize ,
@@ -103,7 +108,8 @@ unsafe fn avx_transpose_block_iter2(
103108 }
104109}
105110
106- #[ inline( always) ] // Main transpose function for blocks within the matrix
111+ #[ inline] // Main transpose function for blocks within the matrix
112+ #[ target_feature( enable = "avx2" ) ]
107113unsafe fn avx_transpose_block (
108114 in_out : * mut __m256i ,
109115 block_size_shift : usize ,
@@ -136,7 +142,8 @@ const AVX_BLOCK_SHIFT: usize = 4;
136142const AVX_BLOCK_SIZE : usize = 1 << AVX_BLOCK_SHIFT ;
137143
138144// Main entry point for matrix transpose
139- pub fn avx_transpose128x128 ( in_out : & mut [ __m256i ; 64 ] ) {
145+ #[ target_feature( enable = "avx2" ) ]
146+ pub unsafe fn avx_transpose128x128 ( in_out : & mut [ __m256i ; 64 ] ) {
140147 const MAT_SIZE_SHIFT : usize = 7 ;
141148 unsafe {
142149 let in_out = in_out. as_mut_ptr ( ) ;
@@ -166,7 +173,8 @@ pub fn avx_transpose128x128(in_out: &mut [__m256i; 64]) {
166173 }
167174}
168175
169- pub fn transpose_bitmatrix ( input : & [ u8 ] , output : & mut [ u8 ] , rows : usize ) {
176+ #[ target_feature( enable = "avx2" ) ]
177+ pub unsafe fn transpose_bitmatrix ( input : & [ u8 ] , output : & mut [ u8 ] , rows : usize ) {
170178 assert_eq ! ( input. len( ) , output. len( ) ) ;
171179 let cols = input. len ( ) * 8 / rows;
172180 assert_eq ! ( 0 , cols % 128 ) ;
@@ -193,8 +201,11 @@ pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
193201 std:: ptr:: copy_nonoverlapping ( src_row, buf_u8_ptr. add ( k * 16 ) , 16 ) ;
194202 }
195203 }
196- // Transpose the 128x128 bit square
197- avx_transpose128x128 ( & mut buf) ;
204+ // SAFETY: avx2 is enabled
205+ unsafe {
206+ // Transpose the 128x128 bit square
207+ avx_transpose128x128 ( & mut buf) ;
208+ }
198209
199210 unsafe {
200211 // needs to be recreated because prev &mut borrow invalidates ptr
@@ -210,7 +221,7 @@ pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
210221 }
211222}
212223
213- #[ cfg( test) ]
224+ #[ cfg( all ( test, target_feature = "avx2" ) ) ]
214225mod tests {
215226 use std:: arch:: x86_64:: _mm256_setzero_si256;
216227
@@ -253,7 +264,9 @@ mod tests {
253264
254265 let mut avx_transposed = v. clone ( ) ;
255266 let mut sse_transposed = v. clone ( ) ;
256- transpose_bitmatrix ( & v, & mut avx_transposed, rows) ;
267+ unsafe {
268+ transpose_bitmatrix ( & v, & mut avx_transposed, rows) ;
269+ }
257270 crate :: transpose:: portable:: transpose_bitmatrix ( & v, & mut sse_transposed, rows) ;
258271
259272 assert_eq ! ( sse_transposed, avx_transposed) ;
0 commit comments