Skip to content

Commit 320e789

Browse files
committed
[WebNN-native/Node] support context-based compute graph
1 parent 31d48a4 commit 320e789

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+532
-488
lines changed

examples/LeNet/LeNet.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ class LeNet {
2525
LeNet();
2626
~LeNet() = default;
2727

28-
wnn::Graph Build(const std::string& weigthsPath);
29-
30-
private:
3128
wnn::Context mContext;
29+
wnn::Graph Build(const std::string& weigthsPath);
3230
};

examples/LeNet/Main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ int main(int argc, const char* argv[]) {
8787
for (int i = 0; i < nIter; ++i) {
8888
std::chrono::time_point<std::chrono::high_resolution_clock> executionStartTime =
8989
std::chrono::high_resolution_clock::now();
90-
utils::Compute(graph, {{"input", input}}, {{"output", result}});
90+
utils::Compute(lenet.mContext, graph, {{"input", input}}, {{"output", result}});
9191
executionTimeVector.push_back(std::chrono::high_resolution_clock::now() -
9292
executionStartTime);
9393
}

examples/MobileNetV2/Main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ int main(int argc, const char* argv[]) {
6060
std::vector<float> result(utils::SizeOfShape(mobilevetv2.mOutputShape));
6161
// Do the first inference for warming up if nIter > 1.
6262
if (mobilevetv2.mNIter > 1) {
63-
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
63+
utils::Compute(context, graph, {{"input", processedPixels}}, {{"output", result}});
6464
}
6565

6666
std::vector<TIME_TYPE> executionTime;
6767
for (int i = 0; i < mobilevetv2.mNIter; ++i) {
6868
std::chrono::time_point<std::chrono::high_resolution_clock> executionStartTime =
6969
std::chrono::high_resolution_clock::now();
70-
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
70+
utils::Compute(context, graph, {{"input", processedPixels}}, {{"output", result}});
7171
executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime);
7272
}
7373

examples/ResNet/Main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ int main(int argc, const char* argv[]) {
6060
std::vector<float> result(utils::SizeOfShape(resnet.mOutputShape));
6161
// Do the first inference for warming up if nIter > 1.
6262
if (resnet.mNIter > 1) {
63-
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
63+
utils::Compute(context, graph, {{"input", processedPixels}}, {{"output", result}});
6464
}
6565

6666
std::vector<TIME_TYPE> executionTime;
6767
for (int i = 0; i < resnet.mNIter; ++i) {
6868
std::chrono::time_point<std::chrono::high_resolution_clock> executionStartTime =
6969
std::chrono::high_resolution_clock::now();
70-
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
70+
utils::Compute(context, graph, {{"input", processedPixels}}, {{"output", result}});
7171
executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime);
7272
}
7373

examples/SampleUtils.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,11 @@ namespace utils {
290290
return builder.Build(namedOperands);
291291
}
292292

293-
void Compute(const wnn::Graph& graph,
293+
void Compute(const wnn::Context& context,
294+
const wnn::Graph& graph,
294295
const std::vector<NamedInput<float>>& inputs,
295296
const std::vector<NamedOutput<float>>& outputs) {
296-
return Compute<float>(graph, inputs, outputs);
297+
return Compute<float>(context, graph, inputs, outputs);
297298
}
298299

299300
std::vector<std::string> ReadTopKLabel(const std::vector<size_t>& topKIndex,

examples/SampleUtils.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,8 @@ namespace utils {
244244
};
245245

246246
template <typename T>
247-
void Compute(const wnn::Graph& graph,
247+
void Compute(const wnn::Context& context,
248+
const wnn::Graph& graph,
248249
const std::vector<NamedInput<T>>& inputs,
249250
const std::vector<NamedOutput<T>>& outputs) {
250251
if (graph.GetHandle() == nullptr) {
@@ -274,11 +275,12 @@ namespace utils {
274275
mlOutputs.push_back(resource);
275276
namedOutputs.Set(output.name.c_str(), &mlOutputs.back());
276277
}
277-
graph.Compute(namedInputs, namedOutputs);
278+
context.ComputeSync(graph, namedInputs, namedOutputs);
278279
DoFlush();
279280
}
280281

281-
void Compute(const wnn::Graph& graph,
282+
void Compute(const wnn::Context& context,
283+
const wnn::Graph& graph,
282284
const std::vector<NamedInput<float>>& inputs,
283285
const std::vector<NamedOutput<float>>& outputs);
284286

examples/SqueezeNet/Main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ int main(int argc, const char* argv[]) {
6060
std::vector<float> result(utils::SizeOfShape(squeezenet.mOutputShape));
6161
// Do the first inference for warming up if nIter > 1.
6262
if (squeezenet.mNIter > 1) {
63-
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
63+
utils::Compute(context, graph, {{"input", processedPixels}}, {{"output", result}});
6464
}
6565

6666
std::vector<TIME_TYPE> executionTime;
6767
for (int i = 0; i < squeezenet.mNIter; ++i) {
6868
std::chrono::time_point<std::chrono::high_resolution_clock> executionStartTime =
6969
std::chrono::high_resolution_clock::now();
70-
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
70+
utils::Compute(context, graph, {{"input", processedPixels}}, {{"output", result}});
7171
executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime);
7272
}
7373

node/src/Context.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
#include <napi.h>
1818
#include <iostream>
1919

20+
#include "Graph.h"
2021
#include "ML.h"
22+
#include "Utils.h"
2123

2224
Napi::FunctionReference node::Context::constructor;
2325

@@ -90,10 +92,37 @@ namespace node {
9092

9193
Napi::Object Context::Initialize(Napi::Env env, Napi::Object exports) {
9294
Napi::HandleScope scope(env);
93-
Napi::Function func = DefineClass(env, "MLContext", {});
95+
Napi::Function func = DefineClass(
96+
env, "MLContext", {InstanceMethod("compute", &Context::Compute, napi_enumerable)});
9497
constructor = Napi::Persistent(func);
9598
constructor.SuppressDestruct();
9699
exports.Set("MLContext", func);
97100
return exports;
98101
}
102+
103+
Napi::Value Context::Compute(const Napi::CallbackInfo& info) {
104+
// status compute(NamedInputs inputs, NamedOutputs outputs);
105+
WEBNN_NODE_ASSERT(info.Length() == 3, "The number of arguments is invalid.");
106+
Napi::Object object = info[0].As<Napi::Object>();
107+
node::Graph* jsGraph = Napi::ObjectWrap<node::Graph>::Unwrap(object);
108+
109+
std::map<std::string, Input> inputs;
110+
WEBNN_NODE_ASSERT(GetNamedInputs(info[1], inputs), "The inputs parameter is invalid.");
111+
112+
std::map<std::string, wnn::Resource> outputs;
113+
WEBNN_NODE_ASSERT(GetNamedOutputs(info[2], outputs), "The outputs parameter is invalid.");
114+
115+
wnn::NamedInputs namedInputs = wnn::CreateNamedInputs();
116+
for (auto& input : inputs) {
117+
namedInputs.Set(input.first.data(), input.second.AsPtr());
118+
}
119+
wnn::NamedOutputs namedOutputs = wnn::CreateNamedOutputs();
120+
for (auto& output : outputs) {
121+
namedOutputs.Set(output.first.data(), &output.second);
122+
}
123+
mImpl.ComputeSync(jsGraph->GetImpl(), namedInputs, namedOutputs);
124+
125+
return Napi::Number::New(info.Env(), 0);
126+
}
127+
99128
} // namespace node

node/src/Context.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ namespace node {
3131
wnn::Context GetImpl();
3232

3333
private:
34+
Napi::Value Compute(const Napi::CallbackInfo& info);
3435
wnn::Context mImpl;
3536
};
3637

node/src/Graph.cpp

Lines changed: 4 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -13,135 +13,22 @@
1313
// limitations under the License.
1414

1515
#include "Graph.h"
16-
17-
#include <iostream>
18-
#include <map>
19-
2016
#include "Utils.h"
2117

2218
namespace node {
2319

24-
struct Input {
25-
public:
26-
wnn::ArrayBufferView bufferView;
27-
std::vector<int32_t> dimensions;
28-
29-
const wnn::Input* AsPtr() {
30-
mInput.resource.arrayBufferView = bufferView;
31-
mInput.resource.gpuBufferView = {};
32-
if (!dimensions.empty()) {
33-
mInput.dimensions = dimensions.data();
34-
mInput.dimensionsCount = dimensions.size();
35-
}
36-
return &mInput;
37-
}
38-
39-
private:
40-
wnn::Input mInput;
41-
};
42-
43-
bool GetNamedInputs(const Napi::Value& jsValue, std::map<std::string, Input>& namedInputs) {
44-
if (!jsValue.IsObject()) {
45-
return false;
46-
}
47-
Napi::Object jsNamedInputs = jsValue.As<Napi::Object>();
48-
Napi::Array names = jsNamedInputs.GetPropertyNames();
49-
if (names.Length() == 0) {
50-
return false;
51-
}
52-
// typedef (MLBufferView or WebGLTexture or GPUTexture) MLResource;
53-
// dictionary MLInput {
54-
// required MLResource resource;
55-
// required sequence<long> dimensions;
56-
// };
57-
// typedef record<DOMString, (MLResource or MLInput)> MLNamedInputs;
58-
for (size_t i = 0; i < names.Length(); ++i) {
59-
Input input = {};
60-
std::string name = names.Get(i).As<Napi::String>().Utf8Value();
61-
// FIXME: validate the type of typed array.
62-
Napi::TypedArray jsTypedArray;
63-
if (jsNamedInputs.Get(name).IsTypedArray()) {
64-
jsTypedArray = jsNamedInputs.Get(name).As<Napi::TypedArray>();
65-
} else {
66-
Napi::Object jsInput = jsNamedInputs.Get(name).As<Napi::Object>();
67-
if (!jsInput.Has("resource") || !jsInput.Has("dimensions")) {
68-
// Input resource and dimensions are required.
69-
return false;
70-
}
71-
if (!jsInput.Get("resource").IsTypedArray()) {
72-
return false;
73-
}
74-
jsTypedArray = jsInput.Get("resource").As<Napi::TypedArray>();
75-
76-
if (!GetArray(jsInput.Get("dimensions"), input.dimensions)) {
77-
return false;
78-
}
79-
if (SizeOfShape(input.dimensions) != jsTypedArray.ElementSize()) {
80-
return false;
81-
}
82-
}
83-
if (!GetArrayBufferView(jsTypedArray, input.bufferView)) {
84-
return false;
85-
}
86-
namedInputs[name] = input;
87-
}
88-
return true;
20+
Graph::Graph(const Napi::CallbackInfo& info) : Napi::ObjectWrap<Graph>(info) {
8921
}
9022

91-
bool GetNamedOutputs(const Napi::Value& jsValue,
92-
std::map<std::string, wnn::Resource>& namedOutputs) {
93-
if (!jsValue.IsObject()) {
94-
return false;
95-
}
96-
Napi::Object jsNamedOutputs = jsValue.As<Napi::Object>();
97-
Napi::Array names = jsNamedOutputs.GetPropertyNames();
98-
if (names.Length() == 0) {
99-
return false;
100-
}
101-
// typedef (MLBufferView or WebGLTexture or GPUTexture) MLResource;
102-
// typedef record<DOMString, MLResource> MLNamedOutputs;
103-
for (size_t i = 0; i < names.Length(); ++i) {
104-
wnn::ArrayBufferView arrayBuffer = {};
105-
std::string name = names.Get(i).As<Napi::String>().Utf8Value();
106-
if (!GetArrayBufferView(jsNamedOutputs.Get(name), arrayBuffer)) {
107-
return false;
108-
}
109-
namedOutputs[name] = {arrayBuffer, {}};
110-
}
111-
return true;
23+
wnn::Graph Graph::GetImpl() {
24+
return mImpl;
11225
}
11326

11427
Napi::FunctionReference Graph::constructor;
11528

116-
Graph::Graph(const Napi::CallbackInfo& info) : Napi::ObjectWrap<Graph>(info) {
117-
}
118-
119-
Napi::Value Graph::Compute(const Napi::CallbackInfo& info) {
120-
// status compute(NamedInputs inputs, NamedOutputs outputs);
121-
WEBNN_NODE_ASSERT(info.Length() == 2, "The number of arguments is invalid.");
122-
std::map<std::string, Input> inputs;
123-
WEBNN_NODE_ASSERT(GetNamedInputs(info[0], inputs), "The inputs parameter is invalid.");
124-
125-
std::map<std::string, wnn::Resource> outputs;
126-
WEBNN_NODE_ASSERT(GetNamedOutputs(info[1], outputs), "The outputs parameter is invalid.");
127-
128-
wnn::NamedInputs namedInputs = wnn::CreateNamedInputs();
129-
for (auto& input : inputs) {
130-
namedInputs.Set(input.first.data(), input.second.AsPtr());
131-
}
132-
wnn::NamedOutputs namedOutputs = wnn::CreateNamedOutputs();
133-
for (auto& output : outputs) {
134-
namedOutputs.Set(output.first.data(), &output.second);
135-
}
136-
mImpl.Compute(namedInputs, namedOutputs);
137-
138-
return Napi::Number::New(info.Env(), 0);
139-
}
140-
14129
Napi::Object Graph::Initialize(Napi::Env env, Napi::Object exports) {
14230
Napi::HandleScope scope(env);
143-
Napi::Function func = DefineClass(
144-
env, "MLGraph", {InstanceMethod("compute", &Graph::Compute, napi_enumerable)});
31+
Napi::Function func = DefineClass(env, "MLGraph", {});
14532
constructor = Napi::Persistent(func);
14633
constructor.SuppressDestruct();
14734
exports.Set("MLGraph", func);

0 commit comments

Comments
 (0)