|
| 1 | +#include "roll.hpp" |
| 2 | +#include "common.hpp" |
| 3 | + |
| 4 | +using namespace sycl; |
| 5 | + |
| 6 | +static inline int wrap_add(int i, int shift, int n) { |
| 7 | + |
| 8 | + int s = i + shift; |
| 9 | + return (s >= n) ? (s - n) : s; |
| 10 | +} |
| 11 | + |
| 12 | +static void kernel_roll_fused_i0_i1( |
| 13 | + queue &q, |
| 14 | + const float *src_d, |
| 15 | + float *dst_d, |
| 16 | + int ne0, int ne1, int ne2, int ne3, |
| 17 | + int sh0, int sh1, int sh2, int sh3) |
| 18 | +{ |
| 19 | + if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return; |
| 20 | + |
| 21 | + |
| 22 | + const int stride1 = ne0; |
| 23 | + const int stride2 = ne0 * ne1; |
| 24 | + const int stride3 = ne0 * ne1 * ne2; |
| 25 | + |
| 26 | + |
| 27 | + const int shNe0 = (ne0 - sh0) % ne0; |
| 28 | + const int shNe1 = (ne1 - sh1) % ne1; |
| 29 | + const int shNe2 = (ne2 - sh2) % ne2; |
| 30 | + const int shNe3 = (ne3 - sh3) % ne3; |
| 31 | + |
| 32 | + |
| 33 | + const size_t g0 = (size_t) ne3; |
| 34 | + const size_t g1 = (size_t) ne2; |
| 35 | + const size_t g2 = (size_t) (ne1 * ne0); |
| 36 | + |
| 37 | + const range<3> global{ g0, g1, g2 }; |
| 38 | + |
| 39 | + q.submit([&](handler &h) { |
| 40 | + h.parallel_for(global, [=](id<3> idx) { |
| 41 | + const int i3 = (int) idx[0]; |
| 42 | + const int i2 = (int) idx[1]; |
| 43 | + |
| 44 | + const int fused = (int) idx[2]; |
| 45 | + const int i1 = fused / ne0; |
| 46 | + const int i0 = fused - i1 * ne0; // fused % ne0 |
| 47 | + |
| 48 | + |
| 49 | + const int idx_dst = i0 |
| 50 | + + i1 * stride1 |
| 51 | + + i2 * stride2 |
| 52 | + + i3 * stride3; |
| 53 | + |
| 54 | + |
| 55 | + const int s0 = wrap_add(i0, shNe0, ne0); |
| 56 | + const int s1 = wrap_add(i1, shNe1, ne1); |
| 57 | + const int s2 = wrap_add(i2, shNe2, ne2); |
| 58 | + const int s3 = wrap_add(i3, shNe3, ne3); |
| 59 | + |
| 60 | + const int idx_src = s0 |
| 61 | + + s1 * stride1 |
| 62 | + + s2 * stride2 |
| 63 | + + s3 * stride3; |
| 64 | + |
| 65 | + dst_d[idx_dst] = src_d[idx_src]; |
| 66 | + }); |
| 67 | + }); |
| 68 | +} |
| 69 | + |
| 70 | +void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { |
| 71 | + GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 72 | + |
| 73 | + const ggml_tensor *src = dst->src[0]; |
| 74 | + GGML_ASSERT(src && src->type == GGML_TYPE_F32); |
| 75 | + |
| 76 | + const int ne0 = (int) dst->ne[0]; |
| 77 | + const int ne1 = (int) dst->ne[1]; |
| 78 | + const int ne2 = (int) dst->ne[2]; |
| 79 | + const int ne3 = (int) dst->ne[3]; |
| 80 | + |
| 81 | + const int32_t *params = (const int32_t *) dst->op_params; |
| 82 | + int shift0 = params[0]; |
| 83 | + int shift1 = params[1]; |
| 84 | + int shift2 = params[2]; |
| 85 | + int shift3 = params[3]; |
| 86 | + |
| 87 | + |
| 88 | + if ((shift0 | shift1 | shift2 | shift3) == 0) { |
| 89 | + const size_t nb = ggml_nbytes(src); |
| 90 | + queue *q = ctx.stream(); |
| 91 | + SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb))); |
| 92 | + return; |
| 93 | + } |
| 94 | + |
| 95 | + auto norm = [](int sh, int n) -> int { |
| 96 | + if (n <= 0) return 0; |
| 97 | + sh %= n; |
| 98 | + if (sh < 0) sh += n; |
| 99 | + return sh; |
| 100 | + }; |
| 101 | + shift0 = norm(shift0, ne0); |
| 102 | + shift1 = norm(shift1, ne1); |
| 103 | + shift2 = norm(shift2, ne2); |
| 104 | + shift3 = norm(shift3, ne3); |
| 105 | + |
| 106 | + try { |
| 107 | + queue *q = ctx.stream(); |
| 108 | + |
| 109 | + const float *src_d = (const float *) src->data; |
| 110 | + float *dst_d = (float *) dst->data; |
| 111 | + GGML_ASSERT(src_d && dst_d); |
| 112 | + |
| 113 | + kernel_roll_fused_i0_i1( |
| 114 | + *q, src_d, dst_d, |
| 115 | + ne0, ne1, ne2, ne3, |
| 116 | + shift0, shift1, shift2, shift3 |
| 117 | + ); |
| 118 | + } catch (const std::exception &e) { |
| 119 | + std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what()); |
| 120 | + throw; |
| 121 | + } |
| 122 | +} |
0 commit comments