Skip to content

Commit 5c4f5bb

Browse files
fs-eireskottmckay
andauthored
[js/web] WebGPU backend via JSEP (microsoft#14579)
### Description This change introduced the following new components into ONNX Runtime Web: - JavaScript Execution Provider (JSEP) - Asynchronized inferencing execution powered by Emscripten's Asyncify - WebGPU backend implemented in TypeScript - initial implementation of kernels: - elementwise operators (22) - binary operators (5) - tensor: Shape, Reshape, Transpose, Gemm - nn: Conv, {Global}Maxpool, {Global}AveragePool Code need to be polished. still working on it. ## Q&A What is JSEP? > JSEP, aka JavaScript Execution Provider, is a new ONNXRuntime execution provider that specifically works on Web environment (browsers). JSEP allows JavaScript code to kick in from various places when ONNX Runtime inferences a model. Why JSEP? > JSEP is a hybrid mode EP that contains both C/C++ and TypeScript/JavaScript implementation. There are 2 strong reasons why we introduces JSEP: > 1. the C/C++ part helps JSEP to leverage ONNX Runtime's capabilities as much as possible including graph transformer, optimizers and also the capabilities to fallback to CPU EP. TypeScript/JavaScript helps JSEP to develop and debug much easier in the browser for the kernel implementation. > 2. the requirement of asynchronized execution from JavaScript API (eg. `buffer.mapAsync()`) makes it impossible to run `OrtRun()` in a synchronized context (see "async problem" section below). This is done by using Emscripten's Asyncify. What is WebGPU? > WebGPU is the new GPU API that available in browser. It's one of the only 2 APIs that currently available to access the GPU from browser (the other is WebGL). > WebGPU is designed with more advanced and stronger features comparing to WebGL and is potentially solution that offer the best GPU performance for model inferencing that currently available. What is the async problem and why we have the problem? > The "async problem" is a problem that you cannot call an async function in a synchronous context. Think about the following C++ code: > ```c > // C-style declarations (API) > typedef void (*ON_COMPLETE)(PVOID state, DATA *data); > void read_data_from_file(FILEHANDLE file, ON_COMPLETE on_complete); > > // implementation > DATA * my_impl_read_data_from_file_sync(FILEHANDLE file) { > // how to implement? > } > ``` > The answer is, it's impossible to implement this function. Usually we try to find a sync version API, or launch a thread to call the async function and sync-wait on the main thread. Unfortunately, in browser environment, neither is possible. > > WebGPU does not offer any synchronized API for data downloading (GPU to CPU). This is the only operation that MUST be async. As `OrtRun()` will eventually call into DataTransfer for copy data from GPU to CPU, and `OrtRun()` is a synchronized function, this cannot be done in normal way. What is Emscripten? How is the Asyncify feature resolved the problem? > Emscripten is the C/C++ compiler for WebAssembly. It's what we use to compile ORT and generates the WebAssembly artifacts which runs on browsers. > > Asyncify is a [compiler feature](https://emscripten.org/docs/porting/asyncify.html) that allows calling async functions from a synchronized context. In short, it generates code to unwind and rewind call stack to emulate async execution. With this feature, we are able to call the async function inside `OrtRun()` call. ## Design Overview **Inter-op** JSEP is doing pretty much same thing to just another EP. It exposes an interface for inter-op with JavaScript, which is defined in onnxruntime/wasm/js_internal_api.js: ```js // init JSEP Module["jsepInit"] = function (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, run) { Module.jsepBackend = backend; Module.jsepAlloc = alloc; Module.jsepFree = free; Module.jsepCopy = copy; Module.jsepCopyAsync = copyAsync; Module.jsepCreateKernel = createKernel; Module.jsepReleaseKernel = releaseKernel; Module.jsepRun = run; }; ``` This simple JavaScript snippet defines all language barrier level functions that requires by JSEP to achieve implementing kernels and data transfers using JavaScript inside ONNX Runtime: - `jsepBackend`: assign the singleton object to webassembly module - `jsepAlloc` and `jsepFree`: implementation of data transfer's Alloc() and Free() - `jsepCopy`: synchronized copy ( GPU to GPU, CPU to GPU) - `jsepCopyAsync`: asynchronized copy ( GPU to CPU) - `jsepCreateKernel` and `jsepReleaseKernel`: a corresponding object that maintained in JS to match lifecycle of Kernel in ORT - `jsepRun`: OpKernel::Compute() should call into this The abstraction above allows to tie as little as possible connections and dependencies between C/C++ and TypeScript/JavaScript. **Resource Management** Lifecycle of tensor data and kernels are managed by ORT(C/C++) but the implementation are left to JavaScript. JavaScript code are responsible to implement the callbacks correctly. For WebGPU, the GPU data is managed by JavaScript using a singleton map (tensot_data_id => GPUBuffer). GPU pipeline is managed as singleton. Shaders are managed using a singletonmap (shader_key => gpu_program), while shader_key is generated by cache_key (OP specific, including attributes) and input shapes. **about data transfer** `js::DataTransfer::CopyTensor` implemented to call either synchronized or asynchronized copy callback, depending on the destination is GPU or not. Emscripten's macro `EM_ASYNC_JS` is used to wrap the async function to be called in the synchronized context. **run kernel in JS** Kernel class constructor calls once `jsepCreateKernel()` with an optional per-kernel specific serialization to pass attributes into JavaScript. `Compute()` are implemented in a way that a metadata serialization is performed in a base class and JavaScript code can access the data using the Emscripten specific builtin macro `EM_ASM_*`. **disabled features** memory pattern is force disabled, because the WebGPU data is not presented by a general memory model (a buffer can be represented by offset + size). concurrent run support is disabled. WebGPU is stateful and it also has async function call. To support concurrent run will significantly increase the complexity and we don't get any real benefit from it. **prefer channels last** JSEP prefers channels last and returns `DataLayout::NHWC` in method `GetPreferredLayout()`. This will let the graph transformers to preprocess the graph into a channels last form so that a more optimized WebGPU shader can be used. **Testing code** It's impossible to test JSEP directly because JSEP itself does not contain any kernel implementation. However, it has the kernel registration which need to work together with the corresponding JavaScript code. There are unit tests that run onnx models from JavaScript API. --------- Co-authored-by: Scott McKay <[email protected]>
1 parent e12d44c commit 5c4f5bb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+6177
-367
lines changed

.eslintrc.js

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,12 @@ module.exports = {
182182
'import/no-extraneous-dependencies': 'off',
183183
'no-console': 'off'
184184
}
185+
}, {
186+
files: ['web/lib/**/3rd-party/**/*.ts'], rules: {
187+
'header/header': 'off',
188+
'unicorn/filename-case': 'off',
189+
'@typescript-eslint/explicit-module-boundary-types': 'off',
190+
}
185191
}],
186192
extends: [
187193
'eslint:recommended',

common/lib/env-impl.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ export class EnvImpl implements Env {
88
constructor() {
99
this.wasm = {};
1010
this.webgl = {};
11+
this.webgpu = {};
1112
this.logLevelInternal = 'warning';
1213
}
1314

@@ -28,8 +29,8 @@ export class EnvImpl implements Env {
2829
debug?: boolean;
2930

3031
wasm: Env.WebAssemblyFlags;
31-
3232
webgl: Env.WebGLFlags;
33+
webgpu: Env.WebGpuFlags;
3334

3435
[name: string]: unknown;
3536

common/lib/env.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ export declare namespace Env {
8686
*/
8787
async?: boolean;
8888
}
89+
90+
export interface WebGpuFlags {
91+
profilingMode?: 'off'|'default';
92+
}
8993
}
9094

9195
export interface Env {
@@ -112,6 +116,11 @@ export interface Env {
112116
*/
113117
webgl: Env.WebGLFlags;
114118

119+
/**
120+
* Represent a set of flags for WebGPU
121+
*/
122+
webgpu: Env.WebGpuFlags;
123+
115124
[name: string]: unknown;
116125
}
117126

web/karma.conf.js

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
const bundleMode = require('minimist')(process.argv)['bundle-mode'] || 'dev'; // 'dev'|'perf'|undefined;
77
const karmaPlugins = require('minimist')(process.argv)['karma-plugins'] || undefined;
88
const timeoutMocha = require('minimist')(process.argv)['timeout-mocha'] || 60000;
9+
const forceLocalHost = !!require('minimist')(process.argv)['force-localhost'];
910
const commonFile = bundleMode === 'dev' ? '../common/dist/ort-common.js' : '../common/dist/ort-common.min.js'
1011
const mainFile = bundleMode === 'dev' ? 'test/ort.dev.js' : 'test/ort.perf.js';
1112

@@ -16,25 +17,32 @@ const mainFile = bundleMode === 'dev' ? 'test/ort.dev.js' : 'test/ort.perf.js';
1617
// https://stackoverflow.com/a/8440736
1718
//
1819
function getMachineIpAddress() {
19-
var os = require('os');
20-
var ifaces = os.networkInterfaces();
20+
if (!forceLocalHost) {
21+
var os = require('os');
22+
var ifaces = os.networkInterfaces();
2123

22-
for (const ifname in ifaces) {
23-
for (const iface of ifaces[ifname]) {
24-
if ('IPv4' !== iface.family || iface.internal !== false) {
25-
// skip over internal (i.e. 127.0.0.1) and non-ipv4 addresses
26-
continue;
27-
}
24+
for (const ifname in ifaces) {
25+
for (const iface of ifaces[ifname]) {
26+
if ('IPv4' !== iface.family || iface.internal !== false) {
27+
// skip over internal (i.e. 127.0.0.1) and non-ipv4 addresses
28+
continue;
29+
}
2830

29-
// returns the first available IP address
30-
return iface.address;
31+
// returns the first available IP address
32+
return iface.address;
33+
}
3134
}
3235
}
3336

3437
// if no available IP address, fallback to "localhost".
3538
return 'localhost';
3639
}
3740

41+
const hostname = getMachineIpAddress();
42+
// In Node.js v16 and below, 'localhost' is using IPv4, so need to listen to '0.0.0.0'
43+
// In Node.js v17+, 'localhost' is using IPv6, so need to listen to '::'
44+
const listenAddress = Number.parseInt(process.versions.node.split('.')[0]) >= 17 ? '::' : '0.0.0.0';
45+
3846
module.exports = function (config) {
3947
config.set({
4048
// global config of your BrowserStack account
@@ -75,12 +83,16 @@ module.exports = function (config) {
7583
browserNoActivityTimeout: 300000,
7684
browserDisconnectTolerance: 0,
7785
browserSocketTimeout: 60000,
78-
hostname: getMachineIpAddress(),
86+
hostname,
87+
listenAddress,
7988
customLaunchers: {
8089
ChromeTest: { base: 'ChromeHeadless', flags: ['--enable-features=SharedArrayBuffer'] },
8190
ChromePerf: { base: 'Chrome', flags: ['--window-size=1,1', '--enable-features=SharedArrayBuffer'] },
8291
ChromeDebug: { debug: true, base: 'Chrome', flags: ['--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer'] },
83-
92+
ChromeCanaryTest: { base: 'ChromeCanary', flags: ['--window-size=1,1', '--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu'] },
93+
ChromeCanaryProfileTest: { base: 'ChromeCanary', flags: ['--window-size=1,1', '--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu', '--disable-dawn-features=disallow_unsafe_apis'] },
94+
ChromeCanaryDebug: { debug: true, base: 'ChromeCanary', flags: ['--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu'] },
95+
ChromeCanaryProfileDebug: { debug: true, base: 'ChromeCanary', flags: ['--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu', '--disable-dawn-features=disallow_unsafe_apis'] },
8496
//
8597
// ==== BrowserStack browsers ====
8698
//

web/lib/build-def.d.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ interface BuildDefinitions {
1414
* defines whether to disable the whole WebGL backend in the build.
1515
*/
1616
DISABLE_WEBGL: boolean;
17+
/**
18+
* defines whether to disable the whole WebGpu backend in the build.
19+
*/
20+
DISABLE_WEBGPU: boolean;
1721
/**
1822
* defines whether to disable the whole WebAssembly backend in the build.
1923
*/

web/lib/index.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@ if (!BUILD_DEFS.DISABLE_WEBGL) {
1313
const onnxjsBackend = require('./backend-onnxjs').onnxjsBackend;
1414
registerBackend('webgl', onnxjsBackend, -10);
1515
}
16+
1617
if (!BUILD_DEFS.DISABLE_WASM) {
1718
const wasmBackend = require('./backend-wasm').wasmBackend;
19+
if (!BUILD_DEFS.DISABLE_WEBGPU) {
20+
registerBackend('webgpu', wasmBackend, 5);
21+
}
1822
registerBackend('cpu', wasmBackend, 10);
1923
registerBackend('wasm', wasmBackend, 10);
2024
registerBackend('xnnpack', wasmBackend, 9);

web/lib/onnxjs/backend.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ export interface Backend {
7878
const backendsCache: Map<string, Backend> = new Map();
7979

8080
export const backend: {[name: string]: Backend} = {
81-
webgl: new WebGLBackend(),
81+
webgl: new WebGLBackend()
8282
};
8383

8484
/**

web/lib/onnxjs/backends/webgl/ops/reduce.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ const createReduceProgramInfo =
9898
};
9999

100100
const validateInputs = (inputs: Tensor[]): void => {
101+
// TODO: support Reduce* operators with 2 inputs.
101102
if (!inputs || inputs.length !== 1) {
102103
throw new Error('Reduce op requires 1 input.');
103104
}
@@ -174,4 +175,4 @@ export const reduceLogSumSquare: OperatorImplementation<ReduceAttributes> =
174175
(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes): Tensor[] => {
175176
const reduceOp: ReduceOp = (): string[] => ['float t; value = 0.0;', 't = _A(inputIdx); value += t * t;', ''];
176177
return reduce(inferenceHandler, inputs, attributes, 'ReduceLogSumSquare', reduceOp);
177-
};
178+
};

web/lib/onnxjs/opset.ts

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@ export interface OpSet {
88
domain: string;
99
version: number;
1010
}
11-
1211
export declare namespace OpSet {
1312
/**
1413
* Domain of an opset, it can be an empty string(default value, represent for ai.onnx), or 'ai.onnx.ml'
1514
*/
1615
type Domain = ''|'ai.onnx.ml'|'com.microsoft';
17-
1816
/**
1917
* A resolve rule consists of 4 or 5 items: opType, opSetDomain, versionSelector, operatorImplementation and
2018
* operatorInitialization (optional)

web/lib/wasm/binding/ort-wasm.d.ts

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4+
declare namespace JSEP {
5+
type BackendType = unknown;
6+
type AllocFunction = (size: number) => number;
7+
type FreeFunction = (size: number) => number;
8+
type UploadFunction = (dataOffset: number, gpuDataId: number, size: number) => void;
9+
type DownloadFunction = (gpuDataId: number, dataOffset: number, size: number) => Promise<void>;
10+
type CreateKernelFunction = (name: string, kernel: number, attribute: unknown) => void;
11+
type ReleaseKernelFunction = (kernel: number) => void;
12+
type RunFunction = (kernel: number, contextDataOffset: number) => number;
13+
}
14+
415
export interface OrtWasmModule extends EmscriptenModule {
516
// #region emscripten functions
617
stackSave(): number;
@@ -51,6 +62,17 @@ export interface OrtWasmModule extends EmscriptenModule {
5162
// #region config
5263
mainScriptUrlOrBlob?: string|Blob;
5364
// #endregion
65+
66+
// #region JSEP
67+
jsepInit?
68+
(backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction,
69+
download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction,
70+
releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void;
71+
72+
_JsepOutput(context: number, index: number, data: number): number;
73+
74+
jsepRunPromise?: Promise<number>;
75+
// #endregion
5476
}
5577

5678
declare const moduleFactory: EmscriptenModuleFactory<OrtWasmModule>;

0 commit comments

Comments
 (0)