Skip to content

Commit edb1688

Browse files
AArch64: Add Neon implementation of load_tmvs
This patch adds a vectorised variant of the mv_projection calculation and a faster initialisation of motion vectors for load_tmvs_neon. Checkasm uplifts after this patch on some Neoverse and Cortex CPU cores compared to the C reference compiled with GCC-13 and Clang-19: GCC Clang AWS Graviton 4: 1.62x 1.59x Cortex-X4: 1.45x 1.46x Cortex-X3: 1.68x 1.69x Cortex-X1: 1.55x 1.52x Cortex-A720: 1.54x 1.57x Cortex-A715: 1.47x 1.55x Cortex-A78: 1.21x 1.18x Cortex-A76: 1.38x 1.37x Cortex-A72: 1.08x 1.11x Cortex-A520: 0.97x 1.18x Cortex-A510: 0.99x 1.14x Cortex-A55: 1.16x 1.23x This patch increases the .text by ~660 bytes, but smaller than the reference implementation by about 0.5 KiB.
1 parent b129d9f commit edb1688

File tree

3 files changed

+285
-0
lines changed

3 files changed

+285
-0
lines changed

src/arm/64/refmvs.S

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@
2525
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
*/
2727

28+
#include "src/arm/asm-offsets.h"
2829
#include "src/arm/asm.S"
2930
#include "util.S"
3031

32+
#define INVALID_MV 0x80008000
33+
3134
// void dav1d_splat_mv_neon(refmvs_block **rr, const refmvs_block *rmv,
3235
// int bx4, int bw4, int bh4)
3336

