Skip to content

Commit 95949c5

Browse files
committed
opencl: add mean
1 parent d478554 commit 95949c5

File tree

3 files changed

+120
-0
lines changed

3 files changed

+120
-0
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ set(GGML_OPENCL_KERNELS
7070
group_norm
7171
im2col_f32
7272
im2col_f16
73+
mean
7374
mul_mat_Ab_Bi_8x4
7475
mul_mv_f16_f16
7576
mul_mv_f16_f32_1row

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ struct ggml_backend_opencl_context {
451451
cl_kernel kernel_scale;
452452
cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4;
453453
cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4;
454+
cl_kernel kernel_mean_f32;
454455
cl_kernel kernel_silu, kernel_silu_4;
455456
cl_kernel kernel_gelu, kernel_gelu_4;
456457
cl_kernel kernel_gelu_erf, kernel_gelu_erf_4;
@@ -1596,6 +1597,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
15961597
GGML_LOG_CONT(".");
15971598
}
15981599

1600+
// mean
1601+
{
1602+
#ifdef GGML_OPENCL_EMBED_KERNELS
1603+
const std::string kernel_src {
1604+
#include "mean.cl.h"
1605+
};
1606+
#else
1607+
const std::string kernel_src = read_file("mean.cl");
1608+
#endif
1609+
cl_program prog =
1610+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1611+
1612+
CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, "kernel_mean_f32", &err), err));
1613+
1614+
CL_CHECK(clReleaseProgram(prog));
1615+
GGML_LOG_CONT(".");
1616+
}
1617+
15991618
// sub
16001619
{
16011620
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3123,6 +3142,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
31233142
return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32;
31243143
}
31253144
case GGML_OP_SUM_ROWS:
3145+
case GGML_OP_MEAN:
31263146
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
31273147
case GGML_OP_FLASH_ATTN_EXT:
31283148
{
@@ -5341,6 +5361,60 @@ static void ggml_cl_sqrt(ggml_backend_t backend, const ggml_tensor * src0, const
53415361
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
53425362
}
53435363

5364+
static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5365+
GGML_ASSERT(src0);
5366+
GGML_ASSERT(src0->extra);
5367+
GGML_ASSERT(dst);
5368+
GGML_ASSERT(dst->extra);
5369+
GGML_UNUSED(src1);
5370+
5371+
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
5372+
GGML_ASSERT(ggml_is_contiguous(src0));
5373+
5374+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5375+
5376+
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
5377+
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
5378+
5379+
cl_ulong offset0 = extra0->offset + src0->view_offs;
5380+
cl_ulong offsetd = extrad->offset + dst->view_offs;
5381+
5382+
const int ne00 = src0->ne[0];
5383+
const int ne01 = src0->ne[1];
5384+
const int ne02 = src0->ne[2];
5385+
const int ne03 = src0->ne[3];
5386+
5387+
const cl_ulong nb01 = src0->nb[1];
5388+
const cl_ulong nb02 = src0->nb[2];
5389+
const cl_ulong nb03 = src0->nb[3];
5390+
5391+
const cl_ulong nb1 = dst->nb[1];
5392+
const cl_ulong nb2 = dst->nb[2];
5393+
const cl_ulong nb3 = dst->nb[3];
5394+
5395+
cl_kernel kernel = backend_ctx->kernel_mean_f32;
5396+
5397+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5398+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
5399+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
5400+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
5401+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
5402+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
5403+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
5404+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
5405+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
5406+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
5407+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
5408+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb1));
5409+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2));
5410+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3));
5411+
5412+
size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03};
5413+
size_t local_work_size[] = {(size_t)64, 1, 1};
5414+
5415+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
5416+
}
5417+
53445418
static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
53455419
GGML_ASSERT(src0);
53465420
GGML_ASSERT(src0->extra);
@@ -9251,6 +9325,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
92519325
}
92529326
func = ggml_cl_sqrt;
92539327
break;
9328+
case GGML_OP_MEAN:
9329+
if (!any_on_device) {
9330+
return false;
9331+
}
9332+
func = ggml_cl_mean;
9333+
break;
92549334
case GGML_OP_UNARY:
92559335
switch (ggml_get_unary_op(tensor)) {
92569336
case GGML_UNARY_OP_GELU:
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
2+
kernel void kernel_mean_f32(
3+
global float * src0,
4+
ulong offset0,
5+
global float * dst,
6+
ulong offsetd,
7+
int ne00,
8+
int ne01,
9+
int ne02,
10+
int ne03,
11+
ulong nb01,
12+
ulong nb02,
13+
ulong nb03,
14+
ulong nb1,
15+
ulong nb2,
16+
ulong nb3
17+
) {
18+
src0 = (global float *)((global char *)src0 + offset0);
19+
dst = (global float *)((global char *)dst + offsetd);
20+
21+
int i3 = get_global_id(2);
22+
int i2 = get_global_id(1);
23+
int i1 = get_global_id(0);
24+
25+
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
26+
return;
27+
}
28+
29+
global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
30+
global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
31+
32+
float row_sum = 0;
33+
34+
for (int i0 = 0; i0 < ne00; i0++) {
35+
row_sum += src_row[i0];
36+
}
37+
38+
dst_row[0] = row_sum / ne00;
39+
}

0 commit comments

Comments
 (0)