@@ -215,6 +215,14 @@ - (void) dealloc {
215215 GGML_METAL_KERNEL_TYPE_REPEAT_F16,
216216 GGML_METAL_KERNEL_TYPE_REPEAT_I32,
217217 GGML_METAL_KERNEL_TYPE_REPEAT_I16,
218+ GGML_METAL_KERNEL_TYPE_OUT_PROD_F32,
219+ GGML_METAL_KERNEL_TYPE_OUT_PROD_F16_F32,
220+ GGML_METAL_KERNEL_TYPE_OUT_PROD_F32_F16,
221+ GGML_METAL_KERNEL_TYPE_OUT_PROD_F16_F16,
222+ GGML_METAL_KERNEL_TYPE_OUT_PROD_Q8_0_F32,
223+ GGML_METAL_KERNEL_TYPE_OUT_PROD_Q8_0_F16,
224+ GGML_METAL_KERNEL_TYPE_OUT_PROD_Q4_0_F32,
225+ GGML_METAL_KERNEL_TYPE_OUT_PROD_Q4_0_F16,
218226 GGML_METAL_KERNEL_TYPE_SCALE,
219227 GGML_METAL_KERNEL_TYPE_SCALE_4,
220228 GGML_METAL_KERNEL_TYPE_CLAMP,
@@ -229,6 +237,8 @@ - (void) dealloc {
229237 GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
230238 GGML_METAL_KERNEL_TYPE_SILU,
231239 GGML_METAL_KERNEL_TYPE_SILU_4,
240+ GGML_METAL_KERNEL_TYPE_SILU_BACK,
241+ GGML_METAL_KERNEL_TYPE_SILU_BACK_4,
232242 GGML_METAL_KERNEL_TYPE_ELU,
233243 GGML_METAL_KERNEL_TYPE_ABS,
234244 GGML_METAL_KERNEL_TYPE_SGN,
@@ -278,6 +288,7 @@ - (void) dealloc {
278288 GGML_METAL_KERNEL_TYPE_RMS_NORM,
279289 GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
280290 GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
291+ GGML_METAL_KERNEL_TYPE_RMS_NORM_BACK,
281292 GGML_METAL_KERNEL_TYPE_L2_NORM,
282293 GGML_METAL_KERNEL_TYPE_GROUP_NORM,
283294 GGML_METAL_KERNEL_TYPE_NORM,
@@ -1137,6 +1148,14 @@ @implementation GGMLMetalClass
11371148 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true );
11381149 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true );
11391150 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true );
1151+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_OUT_PROD_F32, out_prod_f32, true );
1152+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_OUT_PROD_F16_F32, out_prod_f16_f32, true );
1153+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_OUT_PROD_F32_F16, out_prod_f32_f16, true );
1154+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_OUT_PROD_F16_F16, out_prod_f16, true );
1155+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_OUT_PROD_Q8_0_F32, out_prod_q8_0_f32, true );
1156+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_OUT_PROD_Q8_0_F16, out_prod_q8_0_f16, true );
1157+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_OUT_PROD_Q4_0_F32, out_prod_q4_0_f32, true );
1158+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_OUT_PROD_Q4_0_F16, out_prod_q4_0_f16, true );
11401159 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SCALE, scale, true );
11411160 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true );
11421161 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true );
@@ -1151,6 +1170,8 @@ @implementation GGMLMetalClass
11511170 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true );
11521171 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SILU, silu, true );
11531172 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true );
1173+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SILU_BACK, silu_back, true );
1174+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SILU_BACK_4, silu_back_4, true );
11541175 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ELU, elu, true );
11551176 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ABS, abs, true );
11561177 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SGN, sgn, true );
@@ -1200,6 +1221,7 @@ @implementation GGMLMetalClass
12001221 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
12011222 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
12021223 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
1224+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM_BACK, rms_norm_back, has_simdgroup_reduction);
12031225 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
12041226 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
12051227 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NORM, norm, true );
@@ -1853,13 +1875,54 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
18531875 case GGML_OP_DIV:
18541876 case GGML_OP_ADD_ID:
18551877 return op->src [0 ]->type == GGML_TYPE_F32;
1878+ case GGML_OP_OUT_PROD:
1879+ if (op->type != GGML_TYPE_F32) {
1880+ return false ;
1881+ }
1882+
1883+ {
1884+ const enum ggml_type src0_type = op->src [0 ]->type ;
1885+ const enum ggml_type src1_type = op->src [1 ]->type ;
1886+
1887+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
1888+ return true ;
1889+ }
1890+
1891+ if (src0_type == GGML_TYPE_F32 && (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_F16)) {
1892+ return true ;
1893+ }
1894+
1895+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
1896+ return true ;
1897+ }
1898+
1899+ if (src0_type == GGML_TYPE_Q8_0 && (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_F16)) {
1900+ return true ;
1901+ }
1902+
1903+ if (src0_type == GGML_TYPE_Q4_0 && (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_F16)) {
1904+ return true ;
1905+ }
1906+ }
1907+
1908+ return false ;
18561909 case GGML_OP_ACC:
18571910 case GGML_OP_REPEAT:
18581911 case GGML_OP_SCALE:
18591912 case GGML_OP_CONV_TRANSPOSE_1D:
18601913 return true ;
18611914 case GGML_OP_CLAMP:
18621915 return op->src [0 ]->type == GGML_TYPE_F32;
1916+ case GGML_OP_SILU_BACK:
1917+ return op->type == GGML_TYPE_F32 &&
1918+ op->src [0 ] != NULL && op->src [1 ] != NULL &&
1919+ op->src [0 ]->type == GGML_TYPE_F32 &&
1920+ op->src [1 ]->type == GGML_TYPE_F32 &&
1921+ ggml_is_contiguous_1 (op->src [0 ]) &&
1922+ ggml_is_contiguous_1 (op->src [1 ]) &&
1923+ ggml_is_contiguous_1 (op) &&
1924+ ggml_are_same_shape (op, op->src [0 ]) &&
1925+ ggml_are_same_shape (op, op->src [1 ]);
18631926 case GGML_OP_SQR:
18641927 case GGML_OP_SQRT:
18651928 case GGML_OP_SIN:
@@ -1875,6 +1938,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
18751938 case GGML_OP_RMS_NORM:
18761939 case GGML_OP_L2_NORM:
18771940 return has_simdgroup_reduction && (op->ne [0 ] % 4 == 0 && ggml_is_contiguous_1 (op->src [0 ]));
1941+ case GGML_OP_RMS_NORM_BACK:
1942+ return has_simdgroup_reduction &&
1943+ op->type == GGML_TYPE_F32 &&
1944+ op->src [0 ] != NULL && op->src [1 ] != NULL &&
1945+ op->src [0 ]->type == GGML_TYPE_F32 &&
1946+ op->src [1 ]->type == GGML_TYPE_F32 &&
1947+ op->ne [0 ] % 4 == 0 &&
1948+ ggml_is_contiguous_1 (op->src [0 ]) &&
1949+ ggml_is_contiguous_1 (op->src [1 ]) &&
1950+ ggml_is_contiguous_1 (op) &&
1951+ ggml_are_same_shape (op, op->src [0 ]) &&
1952+ ggml_are_same_shape (op, op->src [1 ]);
18781953 case GGML_OP_ARGMAX:
18791954 return true ;
18801955 case GGML_OP_NORM:
@@ -2365,6 +2440,80 @@ static int ggml_metal_encode_node(
23652440 [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
23662441 }
23672442 } break ;
2443+ case GGML_OP_OUT_PROD:
2444+ {
2445+ GGML_ASSERT (dstt == GGML_TYPE_F32);
2446+ GGML_ASSERT (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q4_0);
2447+ GGML_ASSERT (src1t == GGML_TYPE_F32 || src1t == GGML_TYPE_F16);
2448+
2449+ id <MTLComputePipelineState > pipeline = nil ;
2450+
2451+ if (src0t == GGML_TYPE_Q8_0) {
2452+ if (src1t == GGML_TYPE_F32) {
2453+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_OUT_PROD_Q8_0_F32].pipeline ;
2454+ } else if (src1t == GGML_TYPE_F16) {
2455+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_OUT_PROD_Q8_0_F16].pipeline ;
2456+ } else {
2457+ GGML_ABORT (" fatal error" );
2458+ }
2459+ } else if (src0t == GGML_TYPE_Q4_0) {
2460+ if (src1t == GGML_TYPE_F32) {
2461+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_OUT_PROD_Q4_0_F32].pipeline ;
2462+ } else if (src1t == GGML_TYPE_F16) {
2463+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_OUT_PROD_Q4_0_F16].pipeline ;
2464+ } else {
2465+ GGML_ABORT (" fatal error" );
2466+ }
2467+ } else if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32) {
2468+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_OUT_PROD_F16_F32].pipeline ;
2469+ } else if (src0t == GGML_TYPE_F32 && src1t == GGML_TYPE_F32) {
2470+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_OUT_PROD_F32].pipeline ;
2471+ } else if (src0t == GGML_TYPE_F32 && src1t == GGML_TYPE_F16) {
2472+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_OUT_PROD_F32_F16].pipeline ;
2473+ } else if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F16) {
2474+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_OUT_PROD_F16_F16].pipeline ;
2475+ } else {
2476+ GGML_ABORT (" fatal error" );
2477+ }
2478+
2479+ ggml_metal_kargs_out_prod args = {
2480+ (int32_t ) ne00,
2481+ (int32_t ) ne01,
2482+ (int32_t ) ne02,
2483+ (int32_t ) ne03,
2484+ nb00,
2485+ nb01,
2486+ nb02,
2487+ nb03,
2488+ (int32_t ) ne10,
2489+ (int32_t ) ne11,
2490+ (int32_t ) ne12,
2491+ (int32_t ) ne13,
2492+ nb10,
2493+ nb11,
2494+ nb12,
2495+ nb13,
2496+ (int32_t ) ne0,
2497+ (int32_t ) ne1,
2498+ (int32_t ) ne2,
2499+ (int32_t ) ne3,
2500+ nb0,
2501+ nb1,
2502+ nb2,
2503+ nb3,
2504+ };
2505+
2506+ [encoder setComputePipelineState: pipeline];
2507+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
2508+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
2509+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
2510+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
2511+
2512+ const int threads = ne0 < 1 ? 1 : (int ) ne0;
2513+ const int nth = MIN ((int ) pipeline.maxTotalThreadsPerThreadgroup , threads);
2514+
2515+ [encoder dispatchThreadgroups: MTLSizeMake (ne1, ne2, ne3) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
2516+ } break ;
23682517 case GGML_OP_ADD_ID:
23692518 {
23702519 GGML_ASSERT (src0t == GGML_TYPE_F32);
@@ -2575,6 +2724,37 @@ static int ggml_metal_encode_node(
25752724
25762725 const int64_t n = ggml_nelements (dst);
25772726
2727+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2728+ } break ;
2729+ case GGML_OP_SILU_BACK:
2730+ {
2731+ GGML_ASSERT (src0 != NULL );
2732+ GGML_ASSERT (src1 != NULL );
2733+ GGML_ASSERT (src0t == GGML_TYPE_F32);
2734+ GGML_ASSERT (src1t == GGML_TYPE_F32);
2735+ GGML_ASSERT (dstt == GGML_TYPE_F32);
2736+ GGML_ASSERT (ggml_are_same_shape (dst, src0));
2737+ GGML_ASSERT (ggml_are_same_shape (dst, src1));
2738+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
2739+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
2740+ GGML_ASSERT (ggml_is_contiguous_1 (dst));
2741+
2742+ int64_t n = ggml_nelements (dst);
2743+
2744+ id <MTLComputePipelineState > pipeline = nil ;
2745+
2746+ if (n % 4 == 0 ) {
2747+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SILU_BACK_4].pipeline ;
2748+ n /= 4 ;
2749+ } else {
2750+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SILU_BACK].pipeline ;
2751+ }
2752+
2753+ [encoder setComputePipelineState: pipeline];
2754+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2755+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
2756+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
2757+
25782758 [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
25792759 } break ;
25802760 case GGML_OP_UNARY:
@@ -4508,6 +4688,59 @@ static int ggml_metal_encode_node(
45084688
45094689 [encoder setThreadgroupMemoryLength: 32 *sizeof (float ) atIndex: 0 ];
45104690
4691+ [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
4692+ } break ;
4693+ case GGML_OP_RMS_NORM_BACK:
4694+ {
4695+ GGML_ASSERT (src0 != NULL );
4696+ GGML_ASSERT (src1 != NULL );
4697+ GGML_ASSERT (ne00 % 4 == 0 );
4698+ GGML_ASSERT (dstt == GGML_TYPE_F32);
4699+ GGML_ASSERT (src0t == GGML_TYPE_F32);
4700+ GGML_ASSERT (src1t == GGML_TYPE_F32);
4701+ GGML_ASSERT (ggml_are_same_shape (dst, src0));
4702+ GGML_ASSERT (ggml_are_same_shape (dst, src1));
4703+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
4704+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
4705+ GGML_ASSERT (ggml_is_contiguous_1 (dst));
4706+
4707+ float eps;
4708+ memcpy (&eps, dst->op_params , sizeof (float ));
4709+
4710+ ggml_metal_kargs_rms_norm_back args = {
4711+ /* .ne00 =*/ ne00,
4712+ /* .ne00_4 =*/ ne00/4 ,
4713+ /* .nb01 =*/ nb01,
4714+ /* .nb02 =*/ nb02,
4715+ /* .nb03 =*/ nb03,
4716+ /* .nb11 =*/ nb11,
4717+ /* .nb12 =*/ nb12,
4718+ /* .nb13 =*/ nb13,
4719+ /* .nb1 =*/ nb1,
4720+ /* .nb2 =*/ nb2,
4721+ /* .nb3 =*/ nb3,
4722+ /* .eps =*/ eps,
4723+ };
4724+
4725+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM_BACK].pipeline ;
4726+
4727+ int nth = 32 ; // SIMD width
4728+
4729+ while (nth < ne00/4 && nth < (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
4730+ nth *= 2 ;
4731+ }
4732+
4733+ nth = MIN (nth, (int ) pipeline.maxTotalThreadsPerThreadgroup );
4734+ nth = MIN (nth, ne00/4 );
4735+
4736+ [encoder setComputePipelineState: pipeline];
4737+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
4738+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
4739+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
4740+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
4741+
4742+ [encoder setThreadgroupMemoryLength: 2 *32 *sizeof (float ) atIndex: 0 ];
4743+
45114744 [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
45124745 } break ;
45134746 case GGML_OP_L2_NORM:
0 commit comments