1+ // Copyright 2022 The WebNN-native Authors
2+ //
3+ // Licensed under the Apache License, Version 2.0 (the "License");
4+ // you may not use this file except in compliance with the License.
5+ // You may obtain a copy of the License at
6+ //
7+ // http://www.apache.org/licenses/LICENSE-2.0
8+ //
9+ // Unless required by applicable law or agreed to in writing, software
10+ // distributed under the License is distributed on an "AS IS" BASIS,
11+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ // See the License for the specific language governing permissions and
13+ // limitations under the License.
14+
15+ #include " examples/SuperResolution/SuperResolution.h"
16+
17+ SuperResolution::SuperResolution () : ExampleBase() {
18+ }
19+
20+ const wnn::Operand SuperResolution::BuildConstantFromNpy (const wnn::GraphBuilder& builder,
21+ const std::string& path) {
22+ const cnpy::NpyArray data = cnpy::npy_load (path);
23+ mConstants .push_back (data.data_holder );
24+ return utils::BuildConstant (builder, data.shape , data.data <float >(), data.num_bytes ());
25+ }
26+
27+ const wnn::Operand SuperResolution::BuildConv (const wnn::GraphBuilder& builder,
28+ const wnn::Operand& input,
29+ int32_t convIndex,
30+ bool relu,
31+ utils::Conv2dOptions* options,
32+ const std::string& biasName) {
33+ std::string prefix = mLayout == " nchw" ? mWeightsPath + " conv" : mWeightsPath + " Const_" ;
34+ std::string suffix = mLayout == " nchw" ? " _weight.npy" : " .npy" ;
35+ const std::string weightsPath = prefix + std::to_string (convIndex) + suffix;
36+ const wnn::Operand convWeights = BuildConstantFromNpy (builder, weightsPath);
37+
38+ // TODO: Figure out correct "channels last" path suffix.
39+ prefix = mLayout == " nchw" ? mWeightsPath + " conv" : mWeightsPath + " super_resolution_" ;
40+ if (mLayout == " nchw" ) {
41+ prefix.append (std::to_string (convIndex));
42+ }
43+
44+ const std::string biasPath = prefix + biasName + " _bias.npy" ;
45+ const wnn::Operand convBias = BuildConstantFromNpy (builder, biasPath);
46+
47+ const wnn::Conv2dOptions* conv2dOptions = options != nullptr ? options->AsPtr () : nullptr ;
48+ const wnn::Operand conv2d = builder.Conv2d (input, convWeights, conv2dOptions);
49+
50+ if (!mFused ) {
51+ if (relu) {
52+ return builder.Relu (conv2d);
53+ }
54+ return conv2d;
55+ }
56+
57+ // Fused
58+ utils::Conv2dOptions fusedOptions;
59+ if (options != nullptr ) {
60+ fusedOptions = *options;
61+ }
62+ fusedOptions.bias = convBias;
63+
64+ if (relu) {
65+ fusedOptions.activation = builder.ReluOperator ();
66+ }
67+
68+ return builder.Conv2d (input, convWeights, fusedOptions.AsPtr ());
69+ }
70+
71+ const wnn::Operand SuperResolution::LoadNchw (const wnn::GraphBuilder& builder, bool softmax) {
72+ const wnn::Operand input = utils::BuildInput (builder, " input" , {1 , 1 , 224 , 224 });
73+
74+ utils::Conv2dOptions conv1Options;
75+ conv1Options.strides = {1 , 1 };
76+ conv1Options.padding = {2 , 2 , 2 , 2 };
77+ conv1Options.dilations = {1 , 1 };
78+ const wnn::Operand conv1 =
79+ BuildConv (builder, input, /* convIndex*/ 1 , /* relu*/ true , &conv1Options);
80+
81+ utils::Conv2dOptions conv2Options;
82+ conv2Options.strides = {1 , 1 };
83+ conv2Options.padding = {1 , 1 , 1 , 1 };
84+ conv2Options.dilations = {1 , 1 };
85+ const wnn::Operand conv2 =
86+ BuildConv (builder, conv1, /* convIndex*/ 2 , /* relu*/ true , &conv2Options);
87+
88+ utils::Conv2dOptions conv3Options;
89+ conv3Options.strides = {1 , 1 };
90+ conv3Options.padding = {1 , 1 , 1 , 1 };
91+ conv3Options.dilations = {1 , 1 };
92+ const wnn::Operand conv3 =
93+ BuildConv (builder, conv2, /* convIndex*/ 3 , /* relu*/ true , &conv3Options);
94+
95+ utils::Conv2dOptions conv4Options;
96+ conv4Options.strides = {1 , 1 };
97+ conv4Options.padding = {1 , 1 , 1 , 1 };
98+ conv4Options.dilations = {1 , 1 };
99+ const wnn::Operand conv4 =
100+ BuildConv (builder, conv3, /* convIndex*/ 4 , /* relu*/ false , &conv4Options);
101+
102+ const std::vector<int32_t > newShape1 = {-1 , 1 , 3 , 3 , 224 , 224 };
103+ const wnn::Operand reshape1 = builder.Reshape (conv4, newShape1.data (), newShape1.size ());
104+
105+ wnn::TransposeOptions transpose1Options;
106+ std::vector<int32_t > permutation = {0 , 1 , 4 , 2 , 5 , 3 };
107+ transpose1Options.permutation = permutation.data ();
108+ transpose1Options.permutationCount = permutation.size ();
109+ const wnn::Operand transpose1 = builder.Transpose (reshape1, &transpose1Options);
110+
111+ const std::vector<int32_t > newShape2 = {-1 , 1 , 672 , 672 };
112+ return builder.Reshape (transpose1, newShape2.data (), newShape2.size ());
113+ }
0 commit comments