@@ -98,6 +98,7 @@ constexpr char kOpTypeReduceSumSquare[] = "ReduceSumSquare";
9898constexpr char kOpTypeRelu [] = " Relu" ;
9999constexpr char kOpTypeResample2d [] = " Resize" ;
100100constexpr char kOpTypeReshape [] = " Reshape" ;
101+ constexpr char kOpTypeScatterND [] = " ScatterND" ;
101102constexpr char kOpTypeSigmoid [] = " Sigmoid" ;
102103constexpr char kOpTypeSlice [] = " Slice" ;
103104constexpr char kOpTypeSoftmax [] = " Softmax" ;
@@ -1557,6 +1558,49 @@ GraphBuilderOrt::AddReshapeOperation(const mojom::Reshape& reshape) {
15571558 return base::ok ();
15581559}
15591560
1561+ void GraphBuilderOrt::AddScatterNDOperation (
1562+ const mojom::ScatterND& scatter_nd) {
1563+ const std::string node_name = GenerateNextOperationName (scatter_nd.label );
1564+ const std::string input_name =
1565+ GetOperandNameById (scatter_nd.input_operand_id );
1566+ const std::string indices_name =
1567+ GetOperandNameById (scatter_nd.indices_operand_id );
1568+ const std::string updates_name =
1569+ GetOperandNameById (scatter_nd.updates_operand_id );
1570+ const std::string output_name =
1571+ GetOperandNameById (scatter_nd.output_operand_id );
1572+
1573+ std::string int64_indices_name;
1574+ const OperandDataType indices_data_type =
1575+ GetOperand (scatter_nd.indices_operand_id ).descriptor .data_type ();
1576+
1577+ // ONNX only supports int64 indices.
1578+ switch (indices_data_type) {
1579+ case OperandDataType::kInt64 : {
1580+ int64_indices_name = indices_name;
1581+ break ;
1582+ }
1583+ case OperandDataType::kInt32 :
1584+ case OperandDataType::kUint32 : {
1585+ int64_indices_name = GenerateNextOperandName ();
1586+ AppendCast (
1587+ indices_name, int64_indices_name,
1588+ ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64);
1589+ break ;
1590+ }
1591+ default :
1592+ NOTREACHED () << " [WebNN] ScatterND only supports data type int32, uint32 "
1593+ " and int64." ;
1594+ }
1595+
1596+ std::array<const char *, 3 > input_names = {
1597+ input_name.c_str (), int64_indices_name.c_str (), updates_name.c_str ()};
1598+ std::array<const char *, 1 > output_names = {output_name.c_str ()};
1599+
1600+ model_builder_.AddNode (kOpTypeScatterND , node_name, input_names,
1601+ output_names);
1602+ }
1603+
15601604[[nodiscard]] base::expected<void , mojom::ErrorPtr>
15611605GraphBuilderOrt::AddSliceOperation (const mojom::Slice& slice) {
15621606 const std::string node_name = GenerateNextOperationName (slice.label );
@@ -1834,6 +1878,10 @@ GraphBuilderOrt::BuildModel() {
18341878 RETURN_IF_ERROR (AddReshapeOperation (*operation->get_reshape ()));
18351879 break ;
18361880 }
1881+ case mojom::Operation::Tag::kScatterNd : {
1882+ AddScatterNDOperation (*operation->get_scatter_nd ());
1883+ break ;
1884+ }
18371885 case mojom::Operation::Tag::kSigmoid : {
18381886 AddUnaryOperation (*operation->get_sigmoid (), kOpTypeSigmoid );
18391887 break ;
@@ -1879,7 +1927,6 @@ GraphBuilderOrt::BuildModel() {
18791927 case mojom::Operation::Tag::kQuantizeLinear :
18801928 case mojom::Operation::Tag::kReverse :
18811929 case mojom::Operation::Tag::kScatterElements :
1882- case mojom::Operation::Tag::kScatterNd :
18831930 case mojom::Operation::Tag::kSoftplus :
18841931 case mojom::Operation::Tag::kSoftsign :
18851932 case mojom::Operation::Tag::kTanh :
0 commit comments