Skip to content

Commit c75ee80

Browse files
committed
Implement dequantizeLinear
1 parent 17d8c1d commit c75ee80

3 files changed

Lines changed: 175 additions & 3 deletions

File tree

services/webnn/ort/context_impl_ort.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ static constexpr uint64_t kTensorByteLengthLimit =
5555

5656
// static
5757
ContextProperties ContextImplOrt::GetContextProperties() {
58+
static constexpr SupportedDataTypes kDequantizeLinearInputSupportedDataTypes{
59+
OperandDataType::kInt4, OperandDataType::kUint4, OperandDataType::kUint8,
60+
OperandDataType::kInt8, OperandDataType::kInt32};
61+
5862
return ContextProperties(
5963
InputOperandLayout::kNchw, Resample2DAxes::kChannelsFirst,
6064
/*tensor_byte_length_limit=*/kTensorByteLengthLimit,
@@ -69,8 +73,8 @@ ContextProperties ContextImplOrt::GetContextProperties() {
6973
/*conv2d_input=*/DataTypeConstraint::kFloat16To32,
7074
/*conv_transpose2d_input=*/DataTypeConstraint::kFloat16To32,
7175
/*cumulative_sum_input=*/{},
72-
/*dequantize_linear_input=*/{},
73-
/*dequantize_linear_scale=*/{},
76+
/*dequantize_linear_input=*/kDequantizeLinearInputSupportedDataTypes,
77+
/*dequantize_linear_scale=*/DataTypeConstraint::kFloat16To32,
7478
/*add_input=*/
7579
{DataTypeConstraint::kAllDataTypesAtLeast8bits, SupportedRanks::UpTo(8)},
7680
/*sub_input=*/

services/webnn/ort/graph_builder_ort.cc

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ constexpr char kOpTypeClamp[] = "Clip";
6969
constexpr char kOpTypeConcat[] = "Concat";
7070
constexpr char kOpTypeConv2d[] = "Conv";
7171
constexpr char kOpTypeConvTranspose2d[] = "ConvTranspose";
72+
constexpr char kOpTypeDequantizeLinear[] = "DequantizeLinear";
7273
constexpr char kOpTypeExpand[] = "Expand";
7374
constexpr char kOpTypeGather[] = "Gather";
7475
constexpr char kOpTypeGelu[] = "Gelu";
@@ -291,6 +292,39 @@ void GraphBuilderOrt::AppendCast(std::string_view input_name,
291292
ADD_CAST_NODE(node_name, input_name, output_name, to_data_type);
292293
}
293294

295+
std::string GraphBuilderOrt::PrependTranspose(
296+
std::string_view input_name,
297+
base::span<const uint32_t> permutation) {
298+
const std::string node_name = GenerateNextOperationName("inserted_transpose");
299+
const std::string output_name = GenerateNextOperandName();
300+
301+
std::array<const char*, 1> input_names = {input_name.data()};
302+
std::array<const char*, 1> output_names = {output_name.data()};
303+
304+
std::vector<int64_t> perm(permutation.begin(), permutation.end());
305+
std::array<OrtOpAttr*, 1> attributes = {
306+
model_builder_.CreateAttribute(/*name=*/"perm", perm).Release()};
307+
308+
model_builder_.AddNode(kOpTypeTranspose, node_name, input_names, output_names,
309+
attributes);
310+
return output_name;
311+
}
312+
313+
void GraphBuilderOrt::AppendTranspose(std::string_view input_name,
314+
std::string_view output_name,
315+
base::span<const uint32_t> permutation) {
316+
const std::string node_name = GenerateNextOperationName("inserted_transpose");
317+
std::array<const char*, 1> input_names = {input_name.data()};
318+
std::array<const char*, 1> output_names = {output_name.data()};
319+
320+
std::vector<int64_t> perm(permutation.begin(), permutation.end());
321+
std::array<OrtOpAttr*, 1> attributes = {
322+
model_builder_.CreateAttribute(/*name=*/"perm", perm).Release()};
323+
324+
model_builder_.AddNode(kOpTypeTranspose, node_name, input_names, output_names,
325+
attributes);
326+
}
327+
294328
void GraphBuilderOrt::AddInput(uint64_t input_id) {
295329
const mojom::Operand& operand = GetOperand(input_id);
296330
std::string name = GetOperandNameById(input_id);
@@ -924,6 +958,126 @@ void GraphBuilderOrt::AddExpandOperation(const mojom::Expand& expand) {
924958
model_builder_.AddNode(kOpTypeExpand, node_name, input_names, output_names);
925959
}
926960

961+
[[nodiscard]] base::expected<void, mojom::ErrorPtr>
962+
GraphBuilderOrt::AddDequantizeLinearOperation(
963+
const mojom::DequantizeLinear& dequantize_linear) {
964+
const std::string node_name =
965+
GenerateNextOperationName(dequantize_linear.label);
966+
std::string input_name =
967+
GetOperandNameById(dequantize_linear.input_operand_id);
968+
std::string scale_name =
969+
GetOperandNameById(dequantize_linear.scale_operand_id);
970+
std::string zero_point_name =
971+
GetOperandNameById(dequantize_linear.zero_point_operand_id);
972+
std::string output_name =
973+
GetOperandNameById(dequantize_linear.output_operand_id);
974+
975+
const OperandDescriptor& input_descriptor =
976+
GetOperand(dequantize_linear.input_operand_id).descriptor;
977+
std::vector<uint32_t> input_shape = input_descriptor.shape();
978+
979+
const OperandDescriptor& scale_descriptor =
980+
GetOperand(dequantize_linear.scale_operand_id).descriptor;
981+
std::vector<uint32_t> scale_shape = scale_descriptor.shape();
982+
983+
int64_t axis = 1;
984+
int64_t block_size = 0;
985+
bool need_transpose = false;
986+
987+
// https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp#L220
988+
if (scale_shape.size() > 2) {
989+
return NewNotSupportedError(
990+
"OpenVINO dequantizeLinear cannot operate with more than 2D scales");
991+
}
992+
993+
if (scale_shape.empty()) {
994+
// For per-tensor/layer dequantization the scale is a scalar.
995+
axis = 0;
996+
} else if (scale_shape.size() == 1) {
997+
bool is_valid = false;
998+
// for per per-axis dequantization it is a 1-D Tensor
999+
for (size_t i = 0; i < input_shape.size(); i++) {
1000+
if (scale_shape[0] == input_shape[i]) {
1001+
axis = i;
1002+
is_valid = true;
1003+
}
1004+
}
1005+
if (!is_valid) {
1006+
return NewNotSupportedError(
1007+
"For 1D scale, the size of scale must be the same as the size of the "
1008+
"input dim specified by the axis.");
1009+
}
1010+
} else {
1011+
CHECK_EQ(scale_shape.size(), 2u);
1012+
// For blocked dequantization it has the same shape as the input, except for
1013+
// one dimension in which blocking is performed.
1014+
if (scale_shape.size() == input_shape.size()) {
1015+
uint32_t diff_count = 0;
1016+
for (size_t i = 0; i < input_shape.size(); i++) {
1017+
if (scale_shape[i] != input_shape[i]) {
1018+
// https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp#L230
1019+
if (input_shape[i] % scale_shape[i] != 0) {
1020+
return NewNotSupportedError(
1021+
"For blocked dequantization, OpenVINO DequantizeLinear doesn't "
1022+
"support case when input cannot be divided by scale.");
1023+
}
1024+
block_size = input_shape[i] / scale_shape[i];
1025+
axis = i;
1026+
diff_count++;
1027+
if (diff_count > 1) {
1028+
return NewNotSupportedError(
1029+
"For blocked dequantization it has the same shape as the "
1030+
"input, except for one dimension in which blocking is "
1031+
"performed");
1032+
}
1033+
}
1034+
}
1035+
// The shape of scale is the same as the shape of input.
1036+
if (diff_count == 0) {
1037+
axis = 0;
1038+
block_size = 1;
1039+
}
1040+
1041+
// Currently, OpenVINO only supports axis == 0 when scale.size == 2.
1042+
// https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp#L228.
1043+
if (axis != 0) {
1044+
input_name = PrependTranspose(input_name, {1, 0});
1045+
scale_name = PrependTranspose(scale_name, {1, 0});
1046+
zero_point_name = PrependTranspose(zero_point_name, {1, 0});
1047+
axis = 0;
1048+
need_transpose = true;
1049+
}
1050+
}
1051+
}
1052+
1053+
LOG(ERROR) << "AXIS: " << axis;
1054+
LOG(ERROR) << "BLOCK_WISE: " << block_size;
1055+
1056+
const std::string transposed_output_name =
1057+
need_transpose ? GenerateNextOperandName() : output_name;
1058+
1059+
base::FixedArray<const char*> input_names = {
1060+
input_name.c_str(), scale_name.c_str(), zero_point_name.c_str()};
1061+
base::FixedArray<const char*> output_names = {transposed_output_name.c_str()};
1062+
1063+
std::array<OrtOpAttr*, 2> attributes = {
1064+
model_builder_
1065+
.CreateAttribute(/*name=*/"axis", base::checked_cast<int64_t>(axis))
1066+
.Release(),
1067+
model_builder_
1068+
.CreateAttribute(/*name=*/"block_size",
1069+
base::checked_cast<int64_t>(block_size))
1070+
.Release()};
1071+
1072+
model_builder_.AddNode(kOpTypeDequantizeLinear, node_name, input_names,
1073+
output_names, attributes);
1074+
1075+
if (need_transpose) {
1076+
AppendTranspose(transposed_output_name, output_name, {1, 0});
1077+
}
1078+
return base::ok();
1079+
}
1080+
9271081
void GraphBuilderOrt::AddGatherOperation(const mojom::Gather& gather) {
9281082
const std::string node_name = GenerateNextOperationName(gather.label);
9291083
const std::string input_name = GetOperandNameById(gather.input_operand_id);
@@ -1724,6 +1878,11 @@ GraphBuilderOrt::BuildModel() {
17241878
RETURN_IF_ERROR(AddConv2dOperation(*operation->get_conv2d()));
17251879
break;
17261880
}
1881+
case mojom::Operation::Tag::kDequantizeLinear: {
1882+
RETURN_IF_ERROR(
1883+
AddDequantizeLinearOperation(*operation->get_dequantize_linear()));
1884+
break;
1885+
}
17271886
case mojom::Operation::Tag::kExpand: {
17281887
AddExpandOperation(*operation->get_expand());
17291888
break;
@@ -1807,7 +1966,6 @@ GraphBuilderOrt::BuildModel() {
18071966
break;
18081967
}
18091968
case mojom::Operation::Tag::kCumulativeSum:
1810-
case mojom::Operation::Tag::kDequantizeLinear:
18111969
case mojom::Operation::Tag::kElu:
18121970
case mojom::Operation::Tag::kGatherElements:
18131971
case mojom::Operation::Tag::kGatherNd:

services/webnn/ort/graph_builder_ort.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ class GraphBuilderOrt {
129129
std::string PrependCast(std::string_view input_name,
130130
ONNXTensorElementDataType to_data_type);
131131

132+
std::string PrependTranspose(std::string_view input_name,
133+
base::span<const uint32_t> permutation);
134+
132135
// Insert a cast operation after an operation to convert its output to the
133136
// target `to_data_type`. The `input_name` specifies the cast operation's
134137
// input (the output of the operation to be casted), and the `output_name`
@@ -137,6 +140,10 @@ class GraphBuilderOrt {
137140
std::string_view output_name,
138141
ONNXTensorElementDataType to_data_type);
139142

143+
void AppendTranspose(std::string_view input_name,
144+
std::string_view output_name,
145+
base::span<const uint32_t> permutation);
146+
140147
void AddInput(uint64_t input_id);
141148
void AddOutput(uint64_t output_id);
142149

@@ -169,6 +176,9 @@ class GraphBuilderOrt {
169176
void AddConcatOperation(const mojom::Concat& concat);
170177
[[nodiscard]] base::expected<void, mojom::ErrorPtr> AddConv2dOperation(
171178
const mojom::Conv2d& conv2d);
179+
[[nodiscard]] base::expected<void, mojom::ErrorPtr>
180+
AddDequantizeLinearOperation(
181+
const mojom::DequantizeLinear& dequantize_linear);
172182
void AddExpandOperation(const mojom::Expand& expand);
173183
void AddGatherOperation(const mojom::Gather& gather);
174184
void AddGemmOperation(const mojom::Gemm& gemm);

0 commit comments

Comments
 (0)