-
Notifications
You must be signed in to change notification settings - Fork 1.1k
cpu: aarch64: add ASIMD softmax JIT implementation #4441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
As this change is pretty big, do you think it would be possible to neatly split it into two commits: one for the sve optimizations and one for the asimd impl? The sve changes should even maybe be a separate PR. |
This commit moves all SVE-specific code into a new construct `jit_softmax_sve_t`.
This commit introduces an f32 ASIMD `softmax` JIT implementation.
9555ee7 to
d614a2c
Compare
|
I've now split up the changes into 3 separate commits:
I will move the final commit to a follow-up PR if you think that's best. I've only left all 3 together for now as the c7g/c8g speedups would be less noticeable at a glance with the SVE improvements in commits 2 and 3 split up, compared to being altogether in a single table like this. |
This commit adapts some of the ASIMD softmax changes for the SVE kernels. In particular, the `jit:sve_128` logic more closely resembles `jit:asimd` (e.g. its `exp` eltwise injector is inlined and uses `compute_vector_range()` instead of `compute_vector()`).
d614a2c to
61355d8
Compare
jondea
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really good, thank you! Couple of comments, I will have another look over and will probably have more.
| const auto &t4 = VReg4S(vmm_aux3.getIdx()); | ||
| const auto &t_tmp = VReg4S(vmm_tmp.getIdx()); | ||
|
|
||
| const float special_bound_input = 126.5f * logf(2.0f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this variable have a more specific name?
| h->fmov(h->X_TMP_0, DReg(t_tmp.getIdx())); | ||
| h->cbnz(h->X_TMP_0, L_special); | ||
| if (need_special_case) { | ||
| // Check if any lane needs special-case handling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this include NaN and Inf? I think a couple of comments explaining the flow for them would be useful
Description
This commit introduces an f32 ASIMD
softmaxJIT implementation using theexpeltwise injector added in #4376, while also improving performance for the existingsve_*implementations (primarily by increasing the unrolling factorunroll_regs_and skipping the multiplication with default dequantization / requantization factorssrc_scales/dst_scales). Forjit:asimdandjit:sve_128, theexpfunction is also effectively inlined by settingpreserve_vmm = false, whereasjit:sve_256did not benefit from such a change.As the previous softmax implementation heavily relied on predicated instructions,
jit_softmax_base_twas refactored to only include common logic for SVE and non-SVE implementations alike. At the same time, two different derived constructs were added to handle ISA-specific work:jit_softmax_sve_tandjit_softmax_asimd_t.In addition, the JIT eltwise injector was changed to support storing/loading preserved vectors on non-SVE targets.
Performance improvements (f32)
c6g
c7g
c8g