Skip to content

Commit 89cc712

Browse files
guschmuerohan11235813
authored andcommitted
handle fp16 for where op (#19969)
this prevents falling back from webgpu to cpu, aka helps performance
1 parent bf70b8d commit 89cc712

File tree

1 file changed

+14
-12
lines changed
  • onnxruntime/core/providers/js/operators

1 file changed

+14
-12
lines changed

onnxruntime/core/providers/js/operators/where.cc

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,19 @@
66
namespace onnxruntime {
77
namespace js {
88

9-
#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
10-
ONNX_OPERATOR_KERNEL_EX( \
11-
OP_TYPE, \
12-
kOnnxDomain, \
13-
VERSION, \
14-
kJsExecutionProvider, \
15-
KernelDefBuilder() \
16-
.TypeConstraint("T", \
17-
{DataTypeImpl::GetTensorType<float>(), \
18-
DataTypeImpl::GetTensorType<int32_t>(), \
19-
DataTypeImpl::GetTensorType<uint32_t>(), \
20-
DataTypeImpl::GetTensorType<bool>()}), \
9+
#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
10+
ONNX_OPERATOR_KERNEL_EX( \
11+
OP_TYPE, \
12+
kOnnxDomain, \
13+
VERSION, \
14+
kJsExecutionProvider, \
15+
KernelDefBuilder() \
16+
.TypeConstraint("T", \
17+
{DataTypeImpl::GetTensorType<float>(), \
18+
DataTypeImpl::GetTensorType<MLFloat16>(), \
19+
DataTypeImpl::GetTensorType<int32_t>(), \
20+
DataTypeImpl::GetTensorType<uint32_t>(), \
21+
DataTypeImpl::GetTensorType<bool>()}), \
2122
KERNEL_CLASS);
2223

2324
#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \
@@ -29,6 +30,7 @@ namespace js {
2930
KernelDefBuilder() \
3031
.TypeConstraint("T", \
3132
{DataTypeImpl::GetTensorType<float>(), \
33+
DataTypeImpl::GetTensorType<MLFloat16>(), \
3234
DataTypeImpl::GetTensorType<int32_t>(), \
3335
DataTypeImpl::GetTensorType<uint32_t>(), \
3436
DataTypeImpl::GetTensorType<bool>()}), \

0 commit comments

Comments
 (0)