Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions DEPS
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ gclient_gn_args = [

vars = {
'chromium_git': 'https://chromium.googlesource.com',
'dawn_git': 'https://dawn.googlesource.com',
# 'dawn_git': 'https://github.com/fujunwei',
'dawn_git': 'https://github.com/lisa0314',
'github_git': 'https://github.com',

'dawn_standalone': True,
Expand Down Expand Up @@ -45,9 +46,15 @@ deps = {

# Dependencies required for code generator and infrastructure code.
'third_party/dawn': {
'url': '{dawn_git}/dawn.git@bf1c0cf52377b4db2bf3a433dc5056620aad7cdd'
# 'url': '{dawn_git}/dawn.git@f4c84e239bf8b5b2c4733d68ca38e1e9049fd895'
'url': '{dawn_git}/dawn.git@5e6f6fbfcb038e7a0f7857cda186a8771c6eba05'
},

'third_party/abseil-cpp': {
'url': '{chromium_git}/chromium/src/third_party/abseil-cpp@789af048b388657987c59d4da406859034fe310f',
'condition': 'dawn_standalone',
},

# Dependencies required for backends.
'third_party/DirectML': {
'url': '{github_git}/microsoft/DirectML.git@c3f16a701beeeefc9ce5b67c71b554a6903c0f67',
Expand Down Expand Up @@ -136,7 +143,7 @@ deps = {

# Jinja2 and MarkupSafe for the code generator
'third_party/jinja2': {
'url': '{chromium_git}/chromium/src/third_party/jinja2@a82a4944a7f2496639f34a89c9923be5908b80aa',
'url': '{chromium_git}/chromium/src/third_party/jinja2@ee69aa00ee8536f61db6a451f3858745cf587de6',
'condition': 'dawn_standalone',
},
'third_party/markupsafe': {
Expand Down
3 changes: 2 additions & 1 deletion build_overrides/webnn.gni
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
webnn_standalone = true

# The paths to WebNN's dependencies
webnn_dawn_root = "//third_party/dawn"
webnn_abseil_dir = "//third_party/abseil-cpp"
dawn_root = "//third_party/dawn"
webnn_googletest_dir = "//third_party/googletest"
webnn_jinja2_dir = "//third_party/jinja2"
webnn_gpgmm_dir = "//third_party/gpgmm"
2 changes: 1 addition & 1 deletion examples/LeNet/LeNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ wnn::Graph LeNet::Build(const std::string& weigthsPath) {
return nullptr;
}

const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(mContext);
const wnn::GraphBuilder builder = utils::CreateGraphBuilder(mContext);

uint32_t byteOffset = 0;
const wnn::Operand input = utils::BuildInput(builder, "input", {1, 1, 28, 28});
Expand Down
3 changes: 2 additions & 1 deletion examples/MobileNetV2/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ int main(int argc, const char* argv[]) {
}
},
&mobilevetv2);
wnn::GraphBuilder builder = wnn::CreateGraphBuilder(context);

wnn::GraphBuilder builder = utils::CreateGraphBuilder(context);
wnn::Operand output = mobilevetv2.mLayout == "nchw" ? mobilevetv2.LoadNchw(builder)
: mobilevetv2.LoadNhwc(builder);

Expand Down
2 changes: 1 addition & 1 deletion examples/ResNet/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ int main(int argc, const char* argv[]) {
}
},
&resnet);
wnn::GraphBuilder builder = wnn::CreateGraphBuilder(context);
wnn::GraphBuilder builder = utils::CreateGraphBuilder(context);
wnn::Operand output =
resnet.mLayout == "nchw" ? resnet.LoadNchw(builder) : resnet.LoadNhwc(builder);

Expand Down
47 changes: 17 additions & 30 deletions examples/SampleUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ static webnn::wire::WireClient* wireClient = nullptr;
static utils::TerribleCommandBuffer* c2sBuf = nullptr;
static utils::TerribleCommandBuffer* s2cBuf = nullptr;

static wnn::Instance clientInstance;
static wnn::Instance instance;
static std::unique_ptr<webnn::native::Instance> nativeInstance;
wnn::Context CreateCppContext(wnn::ContextOptions const* options) {
nativeInstance = std::make_unique<webnn::native::Instance>();
Expand All @@ -62,11 +62,14 @@ wnn::Context CreateCppContext(wnn::ContextOptions const* options) {
// Choose whether to use the backend procs and context directly, or set up the wire.
WNNContext context = nullptr;
WebnnProcTable procs;
WNNInstance wnnInstance;


switch (cmdBufType) {
case CmdBufType::None:
procs = backendProcs;
context = backendContext;
wnnInstance = nativeInstance->Get();
break;

case CmdBufType::Terrible: {
Expand Down Expand Up @@ -94,23 +97,23 @@ wnn::Context CreateCppContext(wnn::ContextOptions const* options) {

context = contextReservation.context;
#else
webnnProcSetProcs(&procs);
auto instanceReservation = wireClient->ReserveInstance();
wireServer->InjectInstance(nativeInstance->Get(), instanceReservation.id,
instanceReservation.generation);
// Keep the reference instread of using Acquire.
// TODO:: make the instance in the client as singleton object.
clientInstance = wnn::Instance(instanceReservation.instance);
return clientInstance.CreateContext(options);
wnnInstance = instanceReservation.instance;
break;
#endif
}
default:
dawn::ErrorLog() << "Invaild CmdBufType";
DAWN_ASSERT(0);
}
webnnProcSetProcs(&procs);

return wnn::Context::Acquire(context);
instance = wnn::Instance(wnnInstance);
return instance.CreateContext(options);
;
}

void DoFlush() {
Expand All @@ -123,35 +126,15 @@ void DoFlush() {
}

wnn::NamedInputs CreateCppNamedInputs() {
#if defined(WEBNN_ENABLE_WIRE)
return clientInstance.CreateNamedInputs();
#else
return wnn::CreateNamedInputs();
#endif // defined(WEBNN_ENABLE_WIRE)
}

wnn::NamedOperands CreateCppNamedOperands() {
#if defined(WEBNN_ENABLE_WIRE)
return clientInstance.CreateNamedOperands();
#else
return wnn::CreateNamedOperands();
#endif // defined(WEBNN_ENABLE_WIRE)
return instance.CreateNamedInputs();
}

wnn::NamedOutputs CreateCppNamedOutputs() {
#if defined(WEBNN_ENABLE_WIRE)
return clientInstance.CreateNamedOutputs();
#else
return wnn::CreateNamedOutputs();
#endif // defined(WEBNN_ENABLE_WIRE)
return instance.CreateNamedOutputs();
}

wnn::OperatorArray CreateCppOperatorArray() {
#if defined(WEBNN_ENABLE_WIRE)
return clientInstance.CreateOperatorArray();
#else
return wnn::CreateOperatorArray();
#endif // defined(WEBNN_ENABLE_WIRE)
return instance.CreateOperatorArray();
}

bool ExampleBase::ParseAndCheckExampleOptions(int argc, const char* argv[]) {
Expand Down Expand Up @@ -264,6 +247,10 @@ namespace utils {
return activationOperand;
}

wnn::GraphBuilder CreateGraphBuilder(const wnn::Context& context) {
return instance.CreateGraphBuilder(context);
}

wnn::Operand BuildInput(const wnn::GraphBuilder& builder,
std::string name,
const std::vector<int32_t>& dimensions,
Expand All @@ -283,7 +270,7 @@ namespace utils {
}

wnn::Graph Build(const wnn::GraphBuilder& builder, const std::vector<NamedOperand>& outputs) {
wnn::NamedOperands namedOperands = CreateCppNamedOperands();
wnn::NamedOperands namedOperands = instance.CreateNamedOperands();
for (auto& output : outputs) {
namedOperands.Set(output.name.c_str(), output.operand);
}
Expand Down
5 changes: 3 additions & 2 deletions examples/SampleUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ namespace utils {
FusedActivation activation,
const void* options = nullptr);

wnn::GraphBuilder CreateGraphBuilder(const wnn::Context& context);
wnn::Operand BuildInput(const wnn::GraphBuilder& builder,
std::string name,
const std::vector<int32_t>& dimensions,
Expand Down Expand Up @@ -247,7 +248,7 @@ namespace utils {
void Compute(const wnn::Graph& graph,
const std::vector<NamedInput<T>>& inputs,
const std::vector<NamedOutput<T>>& outputs) {
if (graph.GetHandle() == nullptr) {
if (graph.Get() == nullptr) {
dawn::ErrorLog() << "The graph is invaild.";
}

Expand All @@ -272,7 +273,7 @@ namespace utils {
resource.arrayBufferView.buffer = output.resource.data();
resource.arrayBufferView.byteLength = output.resource.size() * sizeof(float);
mlOutputs.push_back(resource);
namedOutputs.Set(output.name.c_str(), &mlOutputs.back());
namedOutputs.SetOutput(output.name.c_str(), &mlOutputs.back());
}
graph.Compute(namedInputs, namedOutputs);
DoFlush();
Expand Down
2 changes: 1 addition & 1 deletion examples/SqueezeNet/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ int main(int argc, const char* argv[]) {
}
},
&squeezenet);
wnn::GraphBuilder builder = wnn::CreateGraphBuilder(context);
wnn::GraphBuilder builder = utils::CreateGraphBuilder(context);
wnn::Operand output =
squeezenet.mLayout == "nchw" ? squeezenet.LoadNchw(builder) : squeezenet.LoadNhwc(builder);

Expand Down
22 changes: 13 additions & 9 deletions generator/webnn_generator.gni
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import("//third_party/dawn/generator/generator_lib.gni")
import("//third_party/dawn/generator/dawn_generator.gni")
import("../scripts/webnn_overrides_with_defaults.gni")

# Dawn used to put autogenerated files in a lot of different places. When we
Expand All @@ -39,12 +39,16 @@ import("../scripts/webnn_overrides_with_defaults.gni")
# disallowed gen directories.

webnn_allowed_gen_output_dirs = [
"src/dawn/",
"src/dawn/native/",
"src/webnn/",
"src/webnn/native/",
"src/webnn/wire/client/",
"src/webnn/wire/server/",
"src/webnn/wire/",
"include/webnn/",
"emscripten-bits/",
"include/dawn/",
]

# Template to help invoking Dawn code generators based on generator_lib
Expand Down Expand Up @@ -77,7 +81,7 @@ template("webnn_generator") {
forward_variables_from(invoker, "*")

# Set arguments required to find the python libraries for the generator
generator_lib_dir = "${webnn_dawn_root}/generator"
generator_lib_dir = "${dawn_root}/generator"
jinja2_path = webnn_jinja2_dir

# Force Dawn's autogenerated file structure to mirror exactly the source
Expand All @@ -87,37 +91,37 @@ template("webnn_generator") {

# Make sure that we delete stale autogenerated file in directories that are
# no longer used by code generation to avoid include conflicts.
deps = [ "${webnn_root}/generator:remove_stale_autogen_files" ]
deps = [ "${dawn_root}/generator:remove_stale_autogen_files" ]

template_dir = "${dawn_root}/generator/templates"

template_dir = "${webnn_root}/generator/templates"
}
}

# Helper generator for calling the generator from webnn.json
#
# dawn_json_generator("my_target_gen") {
# webnn_json_generator("my_target_gen") {
# # Which generator target to output
# target = "my_target"
#
# # Also supports `outputs` and `custom_gen_dir` like dawn_generator.
# }
template("webnn_json_generator") {
webnn_generator(target_name) {
script = "${webnn_root}/generator/webnn_json_generator.py"
script = "${dawn_root}/generator/dawn_json_generator.py"

# The base arguments for the generator: from this webnn.json, generate this
# target using templates in this directory.
args = [
"--webnn-json",
"--dawn-json",
rebase_path("${webnn_root}/webnn.json", root_build_dir),
"--wire-json",
rebase_path("${webnn_root}/webnn_wire.json", root_build_dir),
"--targets",
invoker.target,
"--dawn-generator-path",
rebase_path("${webnn_root}/third_party/dawn/generator"),
]

forward_variables_from(invoker, "*", [ "target" ])

}
}
21 changes: 10 additions & 11 deletions include/webnn/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import("../../scripts/webnn_overrides_with_defaults.gni")

import("${webnn_dawn_root}/scripts/dawn_component.gni")
import("${dawn_root}/scripts/dawn_component.gni")
import("${webnn_root}/generator/webnn_generator.gni")

###############################################################################
Expand All @@ -25,8 +25,8 @@ import("${webnn_root}/generator/webnn_generator.gni")
webnn_json_generator("headers_gen") {
target = "headers"
outputs = [
"include/webnn/webnn_proc_table.h",
"include/webnn/webnn.h",
"include/dawn/webnn_proc_table.h",
"include/dawn/webnn.h",
]
}

Expand All @@ -43,7 +43,10 @@ source_set("headers") {

webnn_json_generator("cpp_headers_gen") {
target = "cpp_headers"
outputs = [ "include/webnn/webnn_cpp.h" ]
outputs = [
"include/dawn/webnn_cpp.h",
"include/dawn/webnn_cpp_print.h",
]
}

source_set("cpp_headers") {
Expand All @@ -63,18 +66,14 @@ config("public") {
include_dirs = [
"${target_gen_dir}/../../include",
"${webnn_root}/include",
"${dawn_root}/include",
"${dawn_gen_root}/include",
]

if (build_with_chromium) {
include_dirs += [
"${webnn_dawn_root}/include",
"${dawn_root}/include",
"${dawn_gen_root}/include",
]
} else {
# TODO: Remove after upgrading webnn infranstructure align with dawn.
include_dirs += [
"${webnn_dawn_root}/src/include",
"${dawn_gen_root}/src/include",
]
}
}
Loading