@@ -2312,7 +2312,6 @@ static inline simd_s_t srsran_simd_convert_2f_bf16(simd_f_t a, simd_f_t b)
23122312#ifdef __ARM_NEON
23132313 const uint32x4_t bias = vdupq_n_u32 (0x7fff );
23142314 const uint32x4_t one = vdupq_n_u32 (0x1 );
2315- const uint32x4_t mask = vdupq_n_u32 (0xffff0000 );
23162315
23172316 uint32x4_t a_i32 = vreinterpretq_u32_f32 (a);
23182317 uint32x4_t b_i32 = vreinterpretq_u32_f32 (b);
@@ -2322,15 +2321,15 @@ static inline simd_s_t srsran_simd_convert_2f_bf16(simd_f_t a, simd_f_t b)
23222321
23232322 // Remove the 16 least significant bits of the fractional part.
23242323 uint16x8_t tmp_a_1 = vreinterpretq_u16_u32 (vshrq_n_u32 (a_i32, 16 ));
2325- uint16x8_t tmp_a_2 = vextq_s16 (tmp_a_1, tmp_a_1, 3 );
2326- uint16x8_t a_packed = vorrq_u16 (tmp_a_1, tmp_a_2);
2324+ uint16x8_t tmp_a_2 = vextq_u16 (tmp_a_1, tmp_a_1, 1 );
2325+ uint32x4_t a_packed = vreinterpretq_u32_u16 ( vorrq_u16 (tmp_a_1, tmp_a_2) );
23272326
23282327 // Remove the 16 least significant bits of the fractional part.
23292328 uint16x8_t tmp_b_1 = vreinterpretq_u16_u32 (vshrq_n_u32 (b_i32, 16 ));
2330- uint16x8_t tmp_b_2 = vextq_s16 (tmp_b_1, tmp_b_1, 3 );
2331- uint16x8_t b_packed = vorrq_u16 (tmp_b_1, tmp_b_2);
2329+ uint16x8_t tmp_b_2 = vextq_u16 (tmp_b_1, tmp_b_1, 1 );
2330+ uint32x4_t b_packed = vreinterpretq_u32_u16 ( vorrq_u16 (tmp_b_1, tmp_b_2) );
23322331
2333- ret = vuzpq_u32 (vreinterpretq_u32_u16 ( a_packed), vreinterpretq_u32_u16 ( b_packed)) .val [0 ];
2332+ ret = vreinterpretq_s16_u32 ( vuzpq_u32 (a_packed, b_packed).val [0 ]) ;
23342333#endif /* __ARM_NEON */
23352334#endif /* __SSE4_1__ */
23362335#endif /* __AVX2__ */
@@ -2394,7 +2393,7 @@ static inline simd_s_t srsran_simd_convert_2f_interleaved_bf16(simd_f_t a, simd_
23942393 a_i32 = vaddq_u32 (a_i32, vaddq_u32 (bias, vandq_u32 (vshrq_n_u32 (a_i32, 16 ), one)));
23952394 b_i32 = vaddq_u32 (b_i32, vaddq_u32 (bias, vandq_u32 (vshrq_n_u32 (b_i32, 16 ), one)));
23962395
2397- return vorrq_u32 (vandq_u32 (b_i32, vdupq_n_u32 (0xffff0000 )), vshrq_n_u32 (a_i32, 16 ));
2396+ return vreinterpretq_s16_u32 ( vorrq_u32 (vandq_u32 (b_i32, vdupq_n_u32 (0xffff0000 )), vshrq_n_u32 (a_i32, 16 ) ));
23982397#endif /* __ARM_NEON */
23992398#endif /* __SSE4_1__ */
24002399#endif /* __AVX2__ */
@@ -2459,7 +2458,7 @@ inline void srsran_simd_bf16_storeu(bf16_t* ptr, simd_f_t a, simd_f_t b)
24592458 _mm_storeu_si128 (reinterpret_cast <__m128i*>(ptr), bf16_vec);
24602459#else /* __SSE4_1__ */
24612460#ifdef __ARM_NEON
2462- vst1q_u32 (reinterpret_cast <uint32_t *>(ptr), bf16_vec);
2461+ vst1q_u32 (reinterpret_cast <uint32_t *>(ptr), vreinterpretq_u32_s16 ( bf16_vec) );
24632462#endif /* __ARM_NEON */
24642463#endif /* __SSE4_1__ */
24652464#endif /* __AVX2__ */
@@ -2469,7 +2468,9 @@ inline void srsran_simd_bf16_storeu(bf16_t* ptr, simd_f_t a, simd_f_t b)
24692468#ifdef SRSRAN_SIMD_CF_SIZE
24702469inline void srsran_simd_cbf16_storeu (cbf16_t * ptr, simd_cf_t simdreg)
24712470{
2472- simd_s_t packed_iq_bf16 = srsran_simd_convert_2f_interleaved_bf16 (simdreg.re , simdreg.im );
2471+ simd_s_t packed_iq_bf16 =
2472+ srsran_simd_convert_2f_interleaved_bf16 (srsran_simd_cf_re (simdreg), srsran_simd_cf_im (simdreg));
2473+
24732474#ifdef __AVX512F__
24742475 _mm512_storeu_si512 (reinterpret_cast <__m512i*>(ptr), packed_iq_bf16);
24752476#else /* __AVX512F__ */
@@ -2480,7 +2481,7 @@ inline void srsran_simd_cbf16_storeu(cbf16_t* ptr, simd_cf_t simdreg)
24802481 _mm_storeu_si128 (reinterpret_cast <__m128i*>(ptr), packed_iq_bf16);
24812482#else /* __SSE4_1__ */
24822483#ifdef __ARM_NEON
2483- vst1q_u32 (reinterpret_cast <uint32_t *>(ptr), packed_iq_bf16);
2484+ vst1q_u32 (reinterpret_cast <uint32_t *>(ptr), vreinterpretq_u32_s16 ( packed_iq_bf16) );
24842485#endif /* __ARM_NEON */
24852486#endif /* __SSE4_1__ */
24862487#endif /* __AVX2__ */
0 commit comments