Skip to content

Commit 9d7c518

Browse files
ye-NXCISC
andauthored
sycl: add CONCAT operator support (ggml-org#16047)
* sycl: add CONCAT operator support * cleanup: remove stray lines added by mistake * fix: code format issues in concat.cpp and tests/test-backend-ops.cpp * chore: fix editorconfig violations * cleanup: drop unnecessary i16 type support * docs: update sycl-csv and regenerate ops.md * update docs/ops.md * fix: adapt to upstream master changes after rebase * fix: remove empty files * fix: drop whitespace --------- Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 22c8c3c commit 9d7c518

File tree

4 files changed

+73
-66
lines changed

4 files changed

+73
-66
lines changed

docs/ops.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Legend:
2424
| ARGSORT ||||||||||
2525
| CEIL |||| 🟡 ||||||
2626
| CLAMP ||||| 🟡 | 🟡 || 🟡 ||
27-
| CONCAT |||| 🟡 || 🟡 | 🟡 |||
27+
| CONCAT |||| 🟡 || 🟡 | |||
2828
| CONT || 🟡 |||| 🟡 | 🟡 | 🟡 ||
2929
| CONV_2D |||| 🟡 ||||||
3030
| CONV_2D_DW ||||||||||

docs/ops/SYCL.csv

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9307,37 +9307,37 @@
93079307
"SYCL0","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=24,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=0,v=0,inplace=1","support","1","yes","SYCL"
93089308
"SYCL0","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=24,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=1,v=0,inplace=1","support","1","yes","SYCL"
93099309
"SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","1","yes","SYCL"
9310-
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","0","no","SYCL"
9310+
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","0","yes","SYCL"
93119311
"SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","1","yes","SYCL"
9312-
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","0","no","SYCL"
9312+
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","0","yes","SYCL"
93139313
"SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","1","yes","SYCL"
9314-
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","0","no","SYCL"
9314+
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","0","yes","SYCL"
93159315
"SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","1","yes","SYCL"
9316-
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","0","no","SYCL"
9316+
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","0","yes","SYCL"
93179317
"SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","1","yes","SYCL"
9318-
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","0","no","SYCL"
9318+
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","0","yes","SYCL"
93199319
"SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","1","yes","SYCL"
9320-
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","0","no","SYCL"
9320+
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","0","yes","SYCL"
93219321
"SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","1","yes","SYCL"
9322-
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","0","no","SYCL"
9322+
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","0","yes","SYCL"
93239323
"SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","1","yes","SYCL"
9324-
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","0","no","SYCL"
9324+
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","0","yes","SYCL"
93259325
"SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","1","yes","SYCL"
9326-
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","0","no","SYCL"
9326+
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","0","yes","SYCL"
93279327
"SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","1","yes","SYCL"
9328-
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","0","no","SYCL"
9328+
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","0","yes","SYCL"
93299329
"SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","1","yes","SYCL"
9330-
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","0","no","SYCL"
9330+
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","0","yes","SYCL"
93319331
"SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","1","yes","SYCL"
9332-
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","0","no","SYCL"
9332+
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","0","yes","SYCL"
93339333
"SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","1","yes","SYCL"
9334-
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","0","no","SYCL"
9334+
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","0","yes","SYCL"
93359335
"SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","1","yes","SYCL"
9336-
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","0","no","SYCL"
9336+
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","0","yes","SYCL"
93379337
"SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","1","yes","SYCL"
9338-
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","0","no","SYCL"
9338+
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","0","yes","SYCL"
93399339
"SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","SYCL"
9340-
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","0","no","SYCL"
9340+
"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","0","yes","SYCL"
93419341
"SYCL0","ARGSORT","type=f32,ne=[8,1,1,1],order=0","support","1","yes","SYCL"
93429342
"SYCL0","ARGSORT","type=f32,ne=[16,10,10,10],order=0","support","1","yes","SYCL"
93439343
"SYCL0","ARGSORT","type=f32,ne=[60,10,10,10],order=0","support","1","yes","SYCL"

ggml/src/ggml-sycl/concat.cpp

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,13 @@
1111
//
1212

1313
#include "concat.hpp"
14-
#include "common.hpp"
1514

16-
static void concat_f32_dim0(const float *x, const float *y, float *dst,
15+
static inline size_t elem_size(ggml_type t) {
16+
return ggml_type_size(t) / ggml_blck_size(t);
17+
}
18+
19+
template <typename T>
20+
static void concat_T_dim0(const T *x, const T *y, T *dst,
1721
const int ne0, const int ne00,
1822
const sycl::nd_item<3> &item_ct1) {
1923
int nidx = item_ct1.get_local_id(2) +
@@ -36,7 +40,8 @@ static void concat_f32_dim0(const float *x, const float *y, float *dst,
3640
}
3741
}
3842

39-
static void concat_f32_dim1(const float *x, const float *y, float *dst,
43+
template <typename T>
44+
static void concat_T_dim1(const T *x, const T *y, T *dst,
4045
const int ne0, const int ne01,
4146
const sycl::nd_item<3> &item_ct1) {
4247
int nidx = item_ct1.get_local_id(2) +
@@ -59,7 +64,8 @@ static void concat_f32_dim1(const float *x, const float *y, float *dst,
5964
}
6065
}
6166

62-
static void concat_f32_dim2(const float *x, const float *y, float *dst,
67+
template <typename T>
68+
static void concat_T_dim2(const T *x, const T *y, T *dst,
6369
const int ne0, const int ne02,
6470
const sycl::nd_item<3> &item_ct1) {
6571
int nidx = item_ct1.get_local_id(2) +
@@ -82,45 +88,35 @@ static void concat_f32_dim2(const float *x, const float *y, float *dst,
8288
}
8389
}
8490

85-
static void concat_f32_sycl(const float *x, const float *y, float *dst,
91+
template <typename T>
92+
static void concat_T_sycl(const T *x, const T *y, T *dst,
8693
int ne00, int ne01, int ne02, int ne0, int ne1,
8794
int ne2, int dim, queue_ptr stream) {
8895
int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
8996
sycl::range<3> gridDim(ne2, ne1, num_blocks);
9097
switch (dim) {
9198
case 0:
92-
stream->parallel_for(
93-
sycl::nd_range<3>(gridDim *
94-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
95-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
96-
[=](sycl::nd_item<3> item_ct1) {
97-
concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1);
98-
});
99-
break;
99+
stream->parallel_for(sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
100+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
101+
[=](sycl::nd_item<3> item_ct1) { concat_T_dim0<T>(x, y, dst, ne0, ne00, item_ct1); });
102+
break;
100103
case 1:
101-
stream->parallel_for(
102-
sycl::nd_range<3>(gridDim *
103-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
104-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
105-
[=](sycl::nd_item<3> item_ct1) {
106-
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
107-
});
108-
break;
104+
stream->parallel_for(sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
105+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
106+
[=](sycl::nd_item<3> item_ct1) { concat_T_dim1<T>(x, y, dst, ne0, ne01, item_ct1); });
107+
break;
109108
// dim >=2 will be dispatched to the default path
110109
default:
111-
stream->parallel_for(
112-
sycl::nd_range<3>(gridDim *
113-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
114-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
115-
[=](sycl::nd_item<3> item_ct1) {
116-
concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1);
117-
});
118-
break;
110+
stream->parallel_for(sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
111+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
112+
[=](sycl::nd_item<3> item_ct1) { concat_T_dim2<T>(x, y, dst, ne0, ne02, item_ct1); });
113+
break;
119114
}
120115
}
121116

122117
// non-contiguous kernel (slow)
123-
static void concat_f32_sycl_non_cont(
118+
template<typename T>
119+
static void concat_T_sycl_non_cont(
124120
queue_ptr stream, const char *src0, const char *src1, char *dst,
125121
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, uint64_t nb00,
126122
uint64_t nb01, uint64_t nb02, uint64_t nb03, int64_t /*ne10*/,
@@ -137,24 +133,25 @@ static void concat_f32_sycl_non_cont(
137133
int64_t o[4] = { 0, 0, 0, 0 };
138134
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
139135

140-
const float * x;
136+
const T * x;
141137

142138
for (int i0 = item_ct1.get_local_id(2); i0 < ne0; i0 += item_ct1.get_local_range(2)) {
143139
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
144-
x = (const float *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00);
140+
x = (const T *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00);
145141
} else {
146-
x = (const float *) (src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 + (i1 - o[1]) * nb11 +
142+
x = (const T *) (src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 + (i1 - o[1]) * nb11 +
147143
(i0 - o[0]) * nb10);
148144
}
149145

150-
float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
146+
T *y = (T *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
151147

152148
*y = *x;
153149
}
154150
});
155151
}
156152

157-
void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
153+
template <typename T>
154+
void concat_impl_sycl(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
158155
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
159156
const ggml_tensor * src0 = dst->src[0];
160157
const ggml_tensor * src1 = dst->src[1];
@@ -163,29 +160,43 @@ void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
163160
const int32_t dim = ((int32_t *) dst->op_params)[0];
164161

165162
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
166-
const float * src0_d = (const float *) src0->data;
167-
const float * src1_d = (const float *) src1->data;
168-
169-
float * dst_d = (float *) dst->data;
170-
163+
const T * src0_d = (const T *) src0->data;
164+
const T * src1_d = (const T *) src1->data;
165+
T * dst_d = (T *) dst->data;
166+
size_t type_size = elem_size(dst->type);
171167
if (dim != 3) {
172168
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
173-
concat_f32_sycl(src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4),
174-
dst_d + i3 * (dst->nb[3] / 4), src0->ne[0], src0->ne[1], src0->ne[2], dst->ne[0],
169+
concat_T_sycl<T>(src0_d + i3 * (src0->nb[3] / type_size), src1_d + i3 * (src1->nb[3] / type_size),
170+
dst_d + i3 * (dst->nb[3] / type_size), src0->ne[0], src0->ne[1], src0->ne[2], dst->ne[0],
175171
dst->ne[1], dst->ne[2], dim, stream);
176172
}
177173
} else {
178174
const size_t size0 = ggml_nbytes(src0);
179175
const size_t size1 = ggml_nbytes(src1);
180176

181177
SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d, src0_d, size0).wait()));
182-
SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d + size0 / 4, src1_d, size1).wait()));
178+
SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d + size0 / type_size, src1_d, size1).wait()));
183179
}
184180
} else {
185-
concat_f32_sycl_non_cont(stream, (const char *) src0->data, (const char *) src1->data, (char *) dst->data,
181+
concat_T_sycl_non_cont<T>(stream, (const char *) src0->data, (const char *) src1->data, (char *) dst->data,
186182
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1],
187183
src0->nb[2], src0->nb[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
188184
src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3], dst->ne[0], dst->ne[1], dst->ne[2],
189185
dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim);
190186
}
191187
}
188+
189+
void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
190+
191+
switch (dst->type) {
192+
case GGML_TYPE_F32:
193+
concat_impl_sycl<float>(ctx, dst);
194+
break;
195+
case GGML_TYPE_I32:
196+
concat_impl_sycl<int32_t>(ctx, dst);
197+
break;
198+
default:
199+
GGML_ASSERT(false && "ggml_sycl_op_concat: unsupported type");
200+
break;
201+
}
202+
}

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4534,16 +4534,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
45344534
}
45354535
return false;
45364536
}
4537-
case GGML_OP_CONCAT:
4538-
{
4539-
ggml_type src0_type = op->src[0]->type;
4540-
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
4541-
}
45424537
case GGML_OP_REPEAT_BACK:
45434538
{
45444539
ggml_type src0_type = op->src[0]->type;
45454540
return src0_type == GGML_TYPE_F32;
45464541
}
4542+
case GGML_OP_CONCAT:
45474543
case GGML_OP_DUP:
45484544
case GGML_OP_ARGMAX:
45494545
case GGML_OP_NONE:

0 commit comments

Comments
 (0)