Skip to content

Commit e8a84f6

Browse files
committed
Add OUT_PROD, RMS_NORM_BACK, SILU_BACK metal shader.
Signed-off-by: Marcus Edel <[email protected]>
1 parent ad4b2d7 commit e8a84f6

File tree

3 files changed

+518
-0
lines changed

3 files changed

+518
-0
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,33 @@ typedef struct {
184184
uint64_t nb3;
185185
} ggml_metal_kargs_cpy;
186186

187+
typedef struct {
188+
int32_t ne00;
189+
int32_t ne01;
190+
int32_t ne02;
191+
int32_t ne03;
192+
uint64_t nb00;
193+
uint64_t nb01;
194+
uint64_t nb02;
195+
uint64_t nb03;
196+
int32_t ne10;
197+
int32_t ne11;
198+
int32_t ne12;
199+
int32_t ne13;
200+
uint64_t nb10;
201+
uint64_t nb11;
202+
uint64_t nb12;
203+
uint64_t nb13;
204+
int32_t ne0;
205+
int32_t ne1;
206+
int32_t ne2;
207+
int32_t ne3;
208+
uint64_t nb0;
209+
uint64_t nb1;
210+
uint64_t nb2;
211+
uint64_t nb3;
212+
} ggml_metal_kargs_out_prod;
213+
187214
typedef struct {
188215
int64_t ne10;
189216
int64_t ne11;
@@ -439,6 +466,21 @@ typedef struct {
439466
uint64_t nbf3[3];
440467
} ggml_metal_kargs_rms_norm;
441468

469+
typedef struct {
470+
int32_t ne00;
471+
int32_t ne00_4;
472+
uint64_t nb01;
473+
uint64_t nb02;
474+
uint64_t nb03;
475+
uint64_t nb11;
476+
uint64_t nb12;
477+
uint64_t nb13;
478+
uint64_t nb1;
479+
uint64_t nb2;
480+
uint64_t nb3;
481+
float eps;
482+
} ggml_metal_kargs_rms_norm_back;
483+
442484
typedef struct {
443485
int32_t ne00;
444486
int32_t ne00_4;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,14 @@ - (void) dealloc {
215215
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
216216
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
217217
GGML_METAL_KERNEL_TYPE_REPEAT_I16,
218+
GGML_METAL_KERNEL_TYPE_OUT_PROD_F32,
219+
GGML_METAL_KERNEL_TYPE_OUT_PROD_F16_F32,
220+
GGML_METAL_KERNEL_TYPE_OUT_PROD_F32_F16,
221+
GGML_METAL_KERNEL_TYPE_OUT_PROD_F16_F16,
222+
GGML_METAL_KERNEL_TYPE_OUT_PROD_Q8_0_F32,
223+
GGML_METAL_KERNEL_TYPE_OUT_PROD_Q8_0_F16,
224+
GGML_METAL_KERNEL_TYPE_OUT_PROD_Q4_0_F32,
225+
GGML_METAL_KERNEL_TYPE_OUT_PROD_Q4_0_F16,
218226
GGML_METAL_KERNEL_TYPE_SCALE,
219227
GGML_METAL_KERNEL_TYPE_SCALE_4,
220228
GGML_METAL_KERNEL_TYPE_CLAMP,
@@ -229,6 +237,8 @@ - (void) dealloc {
229237
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
230238
GGML_METAL_KERNEL_TYPE_SILU,
231239
GGML_METAL_KERNEL_TYPE_SILU_4,
240+
GGML_METAL_KERNEL_TYPE_SILU_BACK,
241+
GGML_METAL_KERNEL_TYPE_SILU_BACK_4,
232242
GGML_METAL_KERNEL_TYPE_ELU,
233243
GGML_METAL_KERNEL_TYPE_ABS,
234244
GGML_METAL_KERNEL_TYPE_SGN,
@@ -278,6 +288,7 @@ - (void) dealloc {
278288
GGML_METAL_KERNEL_TYPE_RMS_NORM,
279289
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
280290
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
291+
GGML_METAL_KERNEL_TYPE_RMS_NORM_BACK,
281292
GGML_METAL_KERNEL_TYPE_L2_NORM,
282293
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
283294
GGML_METAL_KERNEL_TYPE_NORM,
@@ -1137,6 +1148,14 @@ @implementation GGMLMetalClass
11371148
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
11381149
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
11391150
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
1151+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_OUT_PROD_F32, out_prod_f32, true);
1152+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_OUT_PROD_F16_F32, out_prod_f16_f32, true);
1153+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_OUT_PROD_F32_F16, out_prod_f32_f16, true);
1154+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_OUT_PROD_F16_F16, out_prod_f16, true);
1155+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_OUT_PROD_Q8_0_F32, out_prod_q8_0_f32, true);
1156+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_OUT_PROD_Q8_0_F16, out_prod_q8_0_f16, true);
1157+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_OUT_PROD_Q4_0_F32, out_prod_q4_0_f32, true);
1158+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_OUT_PROD_Q4_0_F16, out_prod_q4_0_f16, true);
11401159
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
11411160
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
11421161
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
@@ -1151,6 +1170,8 @@ @implementation GGMLMetalClass
11511170
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
11521171
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
11531172
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
1173+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_BACK, silu_back, true);
1174+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_BACK_4, silu_back_4, true);
11541175
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
11551176
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true);
11561177
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true);
@@ -1200,6 +1221,7 @@ @implementation GGMLMetalClass
12001221
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
12011222
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
12021223
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
1224+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_BACK, rms_norm_back, has_simdgroup_reduction);
12031225
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
12041226
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
12051227
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
@@ -1853,13 +1875,54 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
18531875
case GGML_OP_DIV:
18541876
case GGML_OP_ADD_ID:
18551877
return op->src[0]->type == GGML_TYPE_F32;
1878+
case GGML_OP_OUT_PROD:
1879+
if (op->type != GGML_TYPE_F32) {
1880+
return false;
1881+
}
1882+
1883+
{
1884+
const enum ggml_type src0_type = op->src[0]->type;
1885+
const enum ggml_type src1_type = op->src[1]->type;
1886+
1887+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
1888+
return true;
1889+
}
1890+
1891+
if (src0_type == GGML_TYPE_F32 && (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_F16)) {
1892+
return true;
1893+
}
1894+
1895+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
1896+
return true;
1897+
}
1898+
1899+
if (src0_type == GGML_TYPE_Q8_0 && (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_F16)) {
1900+
return true;
1901+
}
1902+
1903+
if (src0_type == GGML_TYPE_Q4_0 && (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_F16)) {
1904+
return true;
1905+
}
1906+
}
1907+
1908+
return false;
18561909
case GGML_OP_ACC:
18571910
case GGML_OP_REPEAT:
18581911
case GGML_OP_SCALE:
18591912
case GGML_OP_CONV_TRANSPOSE_1D:
18601913
return true;
18611914
case GGML_OP_CLAMP:
18621915
return op->src[0]->type == GGML_TYPE_F32;
1916+
case GGML_OP_SILU_BACK:
1917+
return op->type == GGML_TYPE_F32 &&
1918+
op->src[0] != NULL && op->src[1] != NULL &&
1919+
op->src[0]->type == GGML_TYPE_F32 &&
1920+
op->src[1]->type == GGML_TYPE_F32 &&
1921+
ggml_is_contiguous_1(op->src[0]) &&
1922+
ggml_is_contiguous_1(op->src[1]) &&
1923+
ggml_is_contiguous_1(op) &&
1924+
ggml_are_same_shape(op, op->src[0]) &&
1925+
ggml_are_same_shape(op, op->src[1]);
18631926
case GGML_OP_SQR:
18641927
case GGML_OP_SQRT:
18651928
case GGML_OP_SIN:
@@ -1875,6 +1938,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
18751938
case GGML_OP_RMS_NORM:
18761939
case GGML_OP_L2_NORM:
18771940
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
1941+
case GGML_OP_RMS_NORM_BACK:
1942+
return has_simdgroup_reduction &&
1943+
op->type == GGML_TYPE_F32 &&
1944+
op->src[0] != NULL && op->src[1] != NULL &&
1945+
op->src[0]->type == GGML_TYPE_F32 &&
1946+
op->src[1]->type == GGML_TYPE_F32 &&
1947+
op->ne[0] % 4 == 0 &&
1948+
ggml_is_contiguous_1(op->src[0]) &&
1949+
ggml_is_contiguous_1(op->src[1]) &&
1950+
ggml_is_contiguous_1(op) &&
1951+
ggml_are_same_shape(op, op->src[0]) &&
1952+
ggml_are_same_shape(op, op->src[1]);
18781953
case GGML_OP_ARGMAX:
18791954
return true;
18801955
case GGML_OP_NORM:
@@ -2365,6 +2440,80 @@ static int ggml_metal_encode_node(
23652440
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
23662441
}
23672442
} break;
2443+
case GGML_OP_OUT_PROD:
2444+
{
2445+
GGML_ASSERT(dstt == GGML_TYPE_F32);
2446+
GGML_ASSERT(src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q4_0);
2447+
GGML_ASSERT(src1t == GGML_TYPE_F32 || src1t == GGML_TYPE_F16);
2448+
2449+
id<MTLComputePipelineState> pipeline = nil;
2450+
2451+
if (src0t == GGML_TYPE_Q8_0) {
2452+
if (src1t == GGML_TYPE_F32) {
2453+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_OUT_PROD_Q8_0_F32].pipeline;
2454+
} else if (src1t == GGML_TYPE_F16) {
2455+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_OUT_PROD_Q8_0_F16].pipeline;
2456+
} else {
2457+
GGML_ABORT("fatal error");
2458+
}
2459+
} else if (src0t == GGML_TYPE_Q4_0) {
2460+
if (src1t == GGML_TYPE_F32) {
2461+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_OUT_PROD_Q4_0_F32].pipeline;
2462+
} else if (src1t == GGML_TYPE_F16) {
2463+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_OUT_PROD_Q4_0_F16].pipeline;
2464+
} else {
2465+
GGML_ABORT("fatal error");
2466+
}
2467+
} else if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32) {
2468+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_OUT_PROD_F16_F32].pipeline;
2469+
} else if (src0t == GGML_TYPE_F32 && src1t == GGML_TYPE_F32) {
2470+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_OUT_PROD_F32].pipeline;
2471+
} else if (src0t == GGML_TYPE_F32 && src1t == GGML_TYPE_F16) {
2472+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_OUT_PROD_F32_F16].pipeline;
2473+
} else if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F16) {
2474+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_OUT_PROD_F16_F16].pipeline;
2475+
} else {
2476+
GGML_ABORT("fatal error");
2477+
}
2478+
2479+
ggml_metal_kargs_out_prod args = {
2480+
(int32_t) ne00,
2481+
(int32_t) ne01,
2482+
(int32_t) ne02,
2483+
(int32_t) ne03,
2484+
nb00,
2485+
nb01,
2486+
nb02,
2487+
nb03,
2488+
(int32_t) ne10,
2489+
(int32_t) ne11,
2490+
(int32_t) ne12,
2491+
(int32_t) ne13,
2492+
nb10,
2493+
nb11,
2494+
nb12,
2495+
nb13,
2496+
(int32_t) ne0,
2497+
(int32_t) ne1,
2498+
(int32_t) ne2,
2499+
(int32_t) ne3,
2500+
nb0,
2501+
nb1,
2502+
nb2,
2503+
nb3,
2504+
};
2505+
2506+
[encoder setComputePipelineState:pipeline];
2507+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
2508+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2509+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2510+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2511+
2512+
const int threads = ne0 < 1 ? 1 : (int) ne0;
2513+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, threads);
2514+
2515+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2516+
} break;
23682517
case GGML_OP_ADD_ID:
23692518
{
23702519
GGML_ASSERT(src0t == GGML_TYPE_F32);
@@ -2575,6 +2724,37 @@ static int ggml_metal_encode_node(
25752724

25762725
const int64_t n = ggml_nelements(dst);
25772726

2727+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2728+
} break;
2729+
case GGML_OP_SILU_BACK:
2730+
{
2731+
GGML_ASSERT(src0 != NULL);
2732+
GGML_ASSERT(src1 != NULL);
2733+
GGML_ASSERT(src0t == GGML_TYPE_F32);
2734+
GGML_ASSERT(src1t == GGML_TYPE_F32);
2735+
GGML_ASSERT(dstt == GGML_TYPE_F32);
2736+
GGML_ASSERT(ggml_are_same_shape(dst, src0));
2737+
GGML_ASSERT(ggml_are_same_shape(dst, src1));
2738+
GGML_ASSERT(ggml_is_contiguous_1(src0));
2739+
GGML_ASSERT(ggml_is_contiguous_1(src1));
2740+
GGML_ASSERT(ggml_is_contiguous_1(dst));
2741+
2742+
int64_t n = ggml_nelements(dst);
2743+
2744+
id<MTLComputePipelineState> pipeline = nil;
2745+
2746+
if (n % 4 == 0) {
2747+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_BACK_4].pipeline;
2748+
n /= 4;
2749+
} else {
2750+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_BACK].pipeline;
2751+
}
2752+
2753+
[encoder setComputePipelineState:pipeline];
2754+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2755+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2756+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2757+
25782758
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
25792759
} break;
25802760
case GGML_OP_UNARY:
@@ -4508,6 +4688,59 @@ static int ggml_metal_encode_node(
45084688

45094689
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
45104690

4691+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4692+
} break;
4693+
case GGML_OP_RMS_NORM_BACK:
4694+
{
4695+
GGML_ASSERT(src0 != NULL);
4696+
GGML_ASSERT(src1 != NULL);
4697+
GGML_ASSERT(ne00 % 4 == 0);
4698+
GGML_ASSERT(dstt == GGML_TYPE_F32);
4699+
GGML_ASSERT(src0t == GGML_TYPE_F32);
4700+
GGML_ASSERT(src1t == GGML_TYPE_F32);
4701+
GGML_ASSERT(ggml_are_same_shape(dst, src0));
4702+
GGML_ASSERT(ggml_are_same_shape(dst, src1));
4703+
GGML_ASSERT(ggml_is_contiguous_1(src0));
4704+
GGML_ASSERT(ggml_is_contiguous_1(src1));
4705+
GGML_ASSERT(ggml_is_contiguous_1(dst));
4706+
4707+
float eps;
4708+
memcpy(&eps, dst->op_params, sizeof(float));
4709+
4710+
ggml_metal_kargs_rms_norm_back args = {
4711+
/*.ne00 =*/ ne00,
4712+
/*.ne00_4 =*/ ne00/4,
4713+
/*.nb01 =*/ nb01,
4714+
/*.nb02 =*/ nb02,
4715+
/*.nb03 =*/ nb03,
4716+
/*.nb11 =*/ nb11,
4717+
/*.nb12 =*/ nb12,
4718+
/*.nb13 =*/ nb13,
4719+
/*.nb1 =*/ nb1,
4720+
/*.nb2 =*/ nb2,
4721+
/*.nb3 =*/ nb3,
4722+
/*.eps =*/ eps,
4723+
};
4724+
4725+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_BACK].pipeline;
4726+
4727+
int nth = 32; // SIMD width
4728+
4729+
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
4730+
nth *= 2;
4731+
}
4732+
4733+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
4734+
nth = MIN(nth, ne00/4);
4735+
4736+
[encoder setComputePipelineState:pipeline];
4737+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
4738+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4739+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
4740+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
4741+
4742+
[encoder setThreadgroupMemoryLength:2*32*sizeof(float) atIndex:0];
4743+
45114744
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
45124745
} break;
45134746
case GGML_OP_L2_NORM:

0 commit comments

Comments
 (0)