@@ -292,3 +295,252 @@ jumptable save_tmvs_tbl
292295
.word 1 * 12
293296
.word 10b - save_tmvs_tbl
294297
endjumptable
298+
299+
// void dav1d_load_tmvs_neon(const refmvs_frame *const rf, int tile_row_idx,
300+
// const int col_start8, const int col_end8,
301+
// const int row_start8, int row_end8)
302+
function load_tmvs_neon, export=1
303+
rf .req x0
304+
tile_row_idx .req w1
305+
col_start8 .req w2
306+
col_end8 .req w3
307+
row_start8 .req w4
308+
row_end8 .req w5
309+
col_start8i .req w6
310+
col_end8i .req w7
311+
rp_proj .req x8
312+
stride5 .req x9
313+
wstride5 .req w9
314+
stp x28, x27, [sp, #-96]!
315+
stp x26, x25, [sp, #16]
316+
stp x24, x23, [sp, #32]
317+
stp x22, x21, [sp, #48]
318+
stp x20, x19, [sp, #64]
319+
stp x29, x30, [sp, #80]
320+
321+
ldr w15, [rf, #RMVSF_N_TILE_THREADS]
322+
ldp w16, w17, [rf, #RMVSF_IW8] // include rf->ih8 too
323+
sub col_start8i, col_start8, #8 // col_start8 - 8
324+
add col_end8i, col_end8, #8 // col_end8 + 8
325+
ldr wstride5, [rf, #RMVSF_RP_STRIDE]
326+
ldr rp_proj, [rf, #RMVSF_RP_PROJ]
327+
328+
cmp w15, #1
329+
csel tile_row_idx, wzr, tile_row_idx, eq // if (rf->n_tile_threads == 1) tile_row_idx = 0
330+
331+
bic col_start8i, col_start8i, col_start8i, asr #31 // imax(col_start8 - 8, 0)
332+
cmp col_end8i, w16
333+
csel col_end8i, col_end8i, w16, lt // imin(col_end8 + 8, rf->iw8)
334+
335+
lsl tile_row_idx, tile_row_idx, #4 // 16 * tile_row_idx
336+
337+
cmp row_end8, w17
338+
csel row_end8, row_end8, w17, lt // imin(row_end8, rf->ih8)
339+
340+
add wstride5, wstride5, wstride5, lsl #2 // stride * sizeof(refmvs_temporal_block)
341+
and w15, row_start8, #15 // row_start8 & 15
342+
add w10, col_start8, col_start8, lsl #2 // col_start8 * sizeof(refmvs_temporal_block)
343+
smaddl rp_proj, tile_row_idx, wstride5, rp_proj // &rf->rp_proj[16 * stride * tile_row_idx]
344+
smaddl x10, w15, wstride5, x10 // ((row_start8 & 15) * stride + col_start8) * sizeof(refmvs_temporal_block)
345+
mov w15, #INVALID_MV
346+
sub w11, col_end8, col_start8 // xfill loop count
347+
add x10, x10, rp_proj // &rf->rp_proj[16 * stride * tile_row_idx + (row_start8 & 15) * stride + col_start8]
348+
add x15, x15, x15, lsl #40 // first 64b of 4 [INVALID_MV, 0]... patterns
349+
mov w17, #(INVALID_MV >> 8) // last 32b of 4 patterns
350+
sub w12, row_end8, row_start8 // yfill loop count
351+
ror x16, x15, #48 // second 64b of 4 patterns
352+
ldr w19, [rf, #RMVSF_N_MFMVS]
353+
354+
5: // yfill loop
355+
and w13, w11, #-4 // xfill 4x count by patterns
356+
mov x14, x10 // fill_ptr = row_ptr
357+
add x10, x10, stride5 // row_ptr += stride
358+
sub w12, w12, #1 // y--
359+
360+
cbz w13, 3f
361+
362+
4: // xfill loop 4x
363+
sub w13, w13, #4 // xfill 4x count -= 4
364+
stp x15, x16, [x14]
365+
str w17, [x14, #16]
366+
add x14, x14, #20 // fill_ptr += 4 * sizeof(refmvs_temporal_block)
367+
cbnz w13, 4b
368+
369+
3: // up to 3 residuals
370+
tbz w11, #1, 1f
371+
str x15, [x14]
372+
strh w16, [x14, #8]
373+
add x14, x14, #10 // fill_ptr += 2 * sizeof(refmvs_temporal_block)
374+
375+
1: // up to 1 residual
376+
tbz w11, #0, 2f
377+
str w15, [x14]
378+
2:
379+
cbnz w12, 5b // yfill loop
380+
381+
cbz w19, 11f // if (!rf->n_mfmvs) skip nloop
382+
383+
add x29, rf, #RMVSF_MFMV_REF2CUR
384+
mov w10, #0 // n = 0
385+
movi v3.2s, #255 // 0x3FFF >> 6, for MV clamp
386+
movrel x1, div_mult_tbl
387+
388+
10: // nloop
389+
ldr w16, [x29, x10, lsl #2] // ref2cur = rf->mfmv_ref2cur[n]
390+
cmp w16, #-32 // instead of INT_MIN, we can use smaller constants
391+
b.lt 9f // if (ref2cur == INT_MIN) continue
392+
393+
add x17, x10, #(RMVSF_MFMV_REF - RMVSF_MFMV_REF2CUR) // n - (&rf->mfmv_ref - &rf->mfmv_ref2cur)
394+
mov x20, #4
395+
ldrb w17, [x29, x17] // ref = rf->mfmv_ref[n]
396+
ldr x13, [x29, #(RMVSF_RP_REF - RMVSF_MFMV_REF2CUR)]
397+
mov w28, #28 // 7 * sizeof(int)
398+
smaddl x20, row_start8, wstride5, x20 // row_start8 * stride * sizeof(refmvs_temporal_block) + 4
399+
mov w12, row_start8 // y = row_start8
400+
add x21, x29, #(RMVSF_MFMV_REF2REF - RMVSF_MFMV_REF2CUR - 4) // &rf->mfmv_ref2ref - 1
401+
ldr x13, [x13, x17, lsl #3] // rf->rp_ref[ref]
402+
smaddl x28, w28, w10, x21 // rf->mfmv_ref2ref[n] - 1
403+
sub w17, w17, #4 // ref_sign = ref - 4
404+
add x13, x13, x20 // r = &rf->rp_ref[ref][row_start8 * stride].ref
405+
dup v0.2s, w17 // ref_sign
406+
407+
5: // yloop
408+
and w14, w12, #-8 // y_sb_align = y & ~7
409+
mov w11, col_start8i // x = col_start8i
410+
add w15, w14, #8 // y_sb_align + 8
411+
cmp w14, row_start8
412+
csel w14, w14, row_start8, gt // imax(y_sb_align, row_start8)
413+
cmp w15, row_end8
414+
csel w15, w15, row_end8, lt // imin(y_sb_align + 8, row_end8)
415+
416+
4: // xloop
417+
add x23, x13, x11, lsl #2 // partial &r[x] address
418+
ldrb w22, [x23, x11] // b_ref = rb->ref
419+
cbz w22, 6f // if (!b_ref) continue
420+
421+
ldr w24, [x28, x22, lsl #2] // ref2ref = rf->mfmv_ref2ref[n][b_ref - 1]
422+
cbz w24, 6f // if (!ref2ref) continue
423+
424+
ldrh w20, [x1, x24, lsl #1] // div_mult[ref2ref]
425+
add x23, x23, x11 // &r[x]
426+
mul w20, w20, w16 // frac = ref2cur * div_mult[ref2ref]
427+
428+
ldur s1, [x23, #-4] // mv{y, x} = rb->mv
429+
fmov s2, w20 // frac
430+
sxtl v1.4s, v1.4h
431+
mul v1.2s, v1.2s, v2.s[0] // offset{y, x} = frac * mv{y, x}
432+
433+
ssra v1.2s, v1.2s, #31 // offset{y, x} + (offset{y, x} >> 31)
434+
ldur w25, [x23, #-4] // b_mv = rb->mv
435+
srshr v1.2s, v1.2s, #14 // (offset{y, x} + (offset{y, x} >> 31) + 8192) >> 14
436+
437+
abs v2.2s, v1.2s // abs(offset{y, x})
438+
eor v1.8b, v1.8b, v0.8b // offset{y, x} ^ ref_sign
439+
440+
sshr v2.2s, v2.2s, #6 // abs(offset{y, x}) >> 6
441+
cmlt v1.2s, v1.2s, #0 // sign(offset{y, x} ^ ref_sign): -1 or 0
442+
umin v2.2s, v2.2s, v3.2s // iclip(abs(offset{y, x}) >> 6, 0, 0x3FFF >> 6)
443+
444+
neg v4.2s, v2.2s
445+
bsl v1.8b, v4.8b, v2.8b // apply_sign(iclip(abs(offset{y, x}) >> 6, 0, 0x3FFF >> 6))
446+
fmov x20, d1 // offset{y, x}
447+
448+
add w21, w12, w20 // pos_y = y + offset.y
449+
cmp w21, w14 // pos_y >= y_proj_start
450+
b.lt 1f
451+
cmp w21, w15 // pos_y < y_proj_end
452+
b.ge 1f
453+
add x26, x11, x20, asr #32 // pos_x = x + offset.x
454+
and w27, w21, #15 // pos_y & 15
455+
add x21, x26, x26, lsl #2 // pos_x * sizeof(refmvs_temporal_block)
456+
umaddl x27, w27, wstride5, rp_proj // &rp_proj[(pos_y & 15) * stride]
457+
add x27, x27, x21 // &rp_proj[(pos_y & 15) * stride + pos_x]
458+
459+
3: // copy loop
460+
and w20, w11, #-8 // x_sb_align = x & ~7
461+
sub w21, w20, #8 // x_sb_align - 8
462+
cmp w21, col_start8
463+
csel w21, w21, col_start8, gt // imax(x_sb_align - 8, col_start8)
464+
cmp w26, w21 // pos_x >= imax(x_sb_align - 8, col_start8)
465+
b.lt 2f
466+
add w20, w20, #16 // x_sb_align + 16
467+
cmp w20, col_end8
468+
csel w20, w20, col_end8, lt // imin(x_sb_align + 16, col_end8)
469+
cmp w26, w20 // pos_x < imin(x_sb_align + 16, col_end8)
470+
b.ge 2f
471+
str w25, [x27] // rp_proj[pos + pos_x].mv = rb->mv (b_mv)
472+
strb w24, [x27, #4] // rp_proj[pos + pos_x].ref = ref2ref
473+
474+
2: // search part of copy loop
475+
add w11, w11, #1 // x++
476+
cmp w11, col_end8i // if (++x >= col_end8i) break xloop
477+
b.ge 8f
478+
479+
ldrb w20, [x23, #5]! // rb++; rb->ref
480+
cmp w20, w22 // if (rb->ref != b_ref) break
481+
b.ne 7f
482+
483+
ldur w21, [x23, #-4] // rb->mv.n
484+
cmp w21, w25 // if (rb->mv.n != b_mv.n) break
485+
b.ne 7f
486+
487+
add w26, w26, #1 // pos_x++
488+
add x27, x27, #5 // advance &rp_proj[(pos_y & 15) * stride + pos_x]
489+
b 3b // copy loop
490+
491+
1: // search loop
492+
add w11, w11, #1 // x++
493+
cmp w11, col_end8i // if (++x >= col_end8i) break xloop
494+
b.ge 8f
495+
496+
ldrb w20, [x23, #5]! // rb++; rb->ref
497+
cmp w20, w22 // if (rb->ref != b_ref) break
498+
b.ne 7f
499+
500+
ldur w21, [x23, #-4] // rb->mv.n
501+
cmp w21, w25 // if (rb->mv.n == b_mv.n) continue
502+
b.eq 1b // search loop
503+
7:
504+
cmp w11, col_end8i // x < col_end8i
505+
b.lt 4b // xloop
506+
507+
6: // continue case of xloop
508+
add w11, w11, #1 // x++
509+
cmp w11, col_end8i // x < col_end8i
510+
b.lt 4b // xloop
511+
8:
512+
add w12, w12, #1 // y++
513+
add x13, x13, stride5 // r += stride
514+
cmp w12, row_end8 // y < row_end8
515+
b.lt 5b // yloop
516+
9:
517+
add w10, w10, #1
518+
cmp w10, w19 // n < rf->n_mfmvs
519+
b.lt 10b // nloop
520+
11:
521+
ldp x29, x30, [sp, #80]
522+
ldp x20, x19, [sp, #64]
523+
ldp x22, x21, [sp, #48]
524+
ldp x24, x23, [sp, #32]
525+
ldp x26, x25, [sp, #16]
526+
ldp x28, x27, [sp], #96
527+
ret
528+
.unreq rf
529+
.unreq tile_row_idx
530+
.unreq col_start8
531+
.unreq col_end8
532+
.unreq row_start8
533+
.unreq row_end8
534+
.unreq col_start8i
535+
.unreq col_end8i
536+
.unreq rp_proj
537+
.unreq stride5
538+
.unreq wstride5
539+
endfunc
540+
541+
const div_mult_tbl
542+
.hword 0, 16384, 8192, 5461, 4096, 3276, 2730, 2340
543+
.hword 2048, 1820, 1638, 1489, 1365, 1260, 1170, 1092
544+
.hword 1024, 963, 910, 862, 819, 780, 744, 712
545+
.hword 682, 655, 630, 606, 585, 564, 546, 528
546+
endconst

src/arm/asm-offsets.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
#ifndef ARM_ASM_OFFSETS_H
2828
#define ARM_ASM_OFFSETS_H
2929

30+
#include "config.h"
31+
3032
#define FGD_SEED 0
3133
#define FGD_AR_COEFF_LAG 92
3234
#define FGD_AR_COEFFS_Y 96
@@ -40,4 +42,17 @@
4042
#define FGD_UV_OFFSET 204
4143
#define FGD_CLIP_TO_RESTRICTED_RANGE 216
4244

45+
#if ARCH_AARCH64
46+
#define RMVSF_IW8 16
47+
#define RMVSF_IH8 20
48+
#define RMVSF_MFMV_REF 53
49+
#define RMVSF_MFMV_REF2CUR 56
50+
#define RMVSF_MFMV_REF2REF 68
51+
#define RMVSF_N_MFMVS 152
52+
#define RMVSF_RP_REF 168
53+
#define RMVSF_RP_PROJ 176
54+
#define RMVSF_RP_STRIDE 184
55+
#define RMVSF_N_TILE_THREADS 200
56+
#endif
57+
4358
#endif /* ARM_ASM_OFFSETS_H */

src/arm/refmvs.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,24 @@
2525
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
*/
2727

28+
#include "src/arm/asm-offsets.h"
2829
#include "src/cpu.h"
2930
#include "src/refmvs.h"
3031

32+
#if ARCH_AARCH64
33+
CHECK_OFFSET(refmvs_frame, iw8, RMVSF_IW8);
34+
CHECK_OFFSET(refmvs_frame, ih8, RMVSF_IH8);
35+
CHECK_OFFSET(refmvs_frame, mfmv_ref, RMVSF_MFMV_REF);
36+
CHECK_OFFSET(refmvs_frame, mfmv_ref2cur, RMVSF_MFMV_REF2CUR);
37+
CHECK_OFFSET(refmvs_frame, mfmv_ref2ref, RMVSF_MFMV_REF2REF);
38+
CHECK_OFFSET(refmvs_frame, n_mfmvs, RMVSF_N_MFMVS);
39+
CHECK_OFFSET(refmvs_frame, rp_ref, RMVSF_RP_REF);
40+
CHECK_OFFSET(refmvs_frame, rp_proj, RMVSF_RP_PROJ);
41+
CHECK_OFFSET(refmvs_frame, rp_stride, RMVSF_RP_STRIDE);
42+
CHECK_OFFSET(refmvs_frame, n_tile_threads, RMVSF_N_TILE_THREADS);
43+
#endif
44+
45+
decl_load_tmvs_fn(dav1d_load_tmvs_neon);
3146
decl_save_tmvs_fn(dav1d_save_tmvs_neon);
3247
decl_splat_mv_fn(dav1d_splat_mv_neon);
3348

@@ -36,6 +51,9 @@ static ALWAYS_INLINE void refmvs_dsp_init_arm(Dav1dRefmvsDSPContext *const c) {
3651

3752
if (!(flags & DAV1D_ARM_CPU_FLAG_NEON)) return;
3853

54+
#if ARCH_AARCH64
55+
c->load_tmvs = dav1d_load_tmvs_neon;
56+
#endif
3957
c->save_tmvs = dav1d_save_tmvs_neon;
4058
c->splat_mv = dav1d_splat_mv_neon;
4159
}

0 commit comments

Comments
 (0)