Skip to content

Commit 6b27f08

Browse files
authored
Merge pull request #284 from bbernhar/super_resolution
Add SuperResolution example and test
2 parents bce6861 + 7c4030c commit 6b27f08

File tree

5 files changed

+201
-3
lines changed

5 files changed

+201
-3
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// Copyright 2022 The WebNN-native Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "examples/SuperResolution/SuperResolution.h"
16+
17+
SuperResolution::SuperResolution() : ExampleBase() {
18+
}
19+
20+
const wnn::Operand SuperResolution::BuildConstantFromNpy(const wnn::GraphBuilder& builder,
21+
const std::string& path) {
22+
const cnpy::NpyArray data = cnpy::npy_load(path);
23+
mConstants.push_back(data.data_holder);
24+
return utils::BuildConstant(builder, data.shape, data.data<float>(), data.num_bytes());
25+
}
26+
27+
const wnn::Operand SuperResolution::BuildConv(const wnn::GraphBuilder& builder,
28+
const wnn::Operand& input,
29+
int32_t convIndex,
30+
bool relu,
31+
utils::Conv2dOptions* options,
32+
const std::string& biasName) {
33+
std::string prefix = mLayout == "nchw" ? mWeightsPath + "conv" : mWeightsPath + "Const_";
34+
std::string suffix = mLayout == "nchw" ? "_weight.npy" : ".npy";
35+
const std::string weightsPath = prefix + std::to_string(convIndex) + suffix;
36+
const wnn::Operand convWeights = BuildConstantFromNpy(builder, weightsPath);
37+
38+
// TODO: Figure out correct "channels last" path suffix.
39+
prefix = mLayout == "nchw" ? mWeightsPath + "conv" : mWeightsPath + "super_resolution_";
40+
if (mLayout == "nchw") {
41+
prefix.append(std::to_string(convIndex));
42+
}
43+
44+
const std::string biasPath = prefix + biasName + "_bias.npy";
45+
const wnn::Operand convBias = BuildConstantFromNpy(builder, biasPath);
46+
47+
const wnn::Conv2dOptions* conv2dOptions = options != nullptr ? options->AsPtr() : nullptr;
48+
const wnn::Operand conv2d = builder.Conv2d(input, convWeights, conv2dOptions);
49+
50+
if (!mFused) {
51+
if (relu) {
52+
return builder.Relu(conv2d);
53+
}
54+
return conv2d;
55+
}
56+
57+
// Fused
58+
utils::Conv2dOptions fusedOptions;
59+
if (options != nullptr) {
60+
fusedOptions = *options;
61+
}
62+
fusedOptions.bias = convBias;
63+
64+
if (relu) {
65+
fusedOptions.activation = builder.ReluOperator();
66+
}
67+
68+
return builder.Conv2d(input, convWeights, fusedOptions.AsPtr());
69+
}
70+
71+
const wnn::Operand SuperResolution::LoadNchw(const wnn::GraphBuilder& builder, bool softmax) {
72+
const wnn::Operand input = utils::BuildInput(builder, "input", {1, 1, 224, 224});
73+
74+
utils::Conv2dOptions conv1Options;
75+
conv1Options.strides = {1, 1};
76+
conv1Options.padding = {2, 2, 2, 2};
77+
conv1Options.dilations = {1, 1};
78+
const wnn::Operand conv1 =
79+
BuildConv(builder, input, /*convIndex*/ 1, /*relu*/ true, &conv1Options);
80+
81+
utils::Conv2dOptions conv2Options;
82+
conv2Options.strides = {1, 1};
83+
conv2Options.padding = {1, 1, 1, 1};
84+
conv2Options.dilations = {1, 1};
85+
const wnn::Operand conv2 =
86+
BuildConv(builder, conv1, /*convIndex*/ 2, /*relu*/ true, &conv2Options);
87+
88+
utils::Conv2dOptions conv3Options;
89+
conv3Options.strides = {1, 1};
90+
conv3Options.padding = {1, 1, 1, 1};
91+
conv3Options.dilations = {1, 1};
92+
const wnn::Operand conv3 =
93+
BuildConv(builder, conv2, /*convIndex*/ 3, /*relu*/ true, &conv3Options);
94+
95+
utils::Conv2dOptions conv4Options;
96+
conv4Options.strides = {1, 1};
97+
conv4Options.padding = {1, 1, 1, 1};
98+
conv4Options.dilations = {1, 1};
99+
const wnn::Operand conv4 =
100+
BuildConv(builder, conv3, /*convIndex*/ 4, /*relu*/ false, &conv4Options);
101+
102+
const std::vector<int32_t> newShape1 = {-1, 1, 3, 3, 224, 224};
103+
const wnn::Operand reshape1 = builder.Reshape(conv4, newShape1.data(), newShape1.size());
104+
105+
wnn::TransposeOptions transpose1Options;
106+
std::vector<int32_t> permutation = {0, 1, 4, 2, 5, 3};
107+
transpose1Options.permutation = permutation.data();
108+
transpose1Options.permutationCount = permutation.size();
109+
const wnn::Operand transpose1 = builder.Transpose(reshape1, &transpose1Options);
110+
111+
const std::vector<int32_t> newShape2 = {-1, 1, 672, 672};
112+
return builder.Reshape(transpose1, newShape2.data(), newShape2.size());
113+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright 2022 The WebNN-native Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <webnn/webnn.h>
16+
#include <webnn/webnn_cpp.h>
17+
18+
#include "examples/SampleUtils.h"
19+
20+
class SuperResolution : public ExampleBase {
21+
public:
22+
SuperResolution();
23+
~SuperResolution() override = default;
24+
25+
const wnn::Operand LoadNchw(const wnn::GraphBuilder& builder, bool softmax);
26+
27+
private:
28+
const wnn::Operand BuildConstantFromNpy(const wnn::GraphBuilder& builder,
29+
const std::string& path);
30+
31+
const wnn::Operand BuildConv(const wnn::GraphBuilder& builder,
32+
const wnn::Operand& input,
33+
int32_t convIndex,
34+
bool relu6,
35+
utils::Conv2dOptions* options,
36+
const std::string& biasName = "");
37+
38+
std::vector<SHARED_DATA_TYPE> mConstants;
39+
};

src/webnn/native/dmlx/GraphDMLX.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,7 +1270,7 @@ namespace webnn::native::dmlx {
12701270
DAWN_ASSERT(mExpression.find(inputOperand) != mExpression.end());
12711271
::dml::Expression input = mExpression.at(inputOperand);
12721272
auto newShape = reshape->GetNewShape();
1273-
if (newShape.size() > DML_TENSOR_DIMENSION_COUNT_MAX) {
1273+
if (newShape.size() > DML_TENSOR_DIMENSION_COUNT_MAX1) {
12741274
return DAWN_INTERNAL_ERROR("The size of new shape is not supported by DML.");
12751275
}
12761276
::dml::TensorDimensions newSizes(newShape.size());
@@ -1424,7 +1424,7 @@ namespace webnn::native::dmlx {
14241424
DAWN_ASSERT(mExpression.find(inputOperand) != mExpression.end());
14251425
::dml::Expression input = mExpression.at(inputOperand);
14261426
std::vector<int32_t> permutation = transpose->GetPermutation();
1427-
if (permutation.size() > DML_TENSOR_DIMENSION_COUNT_MAX) {
1427+
if (permutation.size() > DML_TENSOR_DIMENSION_COUNT_MAX1) {
14281428
return DAWN_INTERNAL_ERROR("The size of permutation is not supported by DML.");
14291429
}
14301430

src/webnn/tests/BUILD.gn

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ test("webnn_unittests") {
132132
"${webnn_root}/src/webnn:cpp",
133133
"${webnn_root}/src/webnn:webnn_proc",
134134
"${webnn_root}/src/webnn/common",
135-
"${webnn_root}/src/webnn/native:webnn_native",
136135
"${webnn_root}/src/webnn/native:sources",
136+
"${webnn_root}/src/webnn/native:webnn_native",
137137
"${webnn_root}/src/webnn/utils:webnn_utils",
138138
]
139139

@@ -205,6 +205,8 @@ source_set("webnn_end2end_tests_sources") {
205205
"${webnn_root}/examples/ResNet/ResNet.h",
206206
"${webnn_root}/examples/SqueezeNet/SqueezeNet.cpp",
207207
"${webnn_root}/examples/SqueezeNet/SqueezeNet.h",
208+
"${webnn_root}/examples/SuperResolution/SuperResolution.cpp",
209+
"${webnn_root}/examples/SuperResolution/SuperResolution.h",
208210
"WebnnTest.cpp",
209211
"WebnnTest.h",
210212
"end2end/AddTests.cpp",
@@ -249,6 +251,7 @@ source_set("webnn_end2end_tests_sources") {
249251
"end2end/models/ResNetNhwc.cpp",
250252
"end2end/models/SqueezeNetNchw.cpp",
251253
"end2end/models/SqueezeNetNhwc.cpp",
254+
"end2end/models/SuperResolutionNchw.cpp",
252255
]
253256

254257
# Validation tests that need OS windows live in end2end tests.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright 2022 The WebNN-native Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "examples/SuperResolution/SuperResolution.h"
16+
#include "webnn/tests/WebnnTest.h"
17+
18+
static const std::string kModelPath = WEBNN_END2END_TEST_MODEL_PATH;
19+
20+
class SuperResolutionNchwTests : public WebnnTest {
21+
public:
22+
void TestSuperResolutionNchw(const std::string& inputFile,
23+
const std::string& expectedFile,
24+
bool fused = true) {
25+
SuperResolution superresolution;
26+
superresolution.mFused = true;
27+
const std::string nchwPath = kModelPath + "/super_resolution_nchw/";
28+
superresolution.mWeightsPath = nchwPath + "weights/";
29+
const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext());
30+
wnn::Operand output = superresolution.LoadNchw(builder, false);
31+
wnn::Graph graph = utils::Build(builder, {{"output", output}});
32+
const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile);
33+
const std::vector<float> inputData = inputNpy.as_vec<float>();
34+
std::vector<float> result(utils::SizeOfShape({/*TODO: batchSize?*/ 1, 1, 672, 672}));
35+
utils::Compute(graph, {{"input", inputData}}, {{"output", result}});
36+
const cnpy::NpyArray outputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + expectedFile);
37+
EXPECT_TRUE(utils::CheckValue(result, outputNpy.as_vec<float>()));
38+
}
39+
};
40+
41+
TEST_F(SuperResolutionNchwTests, NchwTest0) {
42+
TestSuperResolutionNchw("0/input_0.npy", "0/output_0.npy", false);
43+
}

0 commit comments

Comments
 (0)