@@ -2127,6 +2127,30 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
21272127 sum_rows_f32_sycl (src0_dd, dst_dd, ncols, nrows, main_stream);
21282128}
21292129
2130+ inline void ggml_sycl_op_mean (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2131+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
2132+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
2133+
2134+ dpct::queue_ptr main_stream = ctx.stream ();
2135+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
2136+
2137+ const float * src0_dd = static_cast <const float *>(dst->src [0 ]->data );
2138+ float * dst_dd = static_cast <float *>(dst->data );
2139+
2140+ const int64_t ncols = dst->src [0 ]->ne [0 ];
2141+ const int64_t nrows = ggml_nrows (dst->src [0 ]);
2142+
2143+ sum_rows_f32_sycl (src0_dd, dst_dd, ncols, nrows, main_stream);
2144+
2145+ main_stream->parallel_for (
2146+ sycl::range<1 >(nrows),
2147+ [=](sycl::id<1 > row) {
2148+ dst_dd[row] /= ncols;
2149+ }
2150+ );
2151+ }
2152+
2153+
21302154inline void ggml_sycl_op_argsort (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
21312155 GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
21322156 GGML_ASSERT (dst->type == GGML_TYPE_I32);
@@ -3510,6 +3534,12 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds
35103534 ggml_sycl_op_sum_rows (ctx, dst);
35113535}
35123536
3537+ static void ggml_sycl_mean (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3538+ scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 1 );
3539+ GGML_ASSERT (ggml_is_contiguous (dst->src [0 ]));
3540+ ggml_sycl_op_mean (ctx, dst);
3541+ }
3542+
35133543static void ggml_sycl_argsort (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
35143544 scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 1 );
35153545 GGML_ASSERT (ggml_is_contiguous (dst->src [0 ]));
@@ -3756,6 +3786,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
37563786 case GGML_OP_SUM_ROWS:
37573787 ggml_sycl_sum_rows (ctx, dst);
37583788 break ;
3789+ case GGML_OP_MEAN:
3790+ ggml_sycl_mean (ctx, dst);
3791+ break ;
37593792 case GGML_OP_ARGSORT:
37603793 ggml_sycl_argsort (ctx, dst);
37613794 break ;
@@ -4406,6 +4439,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
44064439 return op->src [0 ]->type == GGML_TYPE_F32 && op->op_params [0 ] == GGML_SCALE_MODE_NEAREST;
44074440 case GGML_OP_SUM:
44084441 case GGML_OP_SUM_ROWS:
4442+ case GGML_OP_MEAN:
44094443 case GGML_OP_ARGSORT:
44104444 return ggml_is_contiguous (op->src [0 ]);
44114445 case GGML_OP_POOL_2D:
0 commit comments