@@ -306,6 +306,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
306306 GGML_METAL_KERNEL_TYPE_IM2COL_F32,
307307 GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
308308 GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
309+ GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32,
310+ GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
309311 GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
310312 GGML_METAL_KERNEL_TYPE_PAD_F32,
311313 GGML_METAL_KERNEL_TYPE_ARANGE_F32,
@@ -870,6 +872,8 @@ @implementation GGMLMetalClass
870872 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true );
871873 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true );
872874 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true );
875+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true );
876+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true );
873877 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true );
874878 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true );
875879 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true );
@@ -1069,6 +1073,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
10691073 case GGML_OP_REPEAT:
10701074 case GGML_OP_SCALE:
10711075 case GGML_OP_CLAMP:
1076+ case GGML_OP_CONV_TRANSPOSE_1D:
10721077 return true ;
10731078 case GGML_OP_SQR:
10741079 case GGML_OP_SQRT:
@@ -3138,6 +3143,49 @@ static void ggml_metal_encode_node(
31383143 [encoder dispatchThreadgroups: MTLSizeMake (IC, OH, OW) threadsPerThreadgroup: MTLSizeMake (N, KH, KW)];
31393144 }
31403145 } break ;
3146+ case GGML_OP_CONV_TRANSPOSE_1D:
3147+ {
3148+ GGML_ASSERT (ggml_is_contiguous (src0));
3149+ GGML_ASSERT (ggml_is_contiguous (src1));
3150+ GGML_ASSERT (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
3151+ GGML_ASSERT (src1->type == GGML_TYPE_F32);
3152+ GGML_ASSERT ( dst->type == GGML_TYPE_F32);
3153+
3154+ const int32_t s0 = ((const int32_t *)(dst->op_params ))[0 ];
3155+
3156+ const int32_t IC = src1->ne [1 ];
3157+ const int32_t IL = src1->ne [0 ];
3158+
3159+ const int32_t K = src0->ne [0 ];
3160+
3161+ const int32_t OL = dst->ne [0 ];
3162+ const int32_t OC = dst->ne [1 ];
3163+
3164+ id <MTLComputePipelineState > pipeline;
3165+
3166+ switch (src0->type ) {
3167+ case GGML_TYPE_F32: {
3168+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline ;
3169+ } break ;
3170+ case GGML_TYPE_F16: {
3171+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline ;
3172+ } break ;
3173+ default : GGML_ABORT (" fatal error" );
3174+ };
3175+
3176+ [encoder setComputePipelineState: pipeline];
3177+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3178+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
3179+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
3180+ [encoder setBytes: &IC length: sizeof ( int32_t ) atIndex: 3 ];
3181+ [encoder setBytes: &IL length: sizeof ( int32_t ) atIndex: 4 ];
3182+ [encoder setBytes: &K length: sizeof ( int32_t ) atIndex: 5 ];
3183+ [encoder setBytes: &s0 length: sizeof ( int32_t ) atIndex: 6 ];
3184+ [encoder setBytes: &nb0 length: sizeof (uint64_t ) atIndex: 7 ];
3185+ [encoder setBytes: &nb1 length: sizeof (uint64_t ) atIndex: 8 ];
3186+
3187+ [encoder dispatchThreadgroups: MTLSizeMake (OL, OC, 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
3188+ } break ;
31413189 case GGML_OP_UPSCALE:
31423190 {
31433191 GGML_ASSERT (src0->type == GGML_TYPE_F32);
0 commit comments