@@ -512,6 +512,7 @@ struct ggml_backend_opencl_context {
512512 cl_kernel kernel_conv_2d_f16;
513513 cl_kernel kernel_conv_2d_f32;
514514 cl_kernel kernel_conv_2d_f16_f32;
515+ cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4;
515516 cl_kernel kernel_timestep_embedding;
516517 cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
517518 cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
@@ -1888,6 +1889,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
18881889 }
18891890 }
18901891
1892+ // ssm_conv
1893+ {
1894+ #ifdef GGML_OPENCL_EMBED_KERNELS
1895+ const std::string kernel_src {
1896+ #include " ssm_conv.cl.h"
1897+ };
1898+ #else
1899+ const std::string kernel_src = read_file (" ssm_conv.cl" );
1900+ #endif
1901+ cl_program prog =
1902+ build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
1903+
1904+ CL_CHECK ((backend_ctx->kernel_ssm_conv_f32_f32 = clCreateKernel (prog, " kernel_ssm_conv_f32_f32" , &err), err));
1905+ CL_CHECK ((backend_ctx->kernel_ssm_conv_f32_f32_4 = clCreateKernel (prog, " kernel_ssm_conv_f32_f32_4" , &err), err));
1906+ CL_CHECK (clReleaseProgram (prog));
1907+ GGML_LOG_CONT (" ." );
1908+ }
1909+
18911910 // mul_mv_id_q4_0_f32_8x_flat
18921911 {
18931912#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3074,6 +3093,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
30743093 return (op->src [0 ]->type == GGML_TYPE_F16 && op->src [1 ]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
30753094 (op->src [0 ]->type == GGML_TYPE_F32 && op->src [1 ]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
30763095 (op->src [0 ]->type == GGML_TYPE_F16 && op->src [1 ]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
3096+ case GGML_OP_SSM_CONV:
3097+ return (op->src [0 ]->type == GGML_TYPE_F32 && op->src [1 ]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
30773098 case GGML_OP_CONCAT:
30783099 return op->src [0 ]->type == GGML_TYPE_F32 && op->src [1 ]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
30793100 case GGML_OP_TIMESTEP_EMBEDDING:
@@ -5415,6 +5436,70 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const
54155436 backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
54165437}
54175438
5439+ static void ggml_cl_ssm_conv (ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5440+ GGML_ASSERT (src0);
5441+ GGML_ASSERT (src0->extra );
5442+ GGML_ASSERT (src1);
5443+ GGML_ASSERT (src1->extra );
5444+ GGML_ASSERT (dst);
5445+ GGML_ASSERT (dst->extra );
5446+
5447+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
5448+
5449+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra ;
5450+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra ;
5451+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra ;
5452+
5453+ cl_ulong offset0 = extra0->offset + src0->view_offs ;
5454+ cl_ulong offset1 = extra1->offset + src1->view_offs ;
5455+ cl_ulong offsetd = extrad->offset + dst->view_offs ;
5456+
5457+ int ne01 = src0->ne [1 ];
5458+ cl_ulong nb00 = src0->nb [0 ];
5459+ cl_ulong nb01 = src0->nb [1 ];
5460+ cl_ulong nb02 = src0->nb [2 ];
5461+
5462+ int ne10 = src1->ne [0 ];
5463+ cl_ulong nb11 = src1->nb [1 ];
5464+
5465+ int ne1 = dst->ne [1 ];
5466+ int ne2 = dst->ne [2 ];
5467+ cl_ulong nb0 = dst->nb [0 ];
5468+ cl_ulong nb1 = dst->nb [1 ];
5469+ cl_ulong nb2 = dst->nb [2 ];
5470+
5471+ cl_kernel kernel = backend_ctx->kernel_ssm_conv_f32_f32 ;
5472+
5473+ if (ne10 % 4 == 0 ) {
5474+ kernel = backend_ctx->kernel_ssm_conv_f32_f32_4 ;
5475+ }
5476+
5477+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
5478+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
5479+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
5480+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
5481+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extrad->data_device ));
5482+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offsetd));
5483+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (cl_ulong), &nb00));
5484+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &nb01));
5485+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &nb02));
5486+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne10));
5487+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &nb11));
5488+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb0));
5489+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb1));
5490+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb2));
5491+
5492+ size_t global_work_size[] = {(size_t )ne01, (size_t )ne1, (size_t )ne2};
5493+ size_t local_work_size[] = {64 , 1 , 1 };
5494+
5495+ size_t * local_work_size_ptr = local_work_size;
5496+ if (ne01 % 64 != 0 && !backend_ctx->non_uniform_workgroups ) {
5497+ local_work_size_ptr = nullptr ;
5498+ }
5499+
5500+ backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size_ptr, dst);
5501+ }
5502+
54185503static void ggml_cl_gelu (ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
54195504 GGML_ASSERT (src0);
54205505 GGML_ASSERT (src0->extra );
@@ -9432,6 +9517,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
94329517 }
94339518 func = ggml_cl_conv_2d;
94349519 break ;
9520+ case GGML_OP_SSM_CONV:
9521+ if (!any_on_device) {
9522+ return false ;
9523+ }
9524+ func = ggml_cl_ssm_conv;
9525+ break ;
94359526 case GGML_OP_CONCAT:
94369527 if (!any_on_device) {
94379528 return false ;
0 commit comments