Skip to content

Commit 5f0efb1

Browse files
committed
Implement dequantizeLinear
1 parent a97dfb2 commit 5f0efb1

3 files changed

Lines changed: 172 additions & 3 deletions

File tree

services/webnn/ort/context_impl_ort.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ ContextProperties ContextImplOrt::GetContextProperties() {
5959
static constexpr SupportedRanks kNonScalarMaxRank =
6060
SupportedRanks::NonScalarUpTo(8);
6161

62+
static constexpr SupportedDataTypes kDequantizeLinearInputSupportedDataTypes{
63+
OperandDataType::kInt4, OperandDataType::kUint4, OperandDataType::kUint8,
64+
OperandDataType::kInt8, OperandDataType::kInt32};
65+
6266
return ContextProperties(
6367
InputOperandLayout::kNchw, Resample2DAxes::kChannelsFirst,
6468
/*tensor_byte_length_limit=*/kTensorByteLengthLimit,
@@ -74,8 +78,8 @@ ContextProperties ContextImplOrt::GetContextProperties() {
7478
/*conv2d_input=*/DataTypeConstraint::kFloat16To32,
7579
/*conv_transpose2d_input=*/DataTypeConstraint::kFloat16To32,
7680
/*cumulative_sum_input=*/{},
77-
/*dequantize_linear_input=*/{},
78-
/*dequantize_linear_scale=*/{},
81+
/*dequantize_linear_input=*/kDequantizeLinearInputSupportedDataTypes,
82+
/*dequantize_linear_scale=*/DataTypeConstraint::kFloat16To32,
7983
/*add_input=*/
8084
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
8185
/*sub_input=*/

services/webnn/ort/graph_builder_ort.cc

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ constexpr char kOpTypeClamp[] = "Clip";
6969
constexpr char kOpTypeConcat[] = "Concat";
7070
constexpr char kOpTypeConv2d[] = "Conv";
7171
constexpr char kOpTypeConvTranspose2d[] = "ConvTranspose";
72+
constexpr char kOpTypeDequantizeLinear[] = "DequantizeLinear";
7273
constexpr char kOpTypeExpand[] = "Expand";
7374
constexpr char kOpTypeGather[] = "Gather";
7475
constexpr char kOpTypeGelu[] = "Gelu";
@@ -300,6 +301,39 @@ void GraphBuilderOrt::AppendCast(std::string_view input_name,
300301
ADD_CAST_NODE(node_name, input_name, output_name, to_data_type);
301302
}
302303

304+
std::string GraphBuilderOrt::PrependTranspose(
305+
std::string_view input_name,
306+
base::span<const uint32_t> permutation) {
307+
const std::string node_name = GenerateNextOperationName("inserted_transpose");
308+
const std::string output_name = GenerateNextOperandName();
309+
310+
std::array<const char*, 1> input_names = {input_name.data()};
311+
std::array<const char*, 1> output_names = {output_name.data()};
312+
313+
std::vector<int64_t> perm(permutation.begin(), permutation.end());
314+
std::array<OrtOpAttr*, 1> attributes = {
315+
model_builder_.CreateAttribute(/*name=*/"perm", perm).Release()};
316+
317+
model_builder_.AddNode(kOpTypeTranspose, node_name, input_names, output_names,
318+
attributes);
319+
return output_name;
320+
}
321+
322+
void GraphBuilderOrt::AppendTranspose(std::string_view input_name,
323+
std::string_view output_name,
324+
base::span<const uint32_t> permutation) {
325+
const std::string node_name = GenerateNextOperationName("inserted_transpose");
326+
std::array<const char*, 1> input_names = {input_name.data()};
327+
std::array<const char*, 1> output_names = {output_name.data()};
328+
329+
std::vector<int64_t> perm(permutation.begin(), permutation.end());
330+
std::array<OrtOpAttr*, 1> attributes = {
331+
model_builder_.CreateAttribute(/*name=*/"perm", perm).Release()};
332+
333+
model_builder_.AddNode(kOpTypeTranspose, node_name, input_names, output_names,
334+
attributes);
335+
}
336+
303337
void GraphBuilderOrt::AddInput(uint64_t input_id) {
304338
const mojom::Operand& operand = GetOperand(input_id);
305339
std::string name = GetOperandNameById(input_id);
@@ -952,6 +986,123 @@ GraphBuilderOrt::AddExpandOperation(const mojom::Expand& expand) {
952986
return base::ok();
953987
}
954988

989+
[[nodiscard]] base::expected<void, mojom::ErrorPtr>
990+
GraphBuilderOrt::AddDequantizeLinearOperation(
991+
const mojom::DequantizeLinear& dequantize_linear) {
992+
const std::string node_name =
993+
GenerateNextOperationName(dequantize_linear.label);
994+
std::string input_name =
995+
GetOperandNameById(dequantize_linear.input_operand_id);
996+
std::string scale_name =
997+
GetOperandNameById(dequantize_linear.scale_operand_id);
998+
std::string zero_point_name =
999+
GetOperandNameById(dequantize_linear.zero_point_operand_id);
1000+
std::string output_name =
1001+
GetOperandNameById(dequantize_linear.output_operand_id);
1002+
1003+
const OperandDescriptor& input_descriptor =
1004+
GetOperand(dequantize_linear.input_operand_id).descriptor;
1005+
std::vector<uint32_t> input_shape = input_descriptor.shape();
1006+
1007+
const OperandDescriptor& scale_descriptor =
1008+
GetOperand(dequantize_linear.scale_operand_id).descriptor;
1009+
std::vector<uint32_t> scale_shape = scale_descriptor.shape();
1010+
1011+
int64_t axis = 1;
1012+
int64_t block_size = 0;
1013+
bool need_transpose = false;
1014+
1015+
// https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp#L220
1016+
if (scale_shape.size() > 2) {
1017+
return NewNotSupportedError(
1018+
"OpenVINO dequantizeLinear cannot operate with more than 2D scales");
1019+
}
1020+
1021+
if (scale_shape.empty()) {
1022+
// For per-tensor/layer dequantization the scale is a scalar.
1023+
axis = 0;
1024+
} else if (scale_shape.size() == 1) {
1025+
bool is_valid = false;
1026+
// for per per-axis dequantization it is a 1-D Tensor
1027+
for (size_t i = 0; i < input_shape.size(); i++) {
1028+
if (scale_shape[0] == input_shape[i]) {
1029+
axis = i;
1030+
is_valid = true;
1031+
}
1032+
}
1033+
if (!is_valid) {
1034+
return NewNotSupportedError(
1035+
"For 1D scale, the size of scale must be the same as the size of the "
1036+
"input dim specified by the axis.");
1037+
}
1038+
} else {
1039+
CHECK_EQ(scale_shape.size(), 2u);
1040+
// For blocked dequantization it has the same shape as the input, except for
1041+
// one dimension in which blocking is performed.
1042+
if (scale_shape.size() == input_shape.size()) {
1043+
uint32_t diff_count = 0;
1044+
for (size_t i = 0; i < input_shape.size(); i++) {
1045+
if (scale_shape[i] != input_shape[i]) {
1046+
// https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp#L230
1047+
if (input_shape[i] % scale_shape[i] != 0) {
1048+
return NewNotSupportedError(
1049+
"For blocked dequantization, OpenVINO DequantizeLinear doesn't "
1050+
"support case when input cannot be divided by scale.");
1051+
}
1052+
block_size = input_shape[i] / scale_shape[i];
1053+
axis = i;
1054+
diff_count++;
1055+
if (diff_count > 1) {
1056+
return NewNotSupportedError(
1057+
"For blocked dequantization it has the same shape as the "
1058+
"input, except for one dimension in which blocking is "
1059+
"performed");
1060+
}
1061+
}
1062+
}
1063+
// The shape of scale is the same as the shape of input.
1064+
if (diff_count == 0) {
1065+
axis = 0;
1066+
block_size = 1;
1067+
}
1068+
1069+
// Currently, OpenVINO only supports axis == 0 when scale.size == 2.
1070+
// https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp#L228.
1071+
if (axis != 0) {
1072+
input_name = PrependTranspose(input_name, {1, 0});
1073+
scale_name = PrependTranspose(scale_name, {1, 0});
1074+
zero_point_name = PrependTranspose(zero_point_name, {1, 0});
1075+
axis = 0;
1076+
need_transpose = true;
1077+
}
1078+
}
1079+
}
1080+
1081+
const std::string transposed_output_name =
1082+
need_transpose ? GenerateNextOperandName() : output_name;
1083+
1084+
base::FixedArray<const char*> input_names = {
1085+
input_name.c_str(), scale_name.c_str(), zero_point_name.c_str()};
1086+
base::FixedArray<const char*> output_names = {transposed_output_name.c_str()};
1087+
1088+
std::array<OrtOpAttr*, 2> attributes = {
1089+
model_builder_
1090+
.CreateAttribute(/*name=*/"axis", base::checked_cast<int64_t>(axis))
1091+
.Release(),
1092+
model_builder_
1093+
.CreateAttribute(/*name=*/"block_size",
1094+
base::checked_cast<int64_t>(block_size))
1095+
.Release()};
1096+
1097+
model_builder_.AddNode(kOpTypeDequantizeLinear, node_name, input_names,
1098+
output_names, attributes);
1099+
1100+
if (need_transpose) {
1101+
AppendTranspose(transposed_output_name, output_name, {1, 0});
1102+
}
1103+
return base::ok();
1104+
}
1105+
9551106
void GraphBuilderOrt::AddGatherOperation(const mojom::Gather& gather) {
9561107
const std::string node_name = GenerateNextOperationName(gather.label);
9571108
const std::string input_name = GetOperandNameById(gather.input_operand_id);
@@ -1780,6 +1931,11 @@ GraphBuilderOrt::BuildModel() {
17801931
RETURN_IF_ERROR(AddConv2dOperation(*operation->get_conv2d()));
17811932
break;
17821933
}
1934+
case mojom::Operation::Tag::kDequantizeLinear: {
1935+
RETURN_IF_ERROR(
1936+
AddDequantizeLinearOperation(*operation->get_dequantize_linear()));
1937+
break;
1938+
}
17831939
case mojom::Operation::Tag::kExpand: {
17841940
RETURN_IF_ERROR(AddExpandOperation(*operation->get_expand()));
17851941
break;
@@ -1863,7 +2019,6 @@ GraphBuilderOrt::BuildModel() {
18632019
break;
18642020
}
18652021
case mojom::Operation::Tag::kCumulativeSum:
1866-
case mojom::Operation::Tag::kDequantizeLinear:
18672022
case mojom::Operation::Tag::kElu:
18682023
case mojom::Operation::Tag::kGatherElements:
18692024
case mojom::Operation::Tag::kGatherNd:

services/webnn/ort/graph_builder_ort.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ class GraphBuilderOrt {
131131
std::string PrependCast(std::string_view input_name,
132132
ONNXTensorElementDataType to_data_type);
133133

134+
std::string PrependTranspose(std::string_view input_name,
135+
base::span<const uint32_t> permutation);
136+
134137
// Insert a cast operation after an operation to convert its output to the
135138
// target `to_data_type`. The `input_name` specifies the cast operation's
136139
// input (the output of the operation to be casted), and the `output_name`
@@ -139,6 +142,10 @@ class GraphBuilderOrt {
139142
std::string_view output_name,
140143
ONNXTensorElementDataType to_data_type);
141144

145+
void AppendTranspose(std::string_view input_name,
146+
std::string_view output_name,
147+
base::span<const uint32_t> permutation);
148+
142149
void AddInput(uint64_t input_id);
143150
void AddOutput(uint64_t output_id);
144151

@@ -175,6 +182,9 @@ class GraphBuilderOrt {
175182
const mojom::Conv2d& conv2d);
176183
[[nodiscard]] base::expected<void, mojom::ErrorPtr> AddExpandOperation(
177184
const mojom::Expand& expand);
185+
[[nodiscard]] base::expected<void, mojom::ErrorPtr>
186+
AddDequantizeLinearOperation(
187+
const mojom::DequantizeLinear& dequantize_linear);
178188
void AddGatherOperation(const mojom::Gather& gather);
179189
void AddGemmOperation(const mojom::Gemm& gemm);
180190
[[nodiscard]] base::expected<void, mojom::ErrorPtr>

0 commit comments

Comments
 (0)