@@ -69,6 +69,7 @@ constexpr char kOpTypeClamp[] = "Clip";
6969constexpr char kOpTypeConcat [] = " Concat" ;
7070constexpr char kOpTypeConv2d [] = " Conv" ;
7171constexpr char kOpTypeConvTranspose2d [] = " ConvTranspose" ;
72+ constexpr char kOpTypeDequantizeLinear [] = " DequantizeLinear" ;
7273constexpr char kOpTypeExpand [] = " Expand" ;
7374constexpr char kOpTypeGather [] = " Gather" ;
7475constexpr char kOpTypeGelu [] = " Gelu" ;
@@ -291,6 +292,39 @@ void GraphBuilderOrt::AppendCast(std::string_view input_name,
291292 ADD_CAST_NODE (node_name, input_name, output_name, to_data_type);
292293}
293294
295+ std::string GraphBuilderOrt::PrependTranspose (
296+ std::string_view input_name,
297+ base::span<const uint32_t > permutation) {
298+ const std::string node_name = GenerateNextOperationName (" inserted_transpose" );
299+ const std::string output_name = GenerateNextOperandName ();
300+
301+ std::array<const char *, 1 > input_names = {input_name.data ()};
302+ std::array<const char *, 1 > output_names = {output_name.data ()};
303+
304+ std::vector<int64_t > perm (permutation.begin (), permutation.end ());
305+ std::array<OrtOpAttr*, 1 > attributes = {
306+ model_builder_.CreateAttribute (/* name=*/ " perm" , perm).Release ()};
307+
308+ model_builder_.AddNode (kOpTypeTranspose , node_name, input_names, output_names,
309+ attributes);
310+ return output_name;
311+ }
312+
313+ void GraphBuilderOrt::AppendTranspose (std::string_view input_name,
314+ std::string_view output_name,
315+ base::span<const uint32_t > permutation) {
316+ const std::string node_name = GenerateNextOperationName (" inserted_transpose" );
317+ std::array<const char *, 1 > input_names = {input_name.data ()};
318+ std::array<const char *, 1 > output_names = {output_name.data ()};
319+
320+ std::vector<int64_t > perm (permutation.begin (), permutation.end ());
321+ std::array<OrtOpAttr*, 1 > attributes = {
322+ model_builder_.CreateAttribute (/* name=*/ " perm" , perm).Release ()};
323+
324+ model_builder_.AddNode (kOpTypeTranspose , node_name, input_names, output_names,
325+ attributes);
326+ }
327+
294328void GraphBuilderOrt::AddInput (uint64_t input_id) {
295329 const mojom::Operand& operand = GetOperand (input_id);
296330 std::string name = GetOperandNameById (input_id);
@@ -924,6 +958,126 @@ void GraphBuilderOrt::AddExpandOperation(const mojom::Expand& expand) {
924958 model_builder_.AddNode (kOpTypeExpand , node_name, input_names, output_names);
925959}
926960
961+ [[nodiscard]] base::expected<void , mojom::ErrorPtr>
962+ GraphBuilderOrt::AddDequantizeLinearOperation (
963+ const mojom::DequantizeLinear& dequantize_linear) {
964+ const std::string node_name =
965+ GenerateNextOperationName (dequantize_linear.label );
966+ std::string input_name =
967+ GetOperandNameById (dequantize_linear.input_operand_id );
968+ std::string scale_name =
969+ GetOperandNameById (dequantize_linear.scale_operand_id );
970+ std::string zero_point_name =
971+ GetOperandNameById (dequantize_linear.zero_point_operand_id );
972+ std::string output_name =
973+ GetOperandNameById (dequantize_linear.output_operand_id );
974+
975+ const OperandDescriptor& input_descriptor =
976+ GetOperand (dequantize_linear.input_operand_id ).descriptor ;
977+ std::vector<uint32_t > input_shape = input_descriptor.shape ();
978+
979+ const OperandDescriptor& scale_descriptor =
980+ GetOperand (dequantize_linear.scale_operand_id ).descriptor ;
981+ std::vector<uint32_t > scale_shape = scale_descriptor.shape ();
982+
983+ int64_t axis = 1 ;
984+ int64_t block_size = 0 ;
985+ bool need_transpose = false ;
986+
987+ // https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp#L220
988+ if (scale_shape.size () > 2 ) {
989+ return NewNotSupportedError (
990+ " OpenVINO dequantizeLinear cannot operate with more than 2D scales" );
991+ }
992+
993+ if (scale_shape.empty ()) {
994+ // For per-tensor/layer dequantization the scale is a scalar.
995+ axis = 0 ;
996+ } else if (scale_shape.size () == 1 ) {
997+ bool is_valid = false ;
998+ // for per per-axis dequantization it is a 1-D Tensor
999+ for (size_t i = 0 ; i < input_shape.size (); i++) {
1000+ if (scale_shape[0 ] == input_shape[i]) {
1001+ axis = i;
1002+ is_valid = true ;
1003+ }
1004+ }
1005+ if (!is_valid) {
1006+ return NewNotSupportedError (
1007+ " For 1D scale, the size of scale must be the same as the size of the "
1008+ " input dim specified by the axis." );
1009+ }
1010+ } else {
1011+ CHECK_EQ (scale_shape.size (), 2u );
1012+ // For blocked dequantization it has the same shape as the input, except for
1013+ // one dimension in which blocking is performed.
1014+ if (scale_shape.size () == input_shape.size ()) {
1015+ uint32_t diff_count = 0 ;
1016+ for (size_t i = 0 ; i < input_shape.size (); i++) {
1017+ if (scale_shape[i] != input_shape[i]) {
1018+ // https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp#L230
1019+ if (input_shape[i] % scale_shape[i] != 0 ) {
1020+ return NewNotSupportedError (
1021+ " For blocked dequantization, OpenVINO DequantizeLinear doesn't "
1022+ " support case when input cannot be divided by scale." );
1023+ }
1024+ block_size = input_shape[i] / scale_shape[i];
1025+ axis = i;
1026+ diff_count++;
1027+ if (diff_count > 1 ) {
1028+ return NewNotSupportedError (
1029+ " For blocked dequantization it has the same shape as the "
1030+ " input, except for one dimension in which blocking is "
1031+ " performed" );
1032+ }
1033+ }
1034+ }
1035+ // The shape of scale is the same as the shape of input.
1036+ if (diff_count == 0 ) {
1037+ axis = 0 ;
1038+ block_size = 1 ;
1039+ }
1040+
1041+ // Currently, OpenVINO only supports axis == 0 when scale.size == 2.
1042+ // https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp#L228.
1043+ if (axis != 0 ) {
1044+ input_name = PrependTranspose (input_name, {1 , 0 });
1045+ scale_name = PrependTranspose (scale_name, {1 , 0 });
1046+ zero_point_name = PrependTranspose (zero_point_name, {1 , 0 });
1047+ axis = 0 ;
1048+ need_transpose = true ;
1049+ }
1050+ }
1051+ }
1052+
1053+ LOG (ERROR) << " AXIS: " << axis;
1054+ LOG (ERROR) << " BLOCK_WISE: " << block_size;
1055+
1056+ const std::string transposed_output_name =
1057+ need_transpose ? GenerateNextOperandName () : output_name;
1058+
1059+ base::FixedArray<const char *> input_names = {
1060+ input_name.c_str (), scale_name.c_str (), zero_point_name.c_str ()};
1061+ base::FixedArray<const char *> output_names = {transposed_output_name.c_str ()};
1062+
1063+ std::array<OrtOpAttr*, 2 > attributes = {
1064+ model_builder_
1065+ .CreateAttribute (/* name=*/ " axis" , base::checked_cast<int64_t >(axis))
1066+ .Release (),
1067+ model_builder_
1068+ .CreateAttribute (/* name=*/ " block_size" ,
1069+ base::checked_cast<int64_t >(block_size))
1070+ .Release ()};
1071+
1072+ model_builder_.AddNode (kOpTypeDequantizeLinear , node_name, input_names,
1073+ output_names, attributes);
1074+
1075+ if (need_transpose) {
1076+ AppendTranspose (transposed_output_name, output_name, {1 , 0 });
1077+ }
1078+ return base::ok ();
1079+ }
1080+
9271081void GraphBuilderOrt::AddGatherOperation (const mojom::Gather& gather) {
9281082 const std::string node_name = GenerateNextOperationName (gather.label );
9291083 const std::string input_name = GetOperandNameById (gather.input_operand_id );
@@ -1724,6 +1878,11 @@ GraphBuilderOrt::BuildModel() {
17241878 RETURN_IF_ERROR (AddConv2dOperation (*operation->get_conv2d ()));
17251879 break ;
17261880 }
1881+ case mojom::Operation::Tag::kDequantizeLinear : {
1882+ RETURN_IF_ERROR (
1883+ AddDequantizeLinearOperation (*operation->get_dequantize_linear ()));
1884+ break ;
1885+ }
17271886 case mojom::Operation::Tag::kExpand : {
17281887 AddExpandOperation (*operation->get_expand ());
17291888 break ;
@@ -1807,7 +1966,6 @@ GraphBuilderOrt::BuildModel() {
18071966 break ;
18081967 }
18091968 case mojom::Operation::Tag::kCumulativeSum :
1810- case mojom::Operation::Tag::kDequantizeLinear :
18111969 case mojom::Operation::Tag::kElu :
18121970 case mojom::Operation::Tag::kGatherElements :
18131971 case mojom::Operation::Tag::kGatherNd :
0 commit comments