Skip to content

Commit dd64bc9

Browse files
committed
opencl: add ssm_conv
1 parent 95949c5 commit dd64bc9

File tree

3 files changed

+169
-0
lines changed

3 files changed

+169
-0
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ set(GGML_OPENCL_KERNELS
112112
softmax_f16
113113
sqr
114114
sqrt
115+
ssm_conv
115116
sub
116117
sum_rows
117118
transpose

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,7 @@ struct ggml_backend_opencl_context {
512512
cl_kernel kernel_conv_2d_f16;
513513
cl_kernel kernel_conv_2d_f32;
514514
cl_kernel kernel_conv_2d_f16_f32;
515+
cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4;
515516
cl_kernel kernel_timestep_embedding;
516517
cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
517518
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
@@ -1888,6 +1889,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
18881889
}
18891890
}
18901891

1892+
// ssm_conv
1893+
{
1894+
#ifdef GGML_OPENCL_EMBED_KERNELS
1895+
const std::string kernel_src {
1896+
#include "ssm_conv.cl.h"
1897+
};
1898+
#else
1899+
const std::string kernel_src = read_file("ssm_conv.cl");
1900+
#endif
1901+
cl_program prog =
1902+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1903+
1904+
CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32", &err), err));
1905+
CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32_4 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32_4", &err), err));
1906+
CL_CHECK(clReleaseProgram(prog));
1907+
GGML_LOG_CONT(".");
1908+
}
1909+
18911910
// mul_mv_id_q4_0_f32_8x_flat
18921911
{
18931912
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3074,6 +3093,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
30743093
return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
30753094
(op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
30763095
(op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
3096+
case GGML_OP_SSM_CONV:
3097+
return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
30773098
case GGML_OP_CONCAT:
30783099
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
30793100
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -5415,6 +5436,70 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const
54155436
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
54165437
}
54175438

5439+
static void ggml_cl_ssm_conv(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5440+
GGML_ASSERT(src0);
5441+
GGML_ASSERT(src0->extra);
5442+
GGML_ASSERT(src1);
5443+
GGML_ASSERT(src1->extra);
5444+
GGML_ASSERT(dst);
5445+
GGML_ASSERT(dst->extra);
5446+
5447+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5448+
5449+
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
5450+
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
5451+
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
5452+
5453+
cl_ulong offset0 = extra0->offset + src0->view_offs;
5454+
cl_ulong offset1 = extra1->offset + src1->view_offs;
5455+
cl_ulong offsetd = extrad->offset + dst->view_offs;
5456+
5457+
int ne01 = src0->ne[1];
5458+
cl_ulong nb00 = src0->nb[0];
5459+
cl_ulong nb01 = src0->nb[1];
5460+
cl_ulong nb02 = src0->nb[2];
5461+
5462+
int ne10 = src1->ne[0];
5463+
cl_ulong nb11 = src1->nb[1];
5464+
5465+
int ne1 = dst->ne[1];
5466+
int ne2 = dst->ne[2];
5467+
cl_ulong nb0 = dst->nb[0];
5468+
cl_ulong nb1 = dst->nb[1];
5469+
cl_ulong nb2 = dst->nb[2];
5470+
5471+
cl_kernel kernel = backend_ctx->kernel_ssm_conv_f32_f32;
5472+
5473+
if (ne10 % 4 == 0) {
5474+
kernel = backend_ctx->kernel_ssm_conv_f32_f32_4;
5475+
}
5476+
5477+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5478+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
5479+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
5480+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
5481+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
5482+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
5483+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00));
5484+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
5485+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
5486+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10));
5487+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11));
5488+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb0));
5489+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1));
5490+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2));
5491+
5492+
size_t global_work_size[] = {(size_t)ne01, (size_t)ne1, (size_t)ne2};
5493+
size_t local_work_size[] = {64, 1, 1};
5494+
5495+
size_t * local_work_size_ptr = local_work_size;
5496+
if (ne01 % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
5497+
local_work_size_ptr = nullptr;
5498+
}
5499+
5500+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
5501+
}
5502+
54185503
static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
54195504
GGML_ASSERT(src0);
54205505
GGML_ASSERT(src0->extra);
@@ -9432,6 +9517,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
94329517
}
94339518
func = ggml_cl_conv_2d;
94349519
break;
9520+
case GGML_OP_SSM_CONV:
9521+
if (!any_on_device) {
9522+
return false;
9523+
}
9524+
func = ggml_cl_ssm_conv;
9525+
break;
94359526
case GGML_OP_CONCAT:
94369527
if (!any_on_device) {
94379528
return false;
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
kernel void kernel_ssm_conv_f32_f32(
2+
global char * src0,
3+
ulong offset0,
4+
global char * src1,
5+
ulong offset1,
6+
global char * dst,
7+
ulong offsetd,
8+
ulong nb00,
9+
ulong nb01,
10+
ulong nb02,
11+
int ne10,
12+
ulong nb11,
13+
ulong nb0,
14+
ulong nb1,
15+
ulong nb2
16+
){
17+
src0 = src0 + offset0;
18+
src1 = src1 + offset1;
19+
dst = dst + offsetd;
20+
21+
int ir = get_global_id(0);
22+
int i2 = get_global_id(1);
23+
int i3 = get_global_id(2);
24+
25+
int nc = ne10;
26+
27+
global float * s = (global float *) (src0 + ir*nb01 + i2*nb00 + i3*nb02);
28+
global float * c = (global float *) (src1 + ir*nb11);
29+
global float * d = (global float *) (dst + ir*nb0 + i2*nb1 + i3*nb2);
30+
31+
float sumf = 0.0f;
32+
33+
for (int i0 = 0; i0 < nc; ++i0) {
34+
sumf += s[i0] * c[i0];
35+
}
36+
37+
d[0] = sumf;
38+
}
39+
40+
kernel void kernel_ssm_conv_f32_f32_4(
41+
global char * src0,
42+
ulong offset0,
43+
global char * src1,
44+
ulong offset1,
45+
global char * dst,
46+
ulong offsetd,
47+
ulong nb00,
48+
ulong nb01,
49+
ulong nb02,
50+
int ne10,
51+
ulong nb11,
52+
ulong nb0,
53+
ulong nb1,
54+
ulong nb2
55+
) {
56+
src0 = src0 + offset0;
57+
src1 = src1 + offset1;
58+
dst = dst + offsetd;
59+
60+
int ir = get_global_id(0);
61+
int i2 = get_global_id(1);
62+
int i3 = get_global_id(2);
63+
64+
int nc = ne10;
65+
66+
global float4 * s = (global float4 *) (src0 + ir*nb01 + i2*nb00 + i3*nb02);
67+
global float4 * c = (global float4 *) (src1 + ir*nb11);
68+
global float * d = (global float *) (dst + ir*nb0 + i2*nb1 + i3*nb2);
69+
70+
float sumf = 0.0f;
71+
72+
for (int i0 = 0; i0 < nc/4; ++i0) {
73+
sumf += dot(s[i0], c[i0]);
74+
}
75+
76+
d[0] = sumf;
77+
}

0 commit comments

Comments
 (0)