Skip to content

Commit 7e1681a

Browse files
xavierarteagacodebot
authored andcommitted
phy: optimize precoding for PDSCH DM-RS
fix NEON compilation phy: fix compilation
1 parent 3ee54c4 commit 7e1681a

22 files changed

+123
-140
lines changed

include/srsran/phy/generic_functions/precoding/channel_precoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class channel_precoder
3434
/// of RE per layer of the input buffer.
3535
/// \remark An assertion is triggered if the precoding matrix dimensions do not match the number of layers of the
3636
/// input buffer and the number of antenna ports of the output buffer.
37-
virtual void apply_precoding(re_buffer_writer<>& output,
37+
virtual void apply_precoding(re_buffer_writer<cbf16_t>& output,
3838
const re_buffer_reader<>& input,
3939
const precoding_weight_matrix& precoding) const = 0;
4040

include/srsran/phy/support/resource_grid_writer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class resource_grid_writer : public resource_grid_base
8585
/// \param[in] symbols Symbols to be written into the resource grid.
8686
/// \note The RE positions given \c k_init, the number of elements in \c symbols and the \c stride shall be within the
8787
/// resource grid number of subcarriers.
88-
virtual void put(unsigned port, unsigned l, unsigned k_init, unsigned stride, span<const cf_t> symbols) = 0;
88+
virtual void put(unsigned port, unsigned l, unsigned k_init, unsigned stride, span<const cbf16_t> symbols) = 0;
8989

9090
/// \brief Gets a read-write view of an OFDM symbol for a given port.
9191
///

lib/phy/generic_functions/precoding/channel_precoder_avx2.cpp

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,26 @@ simd_cf_interleaved operator*(const simd_cf_interleaved& re, const simd_cf_t& we
4343
return _mm256_fmaddsub_ps(re, weight.re, _mm256_mul_ps(_mm256_shuffle_ps(re, re, 0xb1), weight.im));
4444
}
4545

46+
inline __m128i ps_to_cbf16(simd_cf_interleaved in)
47+
{
48+
const __m256i bias = _mm256_set1_epi32(0x7fff);
49+
const __m256i one = _mm256_set1_epi32(0x1);
50+
51+
__m256i a_i32 = _mm256_castps_si256(in);
52+
53+
// Round to nearest even.
54+
a_i32 = _mm256_add_epi32(a_i32, _mm256_add_epi32(bias, _mm256_and_si256(_mm256_srli_epi32(a_i32, 16), one)));
55+
56+
// Shift right 16 bits.
57+
a_i32 = _mm256_srai_epi32(a_i32, 16);
58+
59+
// Pack both parts in 32-bit registers.
60+
return _mm_packs_epi32(_mm256_extractf128_si256(a_i32, 0), _mm256_extractf128_si256(a_i32, 1));
61+
}
62+
4663
} // namespace
4764

48-
void channel_precoder_avx2::apply_precoding_port(span<cf_t> port_re,
65+
void channel_precoder_avx2::apply_precoding_port(span<cbf16_t> port_re,
4966
const re_buffer_reader<>& input_re,
5067
span<const cf_t> port_weights) const
5168
{
@@ -84,15 +101,16 @@ void channel_precoder_avx2::apply_precoding_port(span<cf_t> port_
84101
}
85102

86103
// Store.
87-
_mm256_storeu_ps(reinterpret_cast<float*>(&port_re[i_re]), re_out);
104+
_mm_storeu_si128(reinterpret_cast<__m128i*>(&port_re[i_re]), ps_to_cbf16(re_out));
88105
}
89106

90107
for (; i_re != nof_re; ++i_re) {
91-
port_re[i_re] = layer_re_view_list[0][i_re] * port_weights[0];
108+
cf_t sum = layer_re_view_list[0][i_re] * port_weights[0];
92109

93110
for (unsigned i_layer = 1; i_layer != nof_layers; ++i_layer) {
94-
port_re[i_re] += layer_re_view_list[i_layer][i_re] * port_weights[i_layer];
111+
sum += layer_re_view_list[i_layer][i_re] * port_weights[i_layer];
95112
}
113+
port_re[i_re] = sum;
96114
}
97115
}
98116

@@ -181,23 +199,6 @@ static inline void layer4_map_and_ci8_to_cf(simd_cf_interleaved& out_l0,
181199
from_ci8_to_cf(out_l0, out_l1, out_l2, out_l3, tmp);
182200
}
183201

