Skip to content

Commit c33e489

Browse files
committed
cryprot-core: fix avx2 transpose portability
1 parent 0f9dc86 commit c33e489

File tree

3 files changed

+48
-23
lines changed

3 files changed

+48
-23
lines changed

cryprot-core/benches/bench.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1-
use std::{arch::x86_64::_mm256_setzero_si256, mem::transmute};
2-
3-
use criterion::{BatchSize, Criterion, black_box, criterion_group, criterion_main};
1+
use criterion::{BatchSize, Criterion, criterion_group, criterion_main};
42
use cryprot_core::{
53
Block,
64
aes_hash::FIXED_KEY_HASH,
75
aes_rng::AesRng,
86
buf::Buf,
97
transpose::{avx2, portable},
108
};
11-
use rand::{Rng, RngCore, rng};
9+
use rand::{RngCore, rng};
1210

1311
fn criterion_benchmark(c: &mut Criterion) {
1412
let rows = 128;
@@ -26,26 +24,28 @@ fn criterion_benchmark(c: &mut Criterion) {
2624
)
2725
});
2826

27+
#[cfg(target_feature = "avx2")]
2928
c.bench_function("avx2 transpose 128 x 2**20", |b| {
3029
b.iter_batched(
3130
|| {
3231
let mut bitmat = vec![0; rows * cols];
3332
rng().fill_bytes(&mut bitmat);
3433
bitmat
3534
},
36-
|bitmat| avx2::transpose_bitmatrix(&bitmat, &mut out, rows),
35+
|bitmat| unsafe { avx2::transpose_bitmatrix(&bitmat, &mut out, rows) },
3736
BatchSize::SmallInput,
3837
)
3938
});
4039

40+
#[cfg(target_feature = "avx2")]
4141
c.bench_function("avx2 transpose 128 x 128", |b| {
4242
b.iter_batched(
4343
|| {
44-
let mut bitmat = [unsafe { _mm256_setzero_si256() }; 64];
44+
let mut bitmat = [unsafe { std::arch::x86_64::_mm256_setzero_si256() }; 64];
4545
rng().fill_bytes(&mut bytemuck::cast_slice_mut(&mut bitmat));
4646
bitmat
4747
},
48-
|mut bitmat| avx2::avx_transpose128x128(&mut bitmat),
48+
|mut bitmat| unsafe { avx2::avx_transpose128x128(&mut bitmat) },
4949
BatchSize::SmallInput,
5050
)
5151
});

cryprot-core/src/transpose.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
1-
// #[cfg(target_feature = "avx2")]
1+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
22
pub mod avx2;
33
pub mod portable;
44

5+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
6+
cpufeatures::new!(target_feature_avx2, "avx2");
7+
8+
/// Transpose bit matrix.
9+
///
10+
/// # Panics
11+
/// If `rows % 128 != 0`
12+
/// If for `let cols = input.len() * 8 / rows`, `cols % 128 != 0`
513
pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
6-
#[cfg(target_feature = "avx2")]
7-
avx2::transpose_bitmatrix(input, output, rows);
8-
#[cfg(not(target_feature = "avx2"))]
14+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
15+
if target_feature_avx2::get() {
16+
unsafe { avx2::transpose_bitmatrix(input, output, rows) }
17+
} else {
18+
portable::transpose_bitmatrix(input, output, rows);
19+
}
20+
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
921
portable::transpose_bitmatrix(input, output, rows);
1022
}

cryprot-core/src/transpose/avx2.rs

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::{arch::x86_64::*, hint::unreachable_unchecked};
22

