3636 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3737*/
3838
39- #include <arm_neon.h>
40-
4139#ifdef FIXED_POINT
42- #ifdef __thumb2__
40+ #if defined(__aarch64__ )
41+ static inline int32_t saturate_32bit_to_16bit (int32_t a ) {
42+ int32_t ret ;
43+ asm ("fmov s0, %w[a]\n"
44+ "sqxtn h0, s0\n"
45+ "sxtl v0.4s, v0.4h\n"
46+ "fmov %w[ret], s0\n"
47+ : [ret ] "=r" (ret )
48+ : [a ] "r" (a )
49+ : "v0" );
50+ return ret ;
51+ }
52+ #elif defined(__thumb2__ )
4353static inline int32_t saturate_32bit_to_16bit (int32_t a ) {
4454 int32_t ret ;
4555 asm ("ssat %[ret], #16, %[a]"
46- : [ret ] "=& r" (ret )
56+ : [ret ] "=r" (ret )
4757 : [a ] "r" (a )
4858 : );
4959 return ret ;
@@ -54,7 +64,7 @@ static inline int32_t saturate_32bit_to_16bit(int32_t a) {
5464 asm ("vmov.s32 d0[0], %[a]\n"
5565 "vqmovn.s32 d0, q0\n"
5666 "vmov.s16 %[ret], d0[0]\n"
57- : [ret ] "=& r" (ret )
67+ : [ret ] "=r" (ret )
5868 : [a ] "r" (a )
5969 : "q0" );
6070 return ret ;
@@ -64,7 +74,63 @@ static inline int32_t saturate_32bit_to_16bit(int32_t a) {
6474#define WORD2INT (x ) (saturate_32bit_to_16bit(x))
6575
6676#define OVERRIDE_INNER_PRODUCT_SINGLE
67- /* Only works when len % 4 == 0 */
77+ /* Only works when len % 4 == 0 and len >= 4 */
78+ #if defined(__aarch64__ )
79+ static inline int32_t inner_product_single (const int16_t * a , const int16_t * b , unsigned int len )
80+ {
81+ int32_t ret ;
82+ uint32_t remainder = len % 16 ;
83+ len = len - remainder ;
84+
85+ asm volatile (" cmp %w[len], #0\n"
86+ " b.ne 1f\n"
87+ " ld1 {v16.4h}, [%[b]], #8\n"
88+ " ld1 {v20.4h}, [%[a]], #8\n"
89+ " subs %w[remainder], %w[remainder], #4\n"
90+ " smull v0.4s, v16.4h, v20.4h\n"
91+ " b.ne 4f\n"
92+ " b 5f\n"
93+ "1:"
94+ " ld1 {v16.4h, v17.4h, v18.4h, v19.4h}, [%[b]], #32\n"
95+ " ld1 {v20.4h, v21.4h, v22.4h, v23.4h}, [%[a]], #32\n"
96+ " subs %w[len], %w[len], #16\n"
97+ " smull v0.4s, v16.4h, v20.4h\n"
98+ " smlal v0.4s, v17.4h, v21.4h\n"
99+ " smlal v0.4s, v18.4h, v22.4h\n"
100+ " smlal v0.4s, v19.4h, v23.4h\n"
101+ " b.eq 3f\n"
102+ "2:"
103+ " ld1 {v16.4h, v17.4h, v18.4h, v19.4h}, [%[b]], #32\n"
104+ " ld1 {v20.4h, v21.4h, v22.4h, v23.4h}, [%[a]], #32\n"
105+ " subs %w[len], %w[len], #16\n"
106+ " smlal v0.4s, v16.4h, v20.4h\n"
107+ " smlal v0.4s, v17.4h, v21.4h\n"
108+ " smlal v0.4s, v18.4h, v22.4h\n"
109+ " smlal v0.4s, v19.4h, v23.4h\n"
110+ " b.ne 2b\n"
111+ "3:"
112+ " cmp %w[remainder], #0\n"
113+ " b.eq 5f\n"
114+ "4:"
115+ " ld1 {v18.4h}, [%[b]], #8\n"
116+ " ld1 {v22.4h}, [%[a]], #8\n"
117+ " subs %w[remainder], %w[remainder], #4\n"
118+ " smlal v0.4s, v18.4h, v22.4h\n"
119+ " b.ne 4b\n"
120+ "5:"
121+ " saddlv d0, v0.4s\n"
122+ " sqxtn s0, d0\n"
123+ " sqrshrn h0, s0, #15\n"
124+ " sxtl v0.4s, v0.4h\n"
125+ " fmov %w[ret], s0\n"
126+ : [ret ] "=r" (ret ), [a ] "+r" (a ), [b ] "+r" (b ),
127+ [len ] "+r" (len ), [remainder ] "+r" (remainder )
128+ :
129+ : "cc" , "v0" ,
130+ "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23" );
131+ return ret ;
132+ }
133+ #else
68134static inline int32_t inner_product_single (const int16_t * a , const int16_t * b , unsigned int len )
69135{
70136 int32_t ret ;
@@ -112,33 +178,104 @@ static inline int32_t inner_product_single(const int16_t *a, const int16_t *b, u
112178 " vqmovn.s64 d0, q0\n"
113179 " vqrshrn.s32 d0, q0, #15\n"
114180 " vmov.s16 %[ret], d0[0]\n"
115- : [ret ] "=& r" (ret ), [a ] "+r" (a ), [b ] "+r" (b ),
181+ : [ret ] "=r" (ret ), [a ] "+r" (a ), [b ] "+r" (b ),
116182 [len ] "+r" (len ), [remainder ] "+r" (remainder )
117183 :
118184 : "cc" , "q0" ,
119- "d16" , "d17" , "d18" , "d19" ,
120- "d20" , "d21" , "d22" , "d23" );
185+ "d16" , "d17" , "d18" , "d19" , "d20" , "d21" , "d22" , "d23" );
121186
122187 return ret ;
123188}
124- #elif defined(FLOATING_POINT )
189+ #endif // ! defined(__aarch64__ )
125190
191+ #elif defined(FLOATING_POINT )
192+ #if defined(__aarch64__ )
193+ static inline int32_t saturate_float_to_16bit (float a ) {
194+ int32_t ret ;
195+ asm ("fcvtas s1, %s[a]\n"
196+ "sqxtn h1, s1\n"
197+ "sxtl v1.4s, v1.4h\n"
198+ "fmov %w[ret], s1\n"
199+ : [ret ] "=r" (ret )
200+ : [a ] "w" (a )
201+ : "v1" );
202+ return ret ;
203+ }
204+ #else
126205static inline int32_t saturate_float_to_16bit (float a ) {
127206 int32_t ret ;
128207 asm ("vmov.f32 d0[0], %[a]\n"
129208 "vcvt.s32.f32 d0, d0, #15\n"
130209 "vqrshrn.s32 d0, q0, #15\n"
131210 "vmov.s16 %[ret], d0[0]\n"
132- : [ret ] "=& r" (ret )
211+ : [ret ] "=r" (ret )
133212 : [a ] "r" (a )
134213 : "q0" );
135214 return ret ;
136215}
216+ #endif
217+
137218#undef WORD2INT
138219#define WORD2INT (x ) (saturate_float_to_16bit(x))
139220
140221#define OVERRIDE_INNER_PRODUCT_SINGLE
141- /* Only works when len % 4 == 0 */
222+ /* Only works when len % 4 == 0 and len >= 4 */
223+ #if defined(__aarch64__ )
224+ static inline float inner_product_single (const float * a , const float * b , unsigned int len )
225+ {
226+ float ret ;
227+ uint32_t remainder = len % 16 ;
228+ len = len - remainder ;
229+
230+ asm volatile (" cmp %w[len], #0\n"
231+ " b.ne 1f\n"
232+ " ld1 {v16.4s}, [%[b]], #16\n"
233+ " ld1 {v20.4s}, [%[a]], #16\n"
234+ " subs %w[remainder], %w[remainder], #4\n"
235+ " fmul v1.4s, v16.4s, v20.4s\n"
236+ " b.ne 4f\n"
237+ " b 5f\n"
238+ "1:"
239+ " ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[b]], #64\n"
240+ " ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[a]], #64\n"
241+ " subs %w[len], %w[len], #16\n"
242+ " fmul v1.4s, v16.4s, v20.4s\n"
243+ " fmul v2.4s, v17.4s, v21.4s\n"
244+ " fmul v3.4s, v18.4s, v22.4s\n"
245+ " fmul v4.4s, v19.4s, v23.4s\n"
246+ " b.eq 3f\n"
247+ "2:"
248+ " ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[b]], #64\n"
249+ " ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[a]], #64\n"
250+ " subs %w[len], %w[len], #16\n"
251+ " fmla v1.4s, v16.4s, v20.4s\n"
252+ " fmla v2.4s, v17.4s, v21.4s\n"
253+ " fmla v3.4s, v18.4s, v22.4s\n"
254+ " fmla v4.4s, v19.4s, v23.4s\n"
255+ " b.ne 2b\n"
256+ "3:"
257+ " fadd v16.4s, v1.4s, v2.4s\n"
258+ " fadd v17.4s, v3.4s, v4.4s\n"
259+ " cmp %w[remainder], #0\n"
260+ " fadd v1.4s, v16.4s, v17.4s\n"
261+ " b.eq 5f\n"
262+ "4:"
263+ " ld1 {v18.4s}, [%[b]], #16\n"
264+ " ld1 {v22.4s}, [%[a]], #16\n"
265+ " subs %w[remainder], %w[remainder], #4\n"
266+ " fmla v1.4s, v18.4s, v22.4s\n"
267+ " b.ne 4b\n"
268+ "5:"
269+ " faddp v1.4s, v1.4s, v1.4s\n"
270+ " faddp %[ret].4s, v1.4s, v1.4s\n"
271+ : [ret ] "=w" (ret ), [a ] "+r" (a ), [b ] "+r" (b ),
272+ [len ] "+r" (len ), [remainder ] "+r" (remainder )
273+ :
274+ : "cc" , "v1" , "v2" , "v3" , "v4" ,
275+ "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23" );
276+ return ret ;
277+ }
278+ #else
142279static inline float inner_product_single (const float * a , const float * b , unsigned int len )
143280{
144281 float ret ;
@@ -191,11 +328,12 @@ static inline float inner_product_single(const float *a, const float *b, unsigne
191328 " vadd.f32 d0, d0, d1\n"
192329 " vpadd.f32 d0, d0, d0\n"
193330 " vmov.f32 %[ret], d0[0]\n"
194- : [ret ] "=& r" (ret ), [a ] "+r" (a ), [b ] "+r" (b ),
331+ : [ret ] "=r" (ret ), [a ] "+r" (a ), [b ] "+r" (b ),
195332 [len ] "+l" (len ), [remainder ] "+l" (remainder )
196333 :
197- : "cc" , "q0" , "q1" , "q2" , "q3" , "q4" , "q5" , "q6" , "q7" , "q8" ,
198- "q9" , "q10" , "q11" );
334+ : "cc" , "q0" , "q1" , "q2" , "q3" ,
335+ "q4" , "q5" , "q6" , "q7" , "q8" , "q9" , "q10" , "q11" );
199336 return ret ;
200337}
338+ #endif // defined(__aarch64__)
201339#endif
0 commit comments