184-
inline __m128i ps_to_cbf16(simd_cf_interleaved in)
185-
{
186-
const __m256i bias = _mm256_set1_epi32(0x7fff);
187-
const __m256i one = _mm256_set1_epi32(0x1);
188-
189-
__m256i a_i32 = _mm256_castps_si256(in);
190-
191-
// Round to nearest even.
192-
a_i32 = _mm256_add_epi32(a_i32, _mm256_add_epi32(bias, _mm256_and_si256(_mm256_srli_epi32(a_i32, 16), one)));
193-
194-
// Shift right 16 bits.
195-
a_i32 = _mm256_srai_epi32(a_i32, 16);
196-
197-
// Pack both parts in 32-bit registers.
198-
return _mm_packs_epi32(_mm256_extractf128_si256(a_i32, 0), _mm256_extractf128_si256(a_i32, 1));
199-
}
200-
201202
void channel_precoder_avx2::apply_layer_map_and_precoding(re_buffer_writer<cbf16_t>& output,
202203
span<const ci8_t> input,
203204
const precoding_weight_matrix& precoding) const

lib/phy/generic_functions/precoding/channel_precoder_avx2.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class channel_precoder_avx2 : public channel_precoder_impl
2222
{
2323
public:
2424
// See interface for documentation.
25-
void apply_precoding_port(span<cf_t> port_re,
25+
void apply_precoding_port(span<cbf16_t> port_re,
2626
const re_buffer_reader<>& input_re,
2727
span<const cf_t> port_weights) const override;
2828

lib/phy/generic_functions/precoding/channel_precoder_avx512.cpp

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,27 @@ struct simd_cf_t {
3737
// Type to hold a set of complex numbers using an AVX512 register, with interleaved real and imaginary parts.
3838
using simd_cf_interleaved = __m512;
3939

40+
inline __m256i ps_to_cbf16(simd_cf_interleaved in)
41+
{
42+
#if __AVX512BF16__
43+
return (__m256i)_mm512_cvtneps_pbh(in);
44+
#else // __AVX512BF16__
45+
const __m512i bias = _mm512_set1_epi32(0x7fff);
46+
const __m512i one = _mm512_set1_epi32(0x1);
47+
48+
__m512i a_i32 = _mm512_castps_si512(in);
49+
50+
// Round to nearest even.
51+
a_i32 = _mm512_add_epi32(a_i32, _mm512_add_epi32(bias, _mm512_and_si512(_mm512_srli_epi32(a_i32, 16), one)));
52+
53+
// Shift right 16 bits.
54+
a_i32 = _mm512_srli_epi32(a_i32, 16);
55+
56+
// Pack both parts in 32-bit registers.
57+
return _mm512_cvtepi32_epi16(a_i32);
58+
#endif // __AVX512BF16__
59+
}
60+
4061
} // namespace
4162

4263
// Multiplication operator for the precoding weights.
@@ -203,28 +224,7 @@ static inline void layer4_map_and_ci8_to_cf(simd_cf_interleaved& out0,
203224
from_ci8_to_cf(out0, out1, out2, out3, tmp);
204225
}
205226

206-
inline __m256i ps_to_cbf16(simd_cf_interleaved in)
207-
{
208-
#if __AVX512BF16__
209-
return (__m256i)_mm512_cvtneps_pbh(in);
210-
#else // __AVX512BF16__
211-
const __m512i bias = _mm512_set1_epi32(0x7fff);
212-
const __m512i one = _mm512_set1_epi32(0x1);
213-
214-
__m512i a_i32 = _mm512_castps_si512(in);
215-
216-
// Round to nearest even.
217-
a_i32 = _mm512_add_epi32(a_i32, _mm512_add_epi32(bias, _mm512_and_si512(_mm512_srli_epi32(a_i32, 16), one)));
218-
219-
// Shift right 16 bits.
220-
a_i32 = _mm512_srli_epi32(a_i32, 16);
221-
222-
// Pack both parts in 32-bit registers.
223-
return _mm512_cvtepi32_epi16(a_i32);
224-
#endif // __AVX512BF16__
225-
}
226-
227-
void channel_precoder_avx512::apply_precoding_port(span<cf_t> port_re,
227+
void channel_precoder_avx512::apply_precoding_port(span<cbf16_t> port_re,
228228
const re_buffer_reader<>& input_re,
229229
span<const cf_t> port_weights) const
230230
{
@@ -263,15 +263,15 @@ void channel_precoder_avx512::apply_precoding_port(span<cf_t> por
263263
}
264264

265265
// Store.
266-
_mm512_storeu_ps(reinterpret_cast<float*>(&port_re[i_re]), re_out);
266+
_mm256_storeu_si256(reinterpret_cast<__m256i*>(&port_re[i_re]), ps_to_cbf16(re_out));
267267
}
268268

269269
for (; i_re != nof_re; ++i_re) {
270-
port_re[i_re] = layer_re_view_list[0][i_re] * port_weights[0];
271-
270+
cf_t sum = layer_re_view_list[0][i_re] * port_weights[0];
272271
for (unsigned i_layer = 1; i_layer != nof_layers; ++i_layer) {
273-
port_re[i_re] += layer_re_view_list[i_layer][i_re] * port_weights[i_layer];
272+
sum += layer_re_view_list[i_layer][i_re] * port_weights[i_layer];
274273
}
274+
port_re[i_re] = sum;
275275
}
276276
}
277277

lib/phy/generic_functions/precoding/channel_precoder_avx512.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace srsran {
2121
class channel_precoder_avx512 : public channel_precoder_impl
2222
{
2323
// See interface for documentation.
24-
void apply_precoding_port(span<cf_t> port_re,
24+
void apply_precoding_port(span<cbf16_t> port_re,
2525
const re_buffer_reader<>& input_re,
2626
span<const cf_t> port_weights) const override;
2727

lib/phy/generic_functions/precoding/channel_precoder_generic.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
using namespace srsran;
1414

15-
void channel_precoder_generic::apply_precoding_port(span<cf_t> port_re,
15+
void channel_precoder_generic::apply_precoding_port(span<cbf16_t> port_re,
1616
const re_buffer_reader<>& input_re,
1717
span<const cf_t> port_weights) const
1818
{
@@ -26,12 +26,13 @@ void channel_precoder_generic::apply_precoding_port(span<cf_t> po
2626

2727
for (unsigned i_re = 0; i_re != nof_re; ++i_re) {
2828
// Set the port RE to the contribution of the first layer.
29-
port_re[i_re] = layer_re_view_list[0][i_re] * port_weights[0];
29+
cf_t sum = layer_re_view_list[0][i_re] * port_weights[0];
3030

3131
for (unsigned i_layer = 1; i_layer != nof_layers; ++i_layer) {
3232
// Accumulate the contributions of all other layers.
33-
port_re[i_re] += layer_re_view_list[i_layer][i_re] * port_weights[i_layer];
33+
sum += layer_re_view_list[i_layer][i_re] * port_weights[i_layer];
3434
}
35+
port_re[i_re] = sum;
3536
}
3637
}
3738

lib/phy/generic_functions/precoding/channel_precoder_generic.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace srsran {
2121
class channel_precoder_generic : public channel_precoder_impl
2222
{
2323
// See interface for documentation.
24-
void apply_precoding_port(span<cf_t> port_re,
24+
void apply_precoding_port(span<cbf16_t> port_re,
2525
const re_buffer_reader<>& input_re,
2626
span<const cf_t> port_weights) const override;
2727

lib/phy/generic_functions/precoding/channel_precoder_impl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
using namespace srsran;
1414

15-
void channel_precoder_impl::apply_precoding(re_buffer_writer<>& output,
15+
void channel_precoder_impl::apply_precoding(re_buffer_writer<cbf16_t>& output,
1616
const re_buffer_reader<>& input,
1717
const precoding_weight_matrix& precoding) const
1818
{
@@ -46,7 +46,7 @@ void channel_precoder_impl::apply_precoding(re_buffer_writer<>& outpu
4646

4747
for (unsigned i_port = 0; i_port != nof_tx_ports; ++i_port) {
4848
// View of the output RE for a single antenna port.
49-
span<cf_t> port_re_view = output.get_slice(i_port);
49+
span<cbf16_t> port_re_view = output.get_slice(i_port);
5050

5151
// View of the precoding weights applicable to a single antenna port, i.e., the coefficients applied to each
5252
// layer for the antenna port.

lib/phy/generic_functions/precoding/channel_precoder_impl.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class channel_precoder_impl : public channel_precoder
2525
explicit channel_precoder_impl() = default;
2626

2727
// See interface for documentation.
28-
void apply_precoding(re_buffer_writer<>& output,
28+
void apply_precoding(re_buffer_writer<cbf16_t>& output,
2929
const re_buffer_reader<>& input,
3030
const precoding_weight_matrix& precoding) const override;
3131

@@ -35,8 +35,9 @@ class channel_precoder_impl : public channel_precoder
3535
/// \param[out] port_re View over the RE of a single antenna port.
3636
/// \param[in] input Input symbols, indexed by RE and transmit layer.
3737
/// \param[in] precoding Precoding coefficients, indexed by layer.
38-
virtual void
39-
apply_precoding_port(span<cf_t> port_re, const re_buffer_reader<>& input_re, span<const cf_t> port_weights) const = 0;
38+
virtual void apply_precoding_port(span<cbf16_t> port_re,
39+
const re_buffer_reader<>& input_re,
40+
span<const cf_t> port_weights) const = 0;
4041
};
4142

4243
} // namespace srsran

0 commit comments

Comments
 (0)