Skip to content

Commit 11c325c

Browse files
authored
ggml-webgpu: Add unary op (SQR, SQRT, SIN, COS) support. (ggml-org#19700)
* ggml-webgpu: Add unary op (SQR, SQRT, SIN, COS) support. * Fix to cast the src value to f32 before sin/cos computing.
1 parent 237958d commit 11c325c

File tree

4 files changed

+62
-20
lines changed

4 files changed

+62
-20
lines changed

docs/ops.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Legend:
3131
| CONV_3D ||||||||||||
3232
| CONV_TRANSPOSE_1D ||||||||||||
3333
| CONV_TRANSPOSE_2D ||||||||||||
34-
| COS ||||| 🟡 ||| 🟡 | |||
34+
| COS ||||| 🟡 ||| 🟡 | |||
3535
| COUNT_EQUAL ||||||||||||
3636
| CPY || 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |||
3737
| CROSS_ENTROPY_LOSS ||||||||||||
@@ -96,13 +96,13 @@ Legend:
9696
| SIGMOID |||| 🟡 | 🟡 | 🟡 || 🟡 ||||
9797
| SILU |||| 🟡 | 🟡 | 🟡 || 🟡 ||||
9898
| SILU_BACK ||||||||||||
99-
| SIN ||||| 🟡 ||| 🟡 | |||
99+
| SIN ||||| 🟡 ||| 🟡 | |||
100100
| SOFTPLUS |||| 🟡 | 🟡 ||| 🟡 ||||
101101
| SOFT_MAX || 🟡 ||||||||||
102102
| SOFT_MAX_BACK ||| 🟡 | 🟡 ||| 🟡 |||||
103103
| SOLVE_TRI |||| 🟡 |||| 🟡 ||||
104-
| SQR ||||| 🟡 ||| 🟡 | |||
105-
| SQRT ||||| 🟡 ||| 🟡 | |||
104+
| SQR ||||| 🟡 ||| 🟡 | |||
105+
| SQRT ||||| 🟡 ||| 🟡 | |||
106106
| SSM_CONV ||||||||||||
107107
| SSM_SCAN |||||||| 🟡 ||||
108108
| STEP |||| 🟡 | 🟡 ||| 🟡 ||||

docs/ops/WebGPU.csv

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8760,22 +8760,14 @@
87608760
"WebGPU: WebGPU","ADD_ID","type_a=f32,type_b=f32,n_embd=129,n_experts=8,n_experts_used=4,n_token=1","support","0","no","WebGPU"
87618761
"WebGPU: WebGPU","ADD_ID","type_a=f32,type_b=f32,n_embd=129,n_experts=8,n_experts_used=4,n_token=32","support","0","no","WebGPU"
87628762
"WebGPU: WebGPU","ADD_ID","type_a=f32,type_b=f32,n_embd=129,n_experts=8,n_experts_used=4,n_token=129","support","0","no","WebGPU"
8763-
"WebGPU: WebGPU","SQR","type=f16,ne=[10,5,4,3]","support","0","no","WebGPU"
8764-
"WebGPU: WebGPU","SQRT","type=f16,ne=[10,3,3,2]","support","0","no","WebGPU"
87658763
"WebGPU: WebGPU","LOG","type=f16,ne=[10,5,4,3]","support","1","yes","WebGPU"
8766-
"WebGPU: WebGPU","SIN","type=f16,ne=[10,2,2,2]","support","0","no","WebGPU"
8767-
"WebGPU: WebGPU","COS","type=f16,ne=[10,2,2,2]","support","0","no","WebGPU"
87688764
"WebGPU: WebGPU","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","WebGPU"
87698765
"WebGPU: WebGPU","LEAKY_RELU","type=f16,ne_a=[10,5,4,3],negative_slope=0.100000","support","0","no","WebGPU"
87708766
"WebGPU: WebGPU","FLOOR","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU"
87718767
"WebGPU: WebGPU","CEIL","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU"
87728768
"WebGPU: WebGPU","ROUND","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU"
87738769
"WebGPU: WebGPU","TRUNC","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU"
8774-
"WebGPU: WebGPU","SQR","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU"
8775-
"WebGPU: WebGPU","SQRT","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU"
87768770
"WebGPU: WebGPU","LOG","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
8777-
"WebGPU: WebGPU","SIN","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU"
8778-
"WebGPU: WebGPU","COS","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU"
87798771
"WebGPU: WebGPU","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","WebGPU"
87808772
"WebGPU: WebGPU","LEAKY_RELU","type=f16,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","WebGPU"
87818773
"WebGPU: WebGPU","FLOOR","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
@@ -8786,22 +8778,14 @@
87868778
"WebGPU: WebGPU","ROUND","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
87878779
"WebGPU: WebGPU","TRUNC","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
87888780
"WebGPU: WebGPU","TRUNC","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
8789-
"WebGPU: WebGPU","SQR","type=f32,ne=[10,5,4,3]","support","0","no","WebGPU"
8790-
"WebGPU: WebGPU","SQRT","type=f32,ne=[10,3,3,2]","support","0","no","WebGPU"
87918781
"WebGPU: WebGPU","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","WebGPU"
8792-
"WebGPU: WebGPU","SIN","type=f32,ne=[10,2,2,2]","support","0","no","WebGPU"
8793-
"WebGPU: WebGPU","COS","type=f32,ne=[10,2,2,2]","support","0","no","WebGPU"
87948782
"WebGPU: WebGPU","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","WebGPU"
87958783
"WebGPU: WebGPU","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","0","no","WebGPU"
87968784
"WebGPU: WebGPU","FLOOR","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU"
87978785
"WebGPU: WebGPU","CEIL","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU"
87988786
"WebGPU: WebGPU","ROUND","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU"
87998787
"WebGPU: WebGPU","TRUNC","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU"
8800-
"WebGPU: WebGPU","SQR","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU"
8801-
"WebGPU: WebGPU","SQRT","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU"
88028788
"WebGPU: WebGPU","LOG","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU"
8803-
"WebGPU: WebGPU","SIN","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU"
8804-
"WebGPU: WebGPU","COS","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU"
88058789
"WebGPU: WebGPU","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","WebGPU"
88068790
"WebGPU: WebGPU","LEAKY_RELU","type=f32,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","WebGPU"
88078791
"WebGPU: WebGPU","FLOOR","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU"
@@ -18901,3 +18885,27 @@
1890118885
"WebGPU: WebGPU","CROSS_ENTROPY_LOSS_BACK","type=f32,ne=[30000,1,1,1]","support","0","no","WebGPU"
1890218886
"WebGPU: WebGPU","OPT_STEP_ADAMW","type=f32,ne=[10,5,4,3]","support","0","no","WebGPU"
1890318887
"WebGPU: WebGPU","OPT_STEP_SGD","type=f32,ne=[10,5,4,3]","support","0","no","WebGPU"
18888+
"WebGPU: WebGPU","SQR","type=f16,ne=[10,5,4,3]","support","1","yes","WebGPU"
18889+
"WebGPU: WebGPU","SQRT","type=f16,ne=[10,3,3,2]","support","1","yes","WebGPU"
18890+
"WebGPU: WebGPU","SIN","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU"
18891+
"WebGPU: WebGPU","COS","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU"
18892+
"WebGPU: WebGPU","SQR","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
18893+
"WebGPU: WebGPU","SQR","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
18894+
"WebGPU: WebGPU","SQRT","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
18895+
"WebGPU: WebGPU","SQRT","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
18896+
"WebGPU: WebGPU","SIN","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
18897+
"WebGPU: WebGPU","SIN","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
18898+
"WebGPU: WebGPU","COS","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
18899+
"WebGPU: WebGPU","COS","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
18900+
"WebGPU: WebGPU","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","WebGPU"
18901+
"WebGPU: WebGPU","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","WebGPU"
18902+
"WebGPU: WebGPU","SIN","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU"
18903+
"WebGPU: WebGPU","COS","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU"
18904+
"WebGPU: WebGPU","SQR","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU"
18905+
"WebGPU: WebGPU","SQR","type=f32,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
18906+
"WebGPU: WebGPU","SQRT","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU"
18907+
"WebGPU: WebGPU","SQRT","type=f32,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
18908+
"WebGPU: WebGPU","SIN","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU"
18909+
"WebGPU: WebGPU","SIN","type=f32,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
18910+
"WebGPU: WebGPU","COS","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU"
18911+
"WebGPU: WebGPU","COS","type=f32,ne=[1024,1024,1,1]","support","1","yes","WebGPU"

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2008,6 +2008,14 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
20082008
return ggml_webgpu_unary_op(ctx, src0, node);
20092009
case GGML_OP_LOG:
20102010
return ggml_webgpu_unary_op(ctx, src0, node);
2011+
case GGML_OP_SQR:
2012+
return ggml_webgpu_unary_op(ctx, src0, node);
2013+
case GGML_OP_SQRT:
2014+
return ggml_webgpu_unary_op(ctx, src0, node);
2015+
case GGML_OP_SIN:
2016+
return ggml_webgpu_unary_op(ctx, src0, node);
2017+
case GGML_OP_COS:
2018+
return ggml_webgpu_unary_op(ctx, src0, node);
20112019
case GGML_OP_PAD:
20122020
return ggml_webgpu_pad(ctx, src0, node);
20132021
case GGML_OP_ARGMAX:
@@ -2967,6 +2975,18 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
29672975
case GGML_OP_LOG:
29682976
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
29692977
break;
2978+
case GGML_OP_SQR:
2979+
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
2980+
break;
2981+
case GGML_OP_SQRT:
2982+
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
2983+
break;
2984+
case GGML_OP_SIN:
2985+
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
2986+
break;
2987+
case GGML_OP_COS:
2988+
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
2989+
break;
29702990
case GGML_OP_PAD:
29712991
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
29722992
break;

ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
170170
#ifdef TRUNC
171171
let res = trunc(src[params.offset_src + src_idx]);
172172
#endif
173+
#ifdef SQR
174+
let res = src[params.offset_src + src_idx] * src[params.offset_src + src_idx];
175+
#endif
176+
#ifdef SQRT
177+
let res = sqrt(src[params.offset_src + src_idx]);
178+
#endif
179+
#ifdef SIN
180+
let res_f32 = sin(f32(src[params.offset_src + src_idx]));
181+
let res = TYPE(res_f32);
182+
#endif
183+
#ifdef COS
184+
let res_f32 = cos(f32(src[params.offset_src + src_idx]));
185+
let res = TYPE(res_f32);
186+
#endif
173187

174188
#ifdef INPLACE
175189
src[params.offset_src + src_idx] = res;

0 commit comments

Comments
 (0)