@@ -202,6 +202,15 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
202202 GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
203203 GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
204204 GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
205+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
206+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
207+ GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
208+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
209+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
210+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
211+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
212+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
213+ GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
205214 GGML_METAL_KERNEL_TYPE_RMS_NORM,
206215 GGML_METAL_KERNEL_TYPE_L2_NORM,
207216 GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -1166,6 +1175,15 @@ @implementation GGMLMetalClass
11661175 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true );
11671176 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true );
11681177 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true );
1178+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true );
1179+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true );
1180+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
1181+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true );
1182+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true );
1183+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true );
1184+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true );
1185+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true );
1186+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true );
11691187 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
11701188 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
11711189 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
@@ -1630,7 +1648,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16301648
16311649 if (!use_bfloat) {
16321650 for (size_t i = 0 , n = 3 ; i < n; ++i) {
1633- if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16) {
1651+ if (op->src [i] != NULL && ( op->src [i]->type == GGML_TYPE_BF16 || op-> type == GGML_TYPE_BF16) ) {
16341652 return false ;
16351653 }
16361654 }
@@ -1798,6 +1816,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
17981816 {
17991817 return op->ne [3 ] == 1 ;
18001818 }
1819+ case GGML_OP_SET_ROWS:
1820+ {
1821+ if (op->src [0 ]->type != GGML_TYPE_F32) {
1822+ return false ;
1823+ }
1824+
1825+ switch (op->type ) {
1826+ case GGML_TYPE_F32:
1827+ case GGML_TYPE_F16:
1828+ case GGML_TYPE_BF16:
1829+ case GGML_TYPE_Q8_0:
1830+ case GGML_TYPE_Q4_0:
1831+ case GGML_TYPE_Q4_1:
1832+ case GGML_TYPE_Q5_0:
1833+ case GGML_TYPE_Q5_1:
1834+ case GGML_TYPE_IQ4_NL:
1835+ return true ;
1836+ default :
1837+ return false ;
1838+ };
1839+ }
18011840 default :
18021841 return false ;
18031842 }
@@ -3757,13 +3796,74 @@ static bool ggml_metal_encode_node(
37573796 };
37583797
37593798 [encoder setComputePipelineState: pipeline];
3760- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3761- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
3762- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
3763- [encoder setBytes: &args length: sizeof (args) atIndex: 3 ];
3799+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
3800+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
3801+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
3802+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
37643803
37653804 [encoder dispatchThreadgroups: MTLSizeMake (ne10, ne11, 1 ) threadsPerThreadgroup: MTLSizeMake (32 , 1 , 1 )];
37663805 } break ;
3806+ case GGML_OP_SET_ROWS:
3807+ {
3808+ id <MTLComputePipelineState > pipeline = nil ;
3809+
3810+ switch (dst->type ) {
3811+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline ; break ;
3812+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline ; break ;
3813+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline ; break ;
3814+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline ; break ;
3815+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline ; break ;
3816+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline ; break ;
3817+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline ; break ;
3818+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline ; break ;
3819+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline ; break ;
3820+ default : GGML_ABORT (" not implemented" );
3821+ }
3822+
3823+ const int32_t nk0 = ne0/ggml_blck_size (dst->type );
3824+
3825+ int nth = 32 ; // SIMD width
3826+
3827+ while (nth < nk0 && nth < (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
3828+ nth *= 2 ;
3829+ }
3830+
3831+ int nrptg = 1 ;
3832+ if (nth > nk0) {
3833+ nrptg = (nth + nk0 - 1 )/nk0;
3834+ nth = nk0;
3835+
3836+ if (nrptg*nth > (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
3837+ nrptg--;
3838+ }
3839+ }
3840+
3841+ nth = MIN (nth, nk0);
3842+
3843+ ggml_metal_kargs_set_rows args = {
3844+ /* .nk0 =*/ nk0,
3845+ /* .ne01 =*/ ne01,
3846+ /* .nb01 =*/ nb01,
3847+ /* .nb02 =*/ nb02,
3848+ /* .nb03 =*/ nb03,
3849+ /* .ne11 =*/ ne11,
3850+ /* .ne12 =*/ ne12,
3851+ /* .nb10 =*/ nb10,
3852+ /* .nb11 =*/ nb11,
3853+ /* .nb12 =*/ nb12,
3854+ /* .nb1 =*/ nb1,
3855+ /* .nb2 =*/ nb2,
3856+ /* .nb3 =*/ nb3,
3857+ };
3858+
3859+ [encoder setComputePipelineState: pipeline];
3860+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
3861+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
3862+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
3863+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
3864+
3865+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nrptg - 1 )/nrptg, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, nrptg, 1 )];
3866+ } break ;
37673867 case GGML_OP_RMS_NORM:
37683868 {
37693869 GGML_ASSERT (ne00 % 4 == 0 );
0 commit comments