3-
#[inline(always)]
3+
#[inline]
4+
#[target_feature(enable = "avx2")]
45
unsafe fn _mm256_slli_epi64_var_shift(a: __m256i, shift: usize) -> __m256i {
56
unsafe {
67
match shift {
@@ -14,7 +15,8 @@ unsafe fn _mm256_slli_epi64_var_shift(a: __m256i, shift: usize) -> __m256i {
1415
}
1516
}
1617

17-
#[inline(always)]
18+
#[inline]
19+
#[target_feature(enable = "avx2")]
1820
unsafe fn _mm256_srli_epi64_var_shift(a: __m256i, shift: usize) -> __m256i {
1921
unsafe {
2022
match shift {
@@ -29,8 +31,10 @@ unsafe fn _mm256_srli_epi64_var_shift(a: __m256i, shift: usize) -> __m256i {
2931
}
3032

3133
// Transpose a 2^block_size_shift x 2^block_size_shift block within a larger
32-
// matrix Only handles first two rows out of every 2^block_rows_shift rows
33-
#[inline(always)] // in each block
34+
// matrix Only handles first two rows out of every 2^block_rows_shift rows in
35+
// each block
36+
#[inline]
37+
#[target_feature(enable = "avx2")]
3438
unsafe fn avx_transpose_block_iter1(
3539
in_out: *mut __m256i,
3640
block_size_shift: usize,
@@ -85,7 +89,8 @@ unsafe fn avx_transpose_block_iter1(
8589
}
8690
}
8791

88-
#[inline(always)] // Process a range of rows in the matrix
92+
#[inline] // Process a range of rows in the matrix
93+
#[target_feature(enable = "avx2")]
8994
unsafe fn avx_transpose_block_iter2(
9095
in_out: *mut __m256i,
9196
block_size_shift: usize,
@@ -103,7 +108,8 @@ unsafe fn avx_transpose_block_iter2(
103108
}
104109
}
105110

106-
#[inline(always)] // Main transpose function for blocks within the matrix
111+
#[inline] // Main transpose function for blocks within the matrix
112+
#[target_feature(enable = "avx2")]
107113
unsafe fn avx_transpose_block(
108114
in_out: *mut __m256i,
109115
block_size_shift: usize,
@@ -136,7 +142,8 @@ const AVX_BLOCK_SHIFT: usize = 4;
136142
const AVX_BLOCK_SIZE: usize = 1 << AVX_BLOCK_SHIFT;
137143

138144
// Main entry point for matrix transpose
139-
pub fn avx_transpose128x128(in_out: &mut [__m256i; 64]) {
145+
#[target_feature(enable = "avx2")]
146+
pub unsafe fn avx_transpose128x128(in_out: &mut [__m256i; 64]) {
140147
const MAT_SIZE_SHIFT: usize = 7;
141148
unsafe {
142149
let in_out = in_out.as_mut_ptr();
@@ -166,7 +173,8 @@ pub fn avx_transpose128x128(in_out: &mut [__m256i; 64]) {
166173
}
167174
}
168175

169-
pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
176+
#[target_feature(enable = "avx2")]
177+
pub unsafe fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
170178
assert_eq!(input.len(), output.len());
171179
let cols = input.len() * 8 / rows;
172180
assert_eq!(0, cols % 128);
@@ -193,8 +201,11 @@ pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
193201
std::ptr::copy_nonoverlapping(src_row, buf_u8_ptr.add(k * 16), 16);
194202
}
195203
}
196-
// Transpose the 128x128 bit square
197-
avx_transpose128x128(&mut buf);
204+
// SAFETY: avx2 is enabled
205+
unsafe {
206+
// Transpose the 128x128 bit square
207+
avx_transpose128x128(&mut buf);
208+
}
198209

199210
unsafe {
200211
// needs to be recreated because prev &mut borrow invalidates ptr
@@ -210,7 +221,7 @@ pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
210221
}
211222
}
212223

213-
#[cfg(test)]
224+
#[cfg(all(test, target_feature = "avx2"))]
214225
mod tests {
215226
use std::arch::x86_64::_mm256_setzero_si256;
216227

@@ -253,7 +264,9 @@ mod tests {
253264

254265
let mut avx_transposed = v.clone();
255266
let mut sse_transposed = v.clone();
256-
transpose_bitmatrix(&v, &mut avx_transposed, rows);
267+
unsafe {
268+
transpose_bitmatrix(&v, &mut avx_transposed, rows);
269+
}
257270
crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows);
258271

259272
assert_eq!(sse_transposed, avx_transposed);

0 commit comments

Comments
 (0)