Skip to content

Commit b4b6b91

Browse files
committed
Implement ScatterND (#138)
Fix #116
1 parent 17bd6fd commit b4b6b91

3 files changed

Lines changed: 56 additions & 4 deletions

File tree

services/webnn/ort/context_impl_ort.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,13 @@ ContextProperties ContextImplOrt::GetContextProperties() {
187187
/*reverse_input=*/{},
188188
/*scatter_elements_input=*/{},
189189
/*scatter_elements_indices=*/{},
190-
/*scatter_nd_input=*/{},
191-
/*scatter_nd_indices=*/{},
192-
/*scatter_nd_updates=*/{},
190+
/*scatter_nd_input=*/
191+
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kNonScalarMaxRank},
192+
/*scatter_nd_indices=*/
193+
{DataTypeConstraint::kGatherScatterIndicesSupportedDataTypes,
194+
kNonScalarMaxRank},
195+
/*scatter_nd_updates=*/
196+
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
193197
/*sigmoid_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
194198
/*slice_input=*/
195199
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},

services/webnn/ort/graph_builder_ort.cc

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ constexpr char kOpTypeReduceSumSquare[] = "ReduceSumSquare";
9898
constexpr char kOpTypeRelu[] = "Relu";
9999
constexpr char kOpTypeResample2d[] = "Resize";
100100
constexpr char kOpTypeReshape[] = "Reshape";
101+
constexpr char kOpTypeScatterND[] = "ScatterND";
101102
constexpr char kOpTypeSigmoid[] = "Sigmoid";
102103
constexpr char kOpTypeSlice[] = "Slice";
103104
constexpr char kOpTypeSoftmax[] = "Softmax";
@@ -1557,6 +1558,49 @@ GraphBuilderOrt::AddReshapeOperation(const mojom::Reshape& reshape) {
15571558
return base::ok();
15581559
}
15591560

1561+
void GraphBuilderOrt::AddScatterNDOperation(
1562+
const mojom::ScatterND& scatter_nd) {
1563+
const std::string node_name = GenerateNextOperationName(scatter_nd.label);
1564+
const std::string input_name =
1565+
GetOperandNameById(scatter_nd.input_operand_id);
1566+
const std::string indices_name =
1567+
GetOperandNameById(scatter_nd.indices_operand_id);
1568+
const std::string updates_name =
1569+
GetOperandNameById(scatter_nd.updates_operand_id);
1570+
const std::string output_name =
1571+
GetOperandNameById(scatter_nd.output_operand_id);
1572+
1573+
std::string int64_indices_name;
1574+
const OperandDataType indices_data_type =
1575+
GetOperand(scatter_nd.indices_operand_id).descriptor.data_type();
1576+
1577+
// ONNX only supports int64 indices.
1578+
switch (indices_data_type) {
1579+
case OperandDataType::kInt64: {
1580+
int64_indices_name = indices_name;
1581+
break;
1582+
}
1583+
case OperandDataType::kInt32:
1584+
case OperandDataType::kUint32: {
1585+
int64_indices_name = GenerateNextOperandName();
1586+
AppendCast(
1587+
indices_name, int64_indices_name,
1588+
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64);
1589+
break;
1590+
}
1591+
default:
1592+
NOTREACHED() << "[WebNN] ScatterND only supports data type int32, uint32 "
1593+
"and int64.";
1594+
}
1595+
1596+
std::array<const char*, 3> input_names = {
1597+
input_name.c_str(), int64_indices_name.c_str(), updates_name.c_str()};
1598+
std::array<const char*, 1> output_names = {output_name.c_str()};
1599+
1600+
model_builder_.AddNode(kOpTypeScatterND, node_name, input_names,
1601+
output_names);
1602+
}
1603+
15601604
[[nodiscard]] base::expected<void, mojom::ErrorPtr>
15611605
GraphBuilderOrt::AddSliceOperation(const mojom::Slice& slice) {
15621606
const std::string node_name = GenerateNextOperationName(slice.label);
@@ -1834,6 +1878,10 @@ GraphBuilderOrt::BuildModel() {
18341878
RETURN_IF_ERROR(AddReshapeOperation(*operation->get_reshape()));
18351879
break;
18361880
}
1881+
case mojom::Operation::Tag::kScatterNd: {
1882+
AddScatterNDOperation(*operation->get_scatter_nd());
1883+
break;
1884+
}
18371885
case mojom::Operation::Tag::kSigmoid: {
18381886
AddUnaryOperation(*operation->get_sigmoid(), kOpTypeSigmoid);
18391887
break;
@@ -1879,7 +1927,6 @@ GraphBuilderOrt::BuildModel() {
18791927
case mojom::Operation::Tag::kQuantizeLinear:
18801928
case mojom::Operation::Tag::kReverse:
18811929
case mojom::Operation::Tag::kScatterElements:
1882-
case mojom::Operation::Tag::kScatterNd:
18831930
case mojom::Operation::Tag::kSoftplus:
18841931
case mojom::Operation::Tag::kSoftsign:
18851932
case mojom::Operation::Tag::kTanh:

services/webnn/ort/graph_builder_ort.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ class GraphBuilderOrt {
194194
const mojom::Resample2d& resample2d);
195195
[[nodiscard]] base::expected<void, mojom::ErrorPtr> AddReshapeOperation(
196196
const mojom::Reshape& reshape);
197+
void AddScatterNDOperation(const mojom::ScatterND& scatter_nd);
197198
[[nodiscard]] base::expected<void, mojom::ErrorPtr> AddSliceOperation(
198199
const mojom::Slice& slice);
199200
void AddSoftmaxOperation(const mojom::Softmax& softmax);

0 commit comments

Comments
 (0)