3535 GGML_METAL_KERNEL_TYPE_MUL_ROW,
3636 GGML_METAL_KERNEL_TYPE_DIV,
3737 GGML_METAL_KERNEL_TYPE_DIV_ROW,
38+ GGML_METAL_KERNEL_TYPE_REPEAT_F32,
39+ GGML_METAL_KERNEL_TYPE_REPEAT_F16,
40+ GGML_METAL_KERNEL_TYPE_REPEAT_I32,
41+ GGML_METAL_KERNEL_TYPE_REPEAT_I16,
3842 GGML_METAL_KERNEL_TYPE_SCALE,
3943 GGML_METAL_KERNEL_TYPE_SCALE_4,
4044 GGML_METAL_KERNEL_TYPE_CLAMP,
@@ -485,6 +489,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
485489 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true );
486490 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV, div, true );
487491 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true );
492+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true );
493+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true );
494+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true );
495+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true );
488496 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SCALE, scale, true );
489497 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true );
490498 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true );
@@ -746,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
746754 case GGML_OP_ACC:
747755 case GGML_OP_MUL:
748756 case GGML_OP_DIV:
757+ case GGML_OP_REPEAT:
749758 case GGML_OP_SCALE:
750759 case GGML_OP_CLAMP:
751760 case GGML_OP_SQR:
@@ -979,8 +988,6 @@ static enum ggml_status ggml_metal_graph_compute(
979988 switch (dst->op ) {
980989 case GGML_OP_CONCAT:
981990 {
982- const int64_t nb = ne00;
983-
984991 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CONCAT].pipeline ;
985992
986993 [encoder setComputePipelineState: pipeline];
@@ -1011,7 +1018,6 @@ static enum ggml_status ggml_metal_graph_compute(
10111018 [encoder setBytes: &nb1 length: sizeof (nb1) atIndex: 24 ];
10121019 [encoder setBytes: &nb2 length: sizeof (nb2) atIndex: 25 ];
10131020 [encoder setBytes: &nb3 length: sizeof (nb3) atIndex: 26 ];
1014- [encoder setBytes: &nb length: sizeof (nb) atIndex: 27 ];
10151021
10161022 const int nth = MIN (1024 , ne0);
10171023
@@ -1021,11 +1027,14 @@ static enum ggml_status ggml_metal_graph_compute(
10211027 case GGML_OP_MUL:
10221028 case GGML_OP_DIV:
10231029 {
1030+ GGML_ASSERT (src0t == GGML_TYPE_F32);
1031+ GGML_ASSERT (src1t == GGML_TYPE_F32);
1032+
10241033 const size_t offs = 0 ;
10251034
10261035 bool bcast_row = false ;
10271036
1028- int64_t nb = ne00;
1037+ int64_t nb = ne00; // used by the "row" kernels
10291038
10301039 id <MTLComputePipelineState > pipeline = nil ;
10311040
@@ -1094,6 +1103,42 @@ static enum ggml_status ggml_metal_graph_compute(
10941103 [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
10951104 }
10961105 } break ;
1106+ case GGML_OP_REPEAT:
1107+ {
1108+ id <MTLComputePipelineState > pipeline;
1109+
1110+ switch (src0t) {
1111+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline ; break ;
1112+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline ; break ;
1113+ case GGML_TYPE_I32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline ; break ;
1114+ case GGML_TYPE_I16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline ; break ;
1115+ default : GGML_ASSERT (false );
1116+ }
1117+
1118+ [encoder setComputePipelineState: pipeline];
1119+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1120+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1121+ [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 2 ];
1122+ [encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 3 ];
1123+ [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 4 ];
1124+ [encoder setBytes: &ne03 length: sizeof (ne03) atIndex: 5 ];
1125+ [encoder setBytes: &nb00 length: sizeof (nb00) atIndex: 6 ];
1126+ [encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 7 ];
1127+ [encoder setBytes: &nb02 length: sizeof (nb02) atIndex: 8 ];
1128+ [encoder setBytes: &nb03 length: sizeof (nb03) atIndex: 9 ];
1129+ [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 10 ];
1130+ [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 11 ];
1131+ [encoder setBytes: &ne2 length: sizeof (ne2) atIndex: 12 ];
1132+ [encoder setBytes: &ne3 length: sizeof (ne3) atIndex: 13 ];
1133+ [encoder setBytes: &nb0 length: sizeof (nb0) atIndex: 14 ];
1134+ [encoder setBytes: &nb1 length: sizeof (nb1) atIndex: 15 ];
1135+ [encoder setBytes: &nb2 length: sizeof (nb2) atIndex: 16 ];
1136+ [encoder setBytes: &nb3 length: sizeof (nb3) atIndex: 17 ];
1137+
1138+ const int nth = MIN ((int ) pipeline.maxTotalThreadsPerThreadgroup , ne0);
1139+
1140+ [encoder dispatchThreadgroups: MTLSizeMake (ne1, ne2, ne3) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
1141+ } break ;
10971142 case GGML_OP_ACC:
10981143 {
10991144 GGML_ASSERT (src0t == GGML_TYPE_F32);
0 commit comments