Skip to content

Commit 89f3982

Browse files
committed
Create and initialize Device object: the first PR for rewriting DML backend
1 parent 294de6b commit 89f3982

File tree

8 files changed

+404
-0
lines changed

8 files changed

+404
-0
lines changed

src/webnn/native/BUILD.gn

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,17 @@ source_set("sources") {
210210
}
211211
}
212212

213+
if (webnn_enable_dml) {
214+
sources += [
215+
"dml/BackendDML.cpp",
216+
"dml/BackendDML.h",
217+
"dml/ContextDML.cpp",
218+
"dml/ContextDML.h",
219+
"dml/GraphDML.cpp",
220+
"dml/GraphDML.h",
221+
]
222+
}
223+
213224
if (webnn_enable_dmlx) {
214225
if (webnn_enable_gpu_buffer == false) {
215226
sources += [
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// Copyright 2019 The Dawn Authors
2+
// Copyright 2022 The WebNN-native Authors
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#include "webnn/native/dml/BackendDML.h"
17+
18+
#include "webnn/native/Instance.h"
19+
#include "webnn/native/dml/ContextDML.h"
20+
21+
namespace webnn::native::dml {
22+
23+
namespace {
24+
HRESULT EnumAdapter(DXGI_GPU_PREFERENCE gpuPreference,
25+
bool useGpu,
26+
ComPtr<IDXGIAdapter1> adapter) {
27+
ComPtr<IDXGIFactory6> dxgiFactory;
28+
WEBNN_RETURN_IF_FAILED(CreateDXGIFactory1(IID_PPV_ARGS(&dxgiFactory)));
29+
if (useGpu) {
30+
UINT adapterIndex = 0;
31+
while (dxgiFactory->EnumAdapterByGpuPreference(adapterIndex++, gpuPreference,
32+
IID_PPV_ARGS(&adapter)) !=
33+
DXGI_ERROR_NOT_FOUND) {
34+
DXGI_ADAPTER_DESC1 pDesc;
35+
adapter->GetDesc1(&pDesc);
36+
// An adapter called the "Microsoft Basic Render Driver" is always present.
37+
// This adapter is a render-only device that has no display outputs. See here
38+
// for documentation on filtering WARP adapter:
39+
// https://docs.microsoft.com/en-us/windows/desktop/direct3ddxgi/d3d10-graphics-programming-guide-dxgi#new-info-about-enumerating-adapters-for-windows-8
40+
bool isSoftwareAdapter = pDesc.Flags == DXGI_ADAPTER_FLAG_SOFTWARE ||
41+
(pDesc.VendorId == 0x1414 && pDesc.DeviceId == 0x8c);
42+
if (!isSoftwareAdapter) {
43+
break;
44+
}
45+
}
46+
} else {
47+
WEBNN_RETURN_IF_FAILED(dxgiFactory->EnumWarpAdapter(IID_PPV_ARGS(&adapter)));
48+
}
49+
return S_OK;
50+
}
51+
52+
} // namespace
53+
54+
Backend::Backend(InstanceBase* instance)
55+
: BackendConnection(instance, wnn::BackendType::DirectML) {
56+
}
57+
58+
MaybeError Backend::Initialize() {
59+
return {};
60+
}
61+
62+
ContextBase* Backend::CreateContext(ContextOptions const* options) {
63+
wnn::DevicePreference devicePreference =
64+
options == nullptr ? wnn::DevicePreference::Default : options->devicePreference;
65+
bool useGpu = devicePreference == wnn::DevicePreference::Cpu ? false : true;
66+
DXGI_GPU_PREFERENCE gpuPreference = DXGI_GPU_PREFERENCE_UNSPECIFIED;
67+
wnn::PowerPreference powerPreference =
68+
options == nullptr ? wnn::PowerPreference::Default : options->powerPreference;
69+
switch (powerPreference) {
70+
case wnn::PowerPreference::High_performance:
71+
gpuPreference = DXGI_GPU_PREFERENCE::DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE;
72+
break;
73+
case wnn::PowerPreference::Low_power:
74+
gpuPreference = DXGI_GPU_PREFERENCE::DXGI_GPU_PREFERENCE_MINIMUM_POWER;
75+
break;
76+
default:
77+
break;
78+
}
79+
80+
bool useDebugLayer = false;
81+
#ifdef _DEBUG
82+
useDebugLayer = true;
83+
#endif
84+
ComPtr<IDXGIAdapter1> adapter;
85+
if (FAILED(EnumAdapter(gpuPreference, useGpu, adapter))) {
86+
dawn::ErrorLog() << "Failed to enumerate adapters for creating the context.";
87+
return nullptr;
88+
}
89+
return Context::Create(adapter, useDebugLayer);
90+
}
91+
92+
BackendConnection* Connect(InstanceBase* instance) {
93+
Backend* backend = new Backend(instance);
94+
95+
if (instance->ConsumedError(backend->Initialize())) {
96+
delete backend;
97+
return nullptr;
98+
}
99+
100+
return backend;
101+
}
102+
103+
} // namespace webnn::native::dml

src/webnn/native/dml/BackendDML.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2019 The Dawn Authors
2+
// Copyright 2022 The WebNN-native Authors
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#ifndef WEBNN_NATIVE_DML_BACKEND_DML_H_
17+
#define WEBNN_NATIVE_DML_BACKEND_DML_H_
18+
19+
#include "webnn/native/BackendConnection.h"
20+
#include "webnn/native/Context.h"
21+
22+
namespace webnn::native::dml {
23+
24+
class Backend : public BackendConnection {
25+
public:
26+
Backend(InstanceBase* instance);
27+
28+
MaybeError Initialize();
29+
ContextBase* CreateContext(ContextOptions const* options = nullptr) override;
30+
};
31+
32+
} // namespace webnn::native::dml
33+
34+
#endif // WEBNN_NATIVE_DML_BACKEND_DML_H_
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Copyright 2022 The WebNN-native Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "webnn/native/dml/ContextDML.h"
16+
17+
#include "webnn/native/dml/GraphDML.h"
18+
19+
namespace webnn::native::dml {
20+
21+
HRESULT Context::Initialize() {
22+
if (mUseDebugLayer) {
23+
ComPtr<ID3D12Debug> debug;
24+
if (SUCCEEDED(D3D12GetDebugInterface(IID_PPV_ARGS(&debug)))) {
25+
debug->EnableDebugLayer();
26+
}
27+
}
28+
WEBNN_RETURN_IF_FAILED(D3D12CreateDevice(mAdapter.Get(), D3D_FEATURE_LEVEL_11_0,
29+
IID_PPV_ARGS(&mCommandRecorderDML.D3D12Device)));
30+
D3D12_COMMAND_QUEUE_DESC commandQueueDesc{};
31+
commandQueueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
32+
commandQueueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
33+
WEBNN_RETURN_IF_FAILED(mCommandRecorderDML.D3D12Device->CreateCommandQueue(
34+
&commandQueueDesc, IID_PPV_ARGS(&mCommandRecorderDML.commandQueue)));
35+
WEBNN_RETURN_IF_FAILED(mCommandRecorderDML.D3D12Device->CreateCommandAllocator(
36+
D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&mCommandRecorderDML.commandAllocator)));
37+
WEBNN_RETURN_IF_FAILED(mCommandRecorderDML.D3D12Device->CreateCommandList(
38+
0, D3D12_COMMAND_LIST_TYPE_DIRECT, mCommandRecorderDML.commandAllocator.Get(), nullptr,
39+
IID_PPV_ARGS(&mCommandRecorderDML.commandList)));
40+
41+
// Create the DirectML device.
42+
ComPtr<ID3D12DebugDevice> debugDevice;
43+
if (mUseDebugLayer && SUCCEEDED(mCommandRecorderDML.D3D12Device.As(&debugDevice))) {
44+
WEBNN_RETURN_IF_FAILED(DMLCreateDevice(mCommandRecorderDML.D3D12Device.Get(),
45+
DML_CREATE_DEVICE_FLAG_DEBUG,
46+
IID_PPV_ARGS(&mCommandRecorderDML.device)));
47+
} else {
48+
WEBNN_RETURN_IF_FAILED(DMLCreateDevice(mCommandRecorderDML.D3D12Device.Get(),
49+
DML_CREATE_DEVICE_FLAG_NONE,
50+
IID_PPV_ARGS(&mCommandRecorderDML.device)));
51+
}
52+
return S_OK;
53+
};
54+
55+
// static
56+
ContextBase* Context::Create(ComPtr<IDXGIAdapter1> adapter, bool useDebugLayer) {
57+
Context* context = new Context(adapter, useDebugLayer);
58+
if (FAILED(context->Initialize())) {
59+
dawn::ErrorLog() << "Failed to initialize Device.";
60+
delete context;
61+
return nullptr;
62+
}
63+
return context;
64+
}
65+
66+
Context::Context(ComPtr<IDXGIAdapter1> adapter, bool useDebugLayer)
67+
: mAdapter(std::move(adapter)), mUseDebugLayer(useDebugLayer) {
68+
}
69+
70+
GraphBase* Context::CreateGraphImpl() {
71+
return new Graph(this);
72+
}
73+
74+
} // namespace webnn::native::dml

src/webnn/native/dml/ContextDML.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright 2022 The WebNN-native Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef WEBNN_NATIVE_DML_CONTEXT_DML_H_
16+
#define WEBNN_NATIVE_DML_CONTEXT_DML_H_
17+
18+
#include "webnn/native/Context.h"
19+
20+
#include "common/Log.h"
21+
#include "dml_platform.h"
22+
#include "webnn/native/Graph.h"
23+
24+
namespace webnn::native::dml {
25+
26+
struct CommandRecorderDML {
27+
ComPtr<IDMLDevice> device;
28+
ComPtr<ID3D12Device> D3D12Device;
29+
ComPtr<IDMLCommandRecorder> commandRecorder;
30+
ComPtr<ID3D12CommandQueue> commandQueue;
31+
ComPtr<ID3D12CommandAllocator> commandAllocator;
32+
ComPtr<ID3D12GraphicsCommandList> commandList;
33+
};
34+
35+
class Context : public ContextBase {
36+
public:
37+
static ContextBase* Create(ComPtr<IDXGIAdapter1> adapter, bool useDebugLayer);
38+
~Context() override = default;
39+
40+
private:
41+
Context(ComPtr<IDXGIAdapter1> adapter, bool useDebugLayer);
42+
HRESULT Initialize();
43+
44+
GraphBase* CreateGraphImpl() override;
45+
46+
CommandRecorderDML mCommandRecorderDML;
47+
ComPtr<IDXGIAdapter1> mAdapter;
48+
bool mUseDebugLayer = false;
49+
};
50+
51+
} // namespace webnn::native::dml
52+
53+
#endif // WEBNN_NATIVE_DML_CONTEXT_DML_H_

src/webnn/native/dml/GraphDML.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright 2022 The WebNN-native Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "webnn/native/dml/GraphDML.h"
16+
17+
#include "webnn/native/NamedInputs.h"
18+
#include "webnn/native/NamedOutputs.h"
19+
20+
namespace webnn::native ::dml {
21+
22+
Graph::Graph(Context* context) : GraphBase(context) {
23+
}
24+
25+
MaybeError Graph::CompileImpl() {
26+
return {};
27+
}
28+
29+
MaybeError Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) {
30+
return {};
31+
}
32+
33+
} // namespace webnn::native::dml

src/webnn/native/dml/GraphDML.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright 2022 The WebNN-native Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef WEBNN_NATIVE_DML_GRAPH_DML_H_
16+
#define WEBNN_NATIVE_DML_GRAPH_DML_H_
17+
18+
#include "webnn/native/Graph.h"
19+
#include "webnn/native/Operand.h"
20+
#include "webnn/native/Operator.h"
21+
#include "webnn/native/dml/ContextDML.h"
22+
#include "webnn/native/ops/BatchNorm.h"
23+
#include "webnn/native/ops/Binary.h"
24+
#include "webnn/native/ops/Clamp.h"
25+
#include "webnn/native/ops/Concat.h"
26+
#include "webnn/native/ops/Constant.h"
27+
#include "webnn/native/ops/Conv2d.h"
28+
#include "webnn/native/ops/Gemm.h"
29+
#include "webnn/native/ops/Gru.h"
30+
#include "webnn/native/ops/Input.h"
31+
#include "webnn/native/ops/InstanceNorm.h"
32+
#include "webnn/native/ops/LeakyRelu.h"
33+
#include "webnn/native/ops/Pad.h"
34+
#include "webnn/native/ops/Pool2d.h"
35+
#include "webnn/native/ops/Reduce.h"
36+
#include "webnn/native/ops/Resample2d.h"
37+
#include "webnn/native/ops/Reshape.h"
38+
#include "webnn/native/ops/Slice.h"
39+
#include "webnn/native/ops/Split.h"
40+
#include "webnn/native/ops/Squeeze.h"
41+
#include "webnn/native/ops/Transpose.h"
42+
#include "webnn/native/ops/Unary.h"
43+
44+
namespace webnn::native::dml {
45+
46+
class Graph : public GraphBase {
47+
public:
48+
explicit Graph(Context* context);
49+
~Graph() override = default;
50+
51+
private:
52+
MaybeError CompileImpl() override;
53+
MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override;
54+
};
55+
56+
} // namespace webnn::native::dml
57+
58+
#endif // WEBNN_NATIVE_DML_GRAPH_DML_H_

0 commit comments

Comments
 (0)