33namespace at ::mps {
44
55static const char * SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
6- struct __attribute__ ((packed)) packed_uint5{{
7- uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
8- }};
9-
106template<typename Y, typename X>
117Y cast(const X x);
128
@@ -15,32 +11,26 @@ template<>
1511 return {2};
1612}}
1713
18- kernel void scatter_kernel_5(uint linear_index [[thread_position_in_grid]],
19- constant void * src_ [[buffer(0)]],
20- device void * dst_ [[buffer(1)]],
21- constant packed_uint5 & size [[buffer(2)]],
22- constant packed_uint5 & stride [[buffer(3)]],
23- constant uint32_t & numel [[buffer(4)]]) {{
14+ kernel void scatter_kernel_n(uint linear_index [[thread_position_in_grid]],
15+ constant void * src_ [[buffer(0)]],
16+ device void * dst_ [[buffer(1)]],
17+ constant uint32_t * size [[buffer(2)]],
18+ constant uint32_t * stride [[buffer(3)]],
19+ constant uint32_t & numel [[buffer(4)]],
20+ constant int32_t & ndim [[buffer(5)]]) {{
2421 if (linear_index >= numel) return;
2522
2623 constant {0} * src = (constant {0} *)src_;
2724 device {1} * dst = (device {1} *)dst_;
2825
29- packed_uint5 local_index;
30- local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
31- local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
32- local_index.z = linear_index / (size.u * size.w) % size.z;
33- local_index.w = linear_index / size.u % size.w;
34- local_index.u = linear_index % size.u;
35-
36- packed_uint5 strided_index;
37- strided_index.x = local_index.x * stride.x;
38- strided_index.y = local_index.y * stride.y;
39- strided_index.z = local_index.z * stride.z;
40- strided_index.w = local_index.w * stride.w;
41- strided_index.u = local_index.u * stride.u;
42-
43- dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u] = cast<{1}>(src[linear_index]);
26+ uint64_t dst_offs = 0;
27+ auto dst_idx = linear_index;
28+ for(int dim = ndim - 1; dim >= 0; --dim) {{
29+ dst_offs += stride[dim] * (dst_idx % size[dim]);
30+ dst_idx /= size[dim];
31+ }}
32+
33+ dst[dst_offs] = cast<{1}>(src[linear_index]);
4434}}
4535
4636kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]],
@@ -121,10 +111,6 @@ kernel void scatter_kernel_1(uint linear_index [[thread_position_in
121111)METAL_SCATTER" ;
122112
123113static const char * GATHER_OPS_TEMPLATE = R"METAL_GATHER(
124- struct __attribute__ ((packed)) packed_uint5{{
125- uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
126- }};
127-
128114template<typename Y, typename X>
129115Y cast(const X x);
130116
@@ -133,33 +119,26 @@ template<>
133119 return {2};
134120}}
135121
136- kernel void gather_kernel_5(uint linear_index [[thread_position_in_grid]],
137- constant void * src_ [[buffer(0)]],
138- device void * dst_ [[buffer(1)]],
139- constant packed_uint5 & size [[buffer(2)]],
140- constant packed_uint5 & stride [[buffer(3)]],
141- constant uint32_t & numel [[buffer(4)]]) {{
122+ kernel void gather_kernel_n(uint linear_index [[thread_position_in_grid]],
123+ constant void * src_ [[buffer(0)]],
124+ device void * dst_ [[buffer(1)]],
125+ constant uint32_t * size [[buffer(2)]],
126+ constant uint32_t * stride [[buffer(3)]],
127+ constant uint32_t & numel [[buffer(4)]],
128+ constant int32_t & ndim [[buffer(5)]]) {{
142129 if (linear_index >= numel) return;
143130
144131 constant {0} * src = (constant {0} *)src_;
145132 device {1} * dst = (device {1} *)dst_;
146133
134+ uint64_t src_offs = 0;
135+ auto src_idx = linear_index;
136+ for(int dim = ndim - 1; dim >= 0; --dim) {{
137+ src_offs += stride[dim] * (src_idx % size[dim]);
138+ src_idx /= size[dim];
139+ }}
147140
148- packed_uint5 local_index;
149- local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
150- local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
151- local_index.z = linear_index / (size.u * size.w) % size.z;
152- local_index.w = linear_index / size.u % size.w;
153- local_index.u = linear_index % size.u;
154-
155- packed_uint5 strided_index;
156- strided_index.x = local_index.x * stride.x;
157- strided_index.y = local_index.y * stride.y;
158- strided_index.z = local_index.z * stride.z;
159- strided_index.w = local_index.w * stride.w;
160- strided_index.u = local_index.u * stride.u;
161-
162- dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u]);
141+ dst[linear_index] = cast<{1}>(src[src_offs]);
163142}}
164143
165144kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]],
0 commit comments