|
13 | 13 | // limitations under the License. |
14 | 14 |
|
15 | 15 | #include "Graph.h" |
16 | | - |
17 | | -#include <iostream> |
18 | | -#include <map> |
19 | | - |
20 | 16 | #include "Utils.h" |
21 | 17 |
|
22 | 18 | namespace node { |
23 | 19 |
|
24 | | - struct Input { |
25 | | - public: |
26 | | - wnn::ArrayBufferView bufferView; |
27 | | - std::vector<int32_t> dimensions; |
28 | | - |
29 | | - const wnn::Input* AsPtr() { |
30 | | - mInput.resource.arrayBufferView = bufferView; |
31 | | - mInput.resource.gpuBufferView = {}; |
32 | | - if (!dimensions.empty()) { |
33 | | - mInput.dimensions = dimensions.data(); |
34 | | - mInput.dimensionsCount = dimensions.size(); |
35 | | - } |
36 | | - return &mInput; |
37 | | - } |
38 | | - |
39 | | - private: |
40 | | - wnn::Input mInput; |
41 | | - }; |
42 | | - |
43 | | - bool GetNamedInputs(const Napi::Value& jsValue, std::map<std::string, Input>& namedInputs) { |
44 | | - if (!jsValue.IsObject()) { |
45 | | - return false; |
46 | | - } |
47 | | - Napi::Object jsNamedInputs = jsValue.As<Napi::Object>(); |
48 | | - Napi::Array names = jsNamedInputs.GetPropertyNames(); |
49 | | - if (names.Length() == 0) { |
50 | | - return false; |
51 | | - } |
52 | | - // typedef (MLBufferView or WebGLTexture or GPUTexture) MLResource; |
53 | | - // dictionary MLInput { |
54 | | - // required MLResource resource; |
55 | | - // required sequence<long> dimensions; |
56 | | - // }; |
57 | | - // typedef record<DOMString, (MLResource or MLInput)> MLNamedInputs; |
58 | | - for (size_t i = 0; i < names.Length(); ++i) { |
59 | | - Input input = {}; |
60 | | - std::string name = names.Get(i).As<Napi::String>().Utf8Value(); |
61 | | - // FIXME: validate the type of typed array. |
62 | | - Napi::TypedArray jsTypedArray; |
63 | | - if (jsNamedInputs.Get(name).IsTypedArray()) { |
64 | | - jsTypedArray = jsNamedInputs.Get(name).As<Napi::TypedArray>(); |
65 | | - } else { |
66 | | - Napi::Object jsInput = jsNamedInputs.Get(name).As<Napi::Object>(); |
67 | | - if (!jsInput.Has("resource") || !jsInput.Has("dimensions")) { |
68 | | - // Input resource and dimensions are required. |
69 | | - return false; |
70 | | - } |
71 | | - if (!jsInput.Get("resource").IsTypedArray()) { |
72 | | - return false; |
73 | | - } |
74 | | - jsTypedArray = jsInput.Get("resource").As<Napi::TypedArray>(); |
75 | | - |
76 | | - if (!GetArray(jsInput.Get("dimensions"), input.dimensions)) { |
77 | | - return false; |
78 | | - } |
79 | | - if (SizeOfShape(input.dimensions) != jsTypedArray.ElementSize()) { |
80 | | - return false; |
81 | | - } |
82 | | - } |
83 | | - if (!GetArrayBufferView(jsTypedArray, input.bufferView)) { |
84 | | - return false; |
85 | | - } |
86 | | - namedInputs[name] = input; |
87 | | - } |
88 | | - return true; |
| 20 | + Graph::Graph(const Napi::CallbackInfo& info) : Napi::ObjectWrap<Graph>(info) { |
89 | 21 | } |
90 | 22 |
|
91 | | - bool GetNamedOutputs(const Napi::Value& jsValue, |
92 | | - std::map<std::string, wnn::Resource>& namedOutputs) { |
93 | | - if (!jsValue.IsObject()) { |
94 | | - return false; |
95 | | - } |
96 | | - Napi::Object jsNamedOutputs = jsValue.As<Napi::Object>(); |
97 | | - Napi::Array names = jsNamedOutputs.GetPropertyNames(); |
98 | | - if (names.Length() == 0) { |
99 | | - return false; |
100 | | - } |
101 | | - // typedef (MLBufferView or WebGLTexture or GPUTexture) MLResource; |
102 | | - // typedef record<DOMString, MLResource> MLNamedOutputs; |
103 | | - for (size_t i = 0; i < names.Length(); ++i) { |
104 | | - wnn::ArrayBufferView arrayBuffer = {}; |
105 | | - std::string name = names.Get(i).As<Napi::String>().Utf8Value(); |
106 | | - if (!GetArrayBufferView(jsNamedOutputs.Get(name), arrayBuffer)) { |
107 | | - return false; |
108 | | - } |
109 | | - namedOutputs[name] = {arrayBuffer, {}}; |
110 | | - } |
111 | | - return true; |
| 23 | + wnn::Graph Graph::GetImpl() { |
| 24 | + return mImpl; |
112 | 25 | } |
113 | 26 |
|
114 | 27 | Napi::FunctionReference Graph::constructor; |
115 | 28 |
|
116 | | - Graph::Graph(const Napi::CallbackInfo& info) : Napi::ObjectWrap<Graph>(info) { |
117 | | - } |
118 | | - |
119 | | - Napi::Value Graph::Compute(const Napi::CallbackInfo& info) { |
120 | | - // status compute(NamedInputs inputs, NamedOutputs outputs); |
121 | | - WEBNN_NODE_ASSERT(info.Length() == 2, "The number of arguments is invalid."); |
122 | | - std::map<std::string, Input> inputs; |
123 | | - WEBNN_NODE_ASSERT(GetNamedInputs(info[0], inputs), "The inputs parameter is invalid."); |
124 | | - |
125 | | - std::map<std::string, wnn::Resource> outputs; |
126 | | - WEBNN_NODE_ASSERT(GetNamedOutputs(info[1], outputs), "The outputs parameter is invalid."); |
127 | | - |
128 | | - wnn::NamedInputs namedInputs = wnn::CreateNamedInputs(); |
129 | | - for (auto& input : inputs) { |
130 | | - namedInputs.Set(input.first.data(), input.second.AsPtr()); |
131 | | - } |
132 | | - wnn::NamedOutputs namedOutputs = wnn::CreateNamedOutputs(); |
133 | | - for (auto& output : outputs) { |
134 | | - namedOutputs.Set(output.first.data(), &output.second); |
135 | | - } |
136 | | - mImpl.Compute(namedInputs, namedOutputs); |
137 | | - |
138 | | - return Napi::Number::New(info.Env(), 0); |
139 | | - } |
140 | | - |
141 | 29 | Napi::Object Graph::Initialize(Napi::Env env, Napi::Object exports) { |
142 | 30 | Napi::HandleScope scope(env); |
143 | | - Napi::Function func = DefineClass( |
144 | | - env, "MLGraph", {InstanceMethod("compute", &Graph::Compute, napi_enumerable)}); |
| 31 | + Napi::Function func = DefineClass(env, "MLGraph", {}); |
145 | 32 | constructor = Napi::Persistent(func); |
146 | 33 | constructor.SuppressDestruct(); |
147 | 34 | exports.Set("MLGraph", func); |
|
0 commit comments