Skip to content

Commit 9f2b7d5

Browse files
lisa0314shiyi9801
authored andcommitted
Implement dequantizeLinear (#134)
1 parent b4b6b91 commit 9f2b7d5

4 files changed

Lines changed: 223 additions & 4 deletions

File tree

services/webnn/ort/context_impl_ort.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ ContextProperties ContextImplOrt::GetContextProperties() {
5959
static constexpr SupportedRanks kNonScalarMaxRank =
6060
SupportedRanks::NonScalarUpTo(8);
6161

62+
static constexpr SupportedDataTypes kDequantizeLinearInputSupportedDataTypes{
63+
OperandDataType::kInt4, OperandDataType::kUint4, OperandDataType::kUint8,
64+
OperandDataType::kInt8, OperandDataType::kInt32};
65+
6266
return ContextProperties(
6367
InputOperandLayout::kNchw, Resample2DAxes::kChannelsFirst,
6468
/*tensor_byte_length_limit=*/kTensorByteLengthLimit,
@@ -74,8 +78,8 @@ ContextProperties ContextImplOrt::GetContextProperties() {
7478
/*conv2d_input=*/DataTypeConstraint::kFloat16To32,
7579
/*conv_transpose2d_input=*/DataTypeConstraint::kFloat16To32,
7680
/*cumulative_sum_input=*/{},
77-
/*dequantize_linear_input=*/{},
78-
/*dequantize_linear_scale=*/{},
81+
/*dequantize_linear_input=*/kDequantizeLinearInputSupportedDataTypes,
82+
/*dequantize_linear_scale=*/DataTypeConstraint::kFloat16To32,
7983
/*add_input=*/
8084
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
8185
/*sub_input=*/

services/webnn/ort/graph_builder_ort.cc

Lines changed: 202 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ constexpr char kOpTypeClamp[] = "Clip";
6969
constexpr char kOpTypeConcat[] = "Concat";
7070
constexpr char kOpTypeConv2d[] = "Conv";
7171
constexpr char kOpTypeConvTranspose2d[] = "ConvTranspose";
72+
constexpr char kOpTypeDequantizeLinear[] = "DequantizeLinear";
7273
constexpr char kOpTypeExpand[] = "Expand";
7374
constexpr char kOpTypeGather[] = "Gather";
7475
constexpr char kOpTypeGelu[] = "Gelu";
@@ -301,6 +302,60 @@ void GraphBuilderOrt::AppendCast(std::string_view input_name,
301302
ADD_CAST_NODE(node_name, input_name, output_name, to_data_type);
302303
}
303304

305+
std::string GraphBuilderOrt::PrependTranspose(
306+
std::string_view input_name,
307+
base::span<const uint32_t> permutation) {
308+
const std::string node_name = GenerateNextOperationName("inserted_transpose");
309+
const std::string output_name = GenerateNextOperandName();
310+
311+
std::array<const char*, 1> input_names = {input_name.data()};
312+
std::array<const char*, 1> output_names = {output_name.data()};
313+
314+
std::vector<int64_t> perm(permutation.begin(), permutation.end());
315+
std::array<OrtOpAttr*, 1> attributes = {
316+
model_builder_.CreateAttribute(/*name=*/"perm", perm).Release()};
317+
318+
model_builder_.AddNode(kOpTypeTranspose, node_name, input_names, output_names,
319+
attributes);
320+
return output_name;
321+
}
322+
323+
void GraphBuilderOrt::AppendTranspose(std::string_view input_name,
324+
std::string_view output_name,
325+
base::span<const uint32_t> permutation) {
326+
const std::string node_name = GenerateNextOperationName("inserted_transpose");
327+
std::array<const char*, 1> input_names = {input_name.data()};
328+
std::array<const char*, 1> output_names = {output_name.data()};
329+
330+
std::vector<int64_t> perm(permutation.begin(), permutation.end());
331+
std::array<OrtOpAttr*, 1> attributes = {
332+
model_builder_.CreateAttribute(/*name=*/"perm", perm).Release()};
333+
334+
model_builder_.AddNode(kOpTypeTranspose, node_name, input_names, output_names,
335+
attributes);
336+
}
337+
338+
[[nodiscard]] base::expected<std::string, mojom::ErrorPtr>
339+
GraphBuilderOrt::PrependReshape(std::string_view input_name,
340+
base::span<const int64_t> new_shape) {
341+
const std::string node_name = GenerateNextOperationName("inserted_reshape");
342+
const std::string output_name = GenerateNextOperandName();
343+
344+
// Shape is an operand with data type int64, not an attribute.
345+
std::vector<uint32_t> new_shape_dims = {
346+
base::checked_cast<uint32_t>(new_shape.size())};
347+
ASSIGN_OR_RETURN(const std::string shape_name,
348+
CreateInitializer<int64_t>(new_shape_dims, new_shape));
349+
350+
std::array<const char*, 2> input_names = {input_name.data(),
351+
shape_name.c_str()};
352+
std::array<const char*, 1> output_names = {output_name.c_str()};
353+
354+
model_builder_.AddNode(kOpTypeReshape, node_name, input_names, output_names);
355+
356+
return output_name;
357+
}
358+
304359
void GraphBuilderOrt::AddInput(uint64_t input_id) {
305360
const mojom::Operand& operand = GetOperand(input_id);
306361
std::string name = GetOperandNameById(input_id);
@@ -953,6 +1008,148 @@ GraphBuilderOrt::AddExpandOperation(const mojom::Expand& expand) {
9531008
return base::ok();
9541009
}
9551010

1011+
[[nodiscard]] base::expected<void, mojom::ErrorPtr>
1012+
GraphBuilderOrt::AddDequantizeLinearOperation(
1013+
const mojom::DequantizeLinear& dequantize_linear) {
1014+
const std::string node_name =
1015+
GenerateNextOperationName(dequantize_linear.label);
1016+
std::string input_name =
1017+
GetOperandNameById(dequantize_linear.input_operand_id);
1018+
std::string scale_name =
1019+
GetOperandNameById(dequantize_linear.scale_operand_id);
1020+
std::string zero_point_name =
1021+
GetOperandNameById(dequantize_linear.zero_point_operand_id);
1022+
std::string output_name =
1023+
GetOperandNameById(dequantize_linear.output_operand_id);
1024+
1025+
const OperandDescriptor& input_descriptor =
1026+
GetOperand(dequantize_linear.input_operand_id).descriptor;
1027+
std::vector<uint32_t> input_shape = input_descriptor.shape();
1028+
1029+
const OperandDescriptor& scale_descriptor =
1030+
GetOperand(dequantize_linear.scale_operand_id).descriptor;
1031+
// ZeroPoint has the same shape as the scale.
1032+
std::vector<uint32_t> scale_shape = scale_descriptor.shape();
1033+
1034+
std::optional<int64_t> axis;
1035+
uint32_t not_one_value_dim_count = 0;
1036+
bool found_same_size = false;
1037+
CHECK_LE(scale_shape.size(), input_shape.size());
1038+
for (size_t i = 0; i < scale_shape.size(); i++) {
1039+
if (scale_shape[scale_shape.size() - i - 1] != 1) {
1040+
not_one_value_dim_count++;
1041+
if (scale_shape[scale_shape.size() - i - 1] ==
1042+
input_shape[input_shape.size() - i - 1]) {
1043+
axis = input_shape.size() - i - 1;
1044+
found_same_size = true;
1045+
}
1046+
}
1047+
}
1048+
// TODO(https://github.com/shiyi9801/chromium/issues/139): Consider to add
1049+
// emulation to support multiple axes case, e.g. input shape is [2, 3, 4, 5]
1050+
// and scale shape is [1, 3, 4, 1].
1051+
bool is_per_axis = found_same_size && not_one_value_dim_count == 1;
1052+
1053+
std::optional<int64_t> block_size;
1054+
bool need_transpose = false;
1055+
if (scale_shape.empty()) {
1056+
// For per-tensor/layer dequantization the scale is a scalar.
1057+
} else if (not_one_value_dim_count == 0) {
1058+
// The numbers in scale shape are all 1., scale and zeroPoint should be
1059+
// reshaped to a scalar.
1060+
ASSIGN_OR_RETURN(scale_name, PrependReshape(scale_name, {}));
1061+
ASSIGN_OR_RETURN(zero_point_name, PrependReshape(zero_point_name, {}));
1062+
} else if (is_per_axis) {
1063+
// For per-axis dequantization, scale and zeroPoint must be a 1-D
1064+
// Tensor.
1065+
CHECK(axis.has_value());
1066+
ASSIGN_OR_RETURN(scale_name,
1067+
PrependReshape(scale_name, {input_shape[axis.value()]}));
1068+
ASSIGN_OR_RETURN(
1069+
zero_point_name,
1070+
PrependReshape(zero_point_name, {input_shape[axis.value()]}));
1071+
} else if (scale_shape.size() == input_shape.size()) {
1072+
// For blocked dequantization it has the same shape as the input, except for
1073+
// one dimension in which blocking is performed.
1074+
uint32_t blocked_axis_count = 0;
1075+
axis = 0;
1076+
block_size = 1;
1077+
for (size_t i = 0; i < input_shape.size(); i++) {
1078+
if (scale_shape[i] != input_shape[i]) {
1079+
CHECK_EQ(input_shape[i] % scale_shape[i], 0u);
1080+
block_size = input_shape[i] / scale_shape[i];
1081+
axis = i;
1082+
blocked_axis_count++;
1083+
// TODO(https://github.com/shiyi9801/chromium/issues/135): Consider to
1084+
// add emulation to support multi-dimensions blockwise.
1085+
if (blocked_axis_count > 1) {
1086+
return NewNotSupportedError(
1087+
"For blocked dequantization scale has the same shape as the "
1088+
"input or except for one dimension in which blocking is "
1089+
"performed");
1090+
}
1091+
}
1092+
}
1093+
1094+
// Currently, OpenVINO only supports axis == 0 when scale.size == 2.
1095+
// https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp#L228.
1096+
if (base::CommandLine::ForCurrentProcess()->HasSwitch(
1097+
switches::kWebNNOrtUseOpenvino)) {
1098+
if (scale_shape.size() != 2) {
1099+
// https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp#L220
1100+
return NewNotSupportedError(
1101+
"Currently ORT OpenVINO only support 2D scale for block_wise "
1102+
"dequantizeLinear.");
1103+
} else if (axis == 1) {
1104+
input_name = PrependTranspose(input_name, {1, 0});
1105+
scale_name = PrependTranspose(scale_name, {1, 0});
1106+
zero_point_name = PrependTranspose(zero_point_name, {1, 0});
1107+
axis = 0;
1108+
need_transpose = true;
1109+
}
1110+
}
1111+
} else {
1112+
// The proposal of requiring scale and zeroPoint to be the same rank as
1113+
// the input is under discussion-
1114+
// https://github.com/webmachinelearning/webnn/pull/805#discussion_r1919498405
1115+
return NewNotSupportedError(
1116+
"Currently, ONNX only supports per-tensor, per-axis and block-wise "
1117+
"dequantizeLinear");
1118+
}
1119+
1120+
const std::string transposed_output_name =
1121+
need_transpose ? GenerateNextOperandName() : output_name;
1122+
1123+
base::FixedArray<const char*> input_names = {
1124+
input_name.c_str(), scale_name.c_str(), zero_point_name.c_str()};
1125+
base::FixedArray<const char*> output_names = {transposed_output_name.c_str()};
1126+
1127+
std::vector<OrtOpAttr*> attributes;
1128+
if (axis.has_value()) {
1129+
attributes.push_back(
1130+
model_builder_
1131+
.CreateAttribute(/*name=*/"axis",
1132+
base::checked_cast<int64_t>(axis.value()))
1133+
.Release());
1134+
}
1135+
1136+
if (block_size.has_value()) {
1137+
attributes.push_back(
1138+
model_builder_
1139+
.CreateAttribute(/*name=*/"block_size",
1140+
base::checked_cast<int64_t>(block_size.value()))
1141+
.Release());
1142+
}
1143+
1144+
model_builder_.AddNode(kOpTypeDequantizeLinear, node_name, input_names,
1145+
output_names, attributes);
1146+
1147+
if (need_transpose) {
1148+
AppendTranspose(transposed_output_name, output_name, {1, 0});
1149+
}
1150+
return base::ok();
1151+
}
1152+
9561153
void GraphBuilderOrt::AddGatherOperation(const mojom::Gather& gather) {
9571154
const std::string node_name = GenerateNextOperationName(gather.label);
9581155
const std::string input_name = GetOperandNameById(gather.input_operand_id);
@@ -1824,6 +2021,11 @@ GraphBuilderOrt::BuildModel() {
18242021
RETURN_IF_ERROR(AddConv2dOperation(*operation->get_conv2d()));
18252022
break;
18262023
}
2024+
case mojom::Operation::Tag::kDequantizeLinear: {
2025+
RETURN_IF_ERROR(
2026+
AddDequantizeLinearOperation(*operation->get_dequantize_linear()));
2027+
break;
2028+
}
18272029
case mojom::Operation::Tag::kExpand: {
18282030
RETURN_IF_ERROR(AddExpandOperation(*operation->get_expand()));
18292031
break;
@@ -1911,7 +2113,6 @@ GraphBuilderOrt::BuildModel() {
19112113
break;
19122114
}
19132115
case mojom::Operation::Tag::kCumulativeSum:
1914-
case mojom::Operation::Tag::kDequantizeLinear:
19152116
case mojom::Operation::Tag::kElu:
19162117
case mojom::Operation::Tag::kGatherElements:
19172118
case mojom::Operation::Tag::kGatherNd:

services/webnn/ort/graph_builder_ort.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,13 @@ class GraphBuilderOrt {
131131
std::string PrependCast(std::string_view input_name,
132132
ONNXTensorElementDataType to_data_type);
133133

134+
[[nodiscard]] base::expected<std::string, mojom::ErrorPtr> PrependReshape(
135+
std::string_view input_name,
136+
base::span<const int64_t> new_shape);
137+
138+
std::string PrependTranspose(std::string_view input_name,
139+
base::span<const uint32_t> permutation);
140+
134141
// Insert a cast operation after an operation to convert its output to the
135142
// target `to_data_type`. The `input_name` specifies the cast operation's
136143
// input (the output of the operation to be casted), and the `output_name`
@@ -139,6 +146,10 @@ class GraphBuilderOrt {
139146
std::string_view output_name,
140147
ONNXTensorElementDataType to_data_type);
141148

149+
void AppendTranspose(std::string_view input_name,
150+
std::string_view output_name,
151+
base::span<const uint32_t> permutation);
152+
142153
void AddInput(uint64_t input_id);
143154
void AddOutput(uint64_t output_id);
144155

@@ -175,6 +186,9 @@ class GraphBuilderOrt {
175186
const mojom::Conv2d& conv2d);
176187
[[nodiscard]] base::expected<void, mojom::ErrorPtr> AddExpandOperation(
177188
const mojom::Expand& expand);
189+
[[nodiscard]] base::expected<void, mojom::ErrorPtr>
190+
AddDequantizeLinearOperation(
191+
const mojom::DequantizeLinear& dequantize_linear);
178192
void AddGatherOperation(const mojom::Gather& gather);
179193
void AddGemmOperation(const mojom::Gemm& gemm);
180194
[[nodiscard]] base::expected<void, mojom::ErrorPtr>

services/webnn/ort/ort_model_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ void OrtModelBuilder::AddOutput(std::string_view name,
101101
void* ort_tensor_raw_data = nullptr;
102102
RETURN_STATUS_IF_FAILED(
103103
GetOrtApi()->GetTensorMutableData(initializer, &ort_tensor_raw_data));
104-
CHECK(ort_tensor_raw_data);
104+
// ort_tensor_raw_data can be nullprt when data is empty.
105105
UNSAFE_BUFFERS(
106106
base::span(static_cast<uint8_t*>(ort_tensor_raw_data), data.size()))
107107
.copy_from(data);

0 commit comments

Comments
 (0)