Skip to content

Commit 069950d

Browse files
authored
[js/webgpu] make RunFunction return void (microsoft#15669)
### Description make `RunFunction` return `void`. the return value is meaningless in the OpResolveRule context. Allows any JavaScript error to be caught and returns non-zero return value from `computeKernel()`
1 parent 5c4f5bb commit 069950d

File tree

10 files changed

+53
-83
lines changed

10 files changed

+53
-83
lines changed

web/lib/wasm/jsep/backend-webgpu.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,11 @@ export class WebGpuBackend {
333333

334334
this.temporaryData = [];
335335
try {
336-
return kernelEntry(context, attributes[1]);
336+
kernelEntry(context, attributes[1]);
337+
return 0; // ORT_OK
338+
} catch (e) {
339+
LOG_DEBUG('warning', `[WebGPU] Kernel "${name}" failed. Error: ${e}`);
340+
return 1; // ORT_FAIL
337341
} finally {
338342
for (const data of this.temporaryData) {
339343
this.gpuDataManager.release(data.id);

web/lib/wasm/jsep/webgpu/op-resolve-rules.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import {parseTransposeAttributes, transpose} from './ops/transpose';
1010
import * as unaryOps from './ops/unary-op';
1111
import {ComputeContext} from './types';
1212

13-
export type RunFunction = (context: ComputeContext, attribute?: unknown) => number;
13+
export type RunFunction = (context: ComputeContext, attribute?: unknown) => void;
1414
export type ParseAttributeFunction = (attributeRaw: unknown) => unknown;
1515
export type OperatorImplementation = [RunFunction]|[RunFunction, ParseAttributeFunction];
1616

web/lib/wasm/jsep/webgpu/ops/binary-op.ts

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -173,22 +173,19 @@ const createBinaryOpProgramInfoLoader =
173173
};
174174
};
175175

176-
export const add = (context: ComputeContext): number => {
176+
export const add = (context: ComputeContext): void => {
177177
context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Add', (a, b) => `${a}+${b}`));
178-
return 0;
179178
};
180179

181-
export const div = (context: ComputeContext): number => {
180+
export const div = (context: ComputeContext): void => {
182181
context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Div', (a, b) => `${a}/${b}`));
183-
return 0;
184182
};
185183

186-
export const mul = (context: ComputeContext): number => {
184+
export const mul = (context: ComputeContext): void => {
187185
context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Mul', (a, b) => `${a}*${b}`));
188-
return 0;
189186
};
190187

191-
export const pow = (context: ComputeContext): number => {
188+
export const pow = (context: ComputeContext): void => {
192189
context.compute(createBinaryOpProgramInfoLoader(
193190
context.inputs, 'Pow', ({scalar: (a, b) => `pow_f32(${a},${b})`, vector: (a, b) => `pow_vf32(${a},${b})`}), `
194191
fn pow_f32(a : f32, b : f32) -> f32 {
@@ -204,10 +201,8 @@ export const pow = (context: ComputeContext): number => {
204201
return vec4<f32>(pow_f32(a.x, b.x), pow_f32(a.y, b.y), pow_f32(a.z, b.z), pow_f32(a.w, b.w));
205202
}
206203
`));
207-
return 0;
208204
};
209205

210-
export const sub = (context: ComputeContext): number => {
206+
export const sub = (context: ComputeContext): void => {
211207
context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Sub', (a, b) => `${a}-${b}`));
212-
return 0;
213208
};

web/lib/wasm/jsep/webgpu/ops/concat.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,7 @@ const createConcatProgramInfoLoader =
151151
return {...metadata, get: () => createConcatProgramInfo(metadata, inputs, attributes.axis)};
152152
};
153153

154-
export const concat = (context: ComputeContext, attributes: ConcatAttributes): number => {
154+
export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => {
155155
validateInputs(context.inputs);
156156
context.compute(createConcatProgramInfoLoader(context.inputs, attributes));
157-
return 0;
158157
};

web/lib/wasm/jsep/webgpu/ops/conv.ts

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ export const parseConvAttributes = (attributes: Record<string, unknown>): ConvAt
138138
{autoPad, format, dilations, group, kernelShape, pads, strides, wIsConst, ...activationAttributes});
139139
};
140140

141-
const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): number => {
141+
const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): void => {
142142
const adjustedAttributes = getAdjustedConvAttributes(attributes, inputs);
143143

144144
// check attributes
@@ -170,12 +170,12 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
170170
attributes.autoPad === 'VALID'))) {
171171
// TODO: implement conv2dByMatMul()
172172
context.compute(createGroupedConvProgramInfoLoader(inputs, adjustedAttributes));
173-
return 0;
173+
return;
174174
}
175175

176176
if (!isChannelsLast || attributes.group !== 1) {
177177
context.compute(createGroupedConvProgramInfoLoader(inputs, adjustedAttributes));
178-
return 0;
178+
return;
179179
}
180180

181181
// TODO: implement conv2dWithIm2Col()
@@ -215,10 +215,9 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
215215
convInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias,
216216
sequentialAccessByThreads),
217217
{inputs: convInputs});
218-
return 0;
219218
};
220219

221-
const conv1d = (context: ComputeContext, attributes: ConvAttributes): number => {
220+
const conv1d = (context: ComputeContext, attributes: ConvAttributes): void => {
222221
// extend the input to 2D by adding H dimension
223222
const isChannelLast = attributes.format === 'NHWC';
224223
const inputs = [
@@ -242,11 +241,13 @@ const conv1d = (context: ComputeContext, attributes: ConvAttributes): number =>
242241
context.compute(createGroupedConvProgramInfoLoader(
243242
inputs, adjustedAttributes,
244243
outputShape => isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : []));
245-
return 0;
246244
};
247245

248-
export const conv = (context: ComputeContext, attributes: ConvAttributes): number => {
246+
export const conv = (context: ComputeContext, attributes: ConvAttributes): void => {
249247
validateInputs(context.inputs, attributes); // currently will fail if not conv1D/2D
250-
return context.inputs[0].dims.length === 3 ? conv1d(context, attributes) :
251-
conv2d(context, context.inputs, attributes);
248+
if (context.inputs[0].dims.length === 3) {
249+
conv1d(context, attributes);
250+
} else {
251+
conv2d(context, context.inputs, attributes);
252+
}
252253
};

web/lib/wasm/jsep/webgpu/ops/gemm.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,9 @@ const createGemmProgramInfoLoader = (inputs: readonly TensorView[], attributes:
136136
return {...metadata, get: () => createGemmProgramInfo(metadata, inputs, attributes)};
137137
};
138138

139-
export const gemm = (context: ComputeContext, attributes: GemmAttributes): number => {
139+
export const gemm = (context: ComputeContext, attributes: GemmAttributes): void => {
140140
validateInputs(context.inputs);
141141
context.compute(createGemmProgramInfoLoader(context.inputs, attributes));
142-
return 0;
143142
};
144143

145144
export const parseGemmAttributes = (attributes: Record<string, unknown>): GemmAttributes =>

web/lib/wasm/jsep/webgpu/ops/matmul.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
9292
}
9393
};
9494

95-
export const matMul = (context: ComputeContext): number => {
95+
export const matMul = (context: ComputeContext): void => {
9696
validateInputs(context.inputs);
9797

9898
context.compute(createMatmulProgramInfoLoader(context.inputs, {activation: '', activationCacheKey: ''}));
99-
return 0;
10099
};

web/lib/wasm/jsep/webgpu/ops/pool.ts

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -289,11 +289,10 @@ export const parseAveragePoolAttributes = (attributes: Record<string, unknown>):
289289
return createAttributeWithCacheKey({countIncludePad, ...attr});
290290
};
291291

292-
export const averagePool = (context: ComputeContext, attributes: AveragePoolAttributes): number => {
292+
export const averagePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
293293
validateInputs(context.inputs);
294294
const metadata = {name: 'AveragePool', inputTypes: [GpuDataType.default], cacheHint: attributes.cacheKey};
295295
context.compute({...metadata, get: () => createAveragePoolProgramInfo(context.inputs, metadata, false, attributes)});
296-
return 0;
297296
};
298297

299298
const globalPoolAttributes = {
@@ -313,11 +312,10 @@ export const parseGlobalAveragePoolAttributes = (attributes: Record<string, unkn
313312
return {format, ...globalPoolAttributes, cacheKey: format};
314313
};
315314

316-
export const globalAveragePool = (context: ComputeContext, attributes: AveragePoolAttributes): number => {
315+
export const globalAveragePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
317316
validateInputs(context.inputs);
318317
const metadata = {name: 'GlobalAveragePool', inputTypes: [GpuDataType.default], cacheHint: attributes.cacheKey};
319318
context.compute({...metadata, get: () => createAveragePoolProgramInfo(context.inputs, metadata, true, attributes)});
320-
return 0;
321319
};
322320

323321
export interface MaxPoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey {
@@ -343,11 +341,10 @@ const createMaxPoolProgramInfo =
343341
};
344342
};
345343

346-
export const maxPool = (context: ComputeContext, attributes: MaxPoolAttributes): number => {
344+
export const maxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
347345
validateInputs(context.inputs);
348346
const metadata = {name: 'MaxPool', inputTypes: [GpuDataType.default], cacheHint: attributes.cacheKey};
349347
context.compute({...metadata, get: () => createMaxPoolProgramInfo(context.inputs, metadata, false, attributes)});
350-
return 0;
351348
};
352349

353350
export const parseMaxPoolAttributes = (attributes: Record<string, unknown>): MaxPoolAttributes => {
@@ -371,9 +368,8 @@ export const parseGlobalMaxPoolAttributes = (attributes: Record<string, unknown>
371368
return {format, ...globalPoolAttributes, cacheKey: format};
372369
};
373370

374-
export const globalMaxPool = (context: ComputeContext, attributes: MaxPoolAttributes): number => {
371+
export const globalMaxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
375372
validateInputs(context.inputs);
376373
const metadata = {name: 'GlobalMaxPool', inputTypes: [GpuDataType.default], cacheHint: attributes.cacheKey};
377374
context.compute({...metadata, get: () => createMaxPoolProgramInfo(context.inputs, metadata, true, attributes)});
378-
return 0;
379375
};

web/lib/wasm/jsep/webgpu/ops/transpose.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,13 @@ export const createTransposeProgramInfo = (input: TensorView, permAttr: number[]
8484
};
8585
};
8686

87-
export const transpose = (context: ComputeContext, attributes: TransposeAttributes): number => {
87+
export const transpose = (context: ComputeContext, attributes: TransposeAttributes): void => {
8888
validateInputs(context.inputs);
8989
context.compute({
9090
...transposeProgramMetadata,
9191
cacheHint: attributes.cacheKey,
9292
get: () => createTransposeProgramInfo(context.inputs[0], attributes.perm)
9393
});
94-
return 0;
9594
};
9695

9796
export const parseTransposeAttributes = (attributes: Record<string, unknown>): TransposeAttributes =>

web/lib/wasm/jsep/webgpu/ops/unary-op.ts

Lines changed: 24 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -58,46 +58,39 @@ const createElementwiseProgramInfoLoader =
5858
};
5959
};
6060

61-
export const abs = (context: ComputeContext): number => {
61+
export const abs = (context: ComputeContext): void => {
6262
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Abs', 'abs'));
63-
return 0;
6463
};
6564

66-
export const acos = (context: ComputeContext): number => {
65+
export const acos = (context: ComputeContext): void => {
6766
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Acos', 'acos'));
68-
return 0;
6967
};
7068

71-
export const acosh = (context: ComputeContext): number => {
69+
export const acosh = (context: ComputeContext): void => {
7270
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Acosh', 'acosh'));
73-
return 0;
7471
};
7572

76-
export const asin = (context: ComputeContext): number => {
73+
export const asin = (context: ComputeContext): void => {
7774
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Asin', 'asin'));
78-
return 0;
7975
};
8076

81-
export const asinh = (context: ComputeContext): number => {
77+
export const asinh = (context: ComputeContext): void => {
8278
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Asinh', 'asinh'));
83-
return 0;
8479
};
8580

86-
export const atan = (context: ComputeContext): number => {
81+
export const atan = (context: ComputeContext): void => {
8782
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Atan', 'atan'));
88-
return 0;
8983
};
90-
export const atanh = (context: ComputeContext): number => {
84+
export const atanh = (context: ComputeContext): void => {
9185
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Atanh', 'atanh'));
92-
return 0;
9386
};
9487

9588
export interface ClipAttributes extends AttributeWithCacheKey {
9689
readonly min: number;
9790
readonly max: number;
9891
}
9992

100-
export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): number => {
93+
export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => {
10194
context.compute(
10295
createElementwiseProgramInfoLoader(
10396
context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, `
@@ -106,39 +99,35 @@ export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): nu
10699
`,
107100
attributes.cacheKey),
108101
{inputs: [0]});
109-
return 0;
110102
};
111103
const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
112104
const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
113105
const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
114106
return createAttributeWithCacheKey({min, max});
115107
};
116108

117-
export const clip = (context: ComputeContext): number => {
109+
export const clip = (context: ComputeContext): void => {
118110
const attributes = generateClipAttributesFromInputs(context.inputs);
119-
return clipV10(context, attributes);
111+
clipV10(context, attributes);
120112
};
121113

122-
export const ceil = (context: ComputeContext): number => {
114+
export const ceil = (context: ComputeContext): void => {
123115
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Ceil', 'ceil'));
124-
return 0;
125116
};
126117

127-
export const cos = (context: ComputeContext): number => {
118+
export const cos = (context: ComputeContext): void => {
128119
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Cos', 'cos'));
129-
return 0;
130120
};
131121

132-
export const cosh = (context: ComputeContext): number => {
122+
export const cosh = (context: ComputeContext): void => {
133123
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Cosh', 'cosh'));
134-
return 0;
135124
};
136125

137126
export interface EluAttributes extends AttributeWithCacheKey {
138127
readonly alpha: number;
139128
}
140129

141-
export const elu = (context: ComputeContext, attributes: EluAttributes): number => {
130+
export const elu = (context: ComputeContext, attributes: EluAttributes): void => {
142131
context.compute(createElementwiseProgramInfoLoader(
143132
context.inputs[0], 'Elu', a => `elu_vf32(${a})`, `
144133
const elu_alpha_: f32 = f32(${attributes.alpha});
@@ -151,13 +140,12 @@ export const elu = (context: ComputeContext, attributes: EluAttributes): number
151140
return vec4(elu_f32(v.x), elu_f32(v.y), elu_f32(v.z), elu_f32(v.w));
152141
}`,
153142
attributes.cacheKey));
154-
return 0;
155143
};
156144

157145
export const parseEluAttributes = (attributes: Record<string, unknown>): EluAttributes =>
158146
createAttributeWithCacheKey(attributes as {alpha: number});
159147

160-
export const erf = (context: ComputeContext): number => {
148+
export const erf = (context: ComputeContext): void => {
161149
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, `
162150
const r0: f32 = 0.3275911;
163151
const r1: f32 = 0.254829592;
@@ -171,50 +159,40 @@ export const erf = (context: ComputeContext): number => {
171159
let x = 1.0 / (1.0 + r0 * absv);
172160
return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv));
173161
}`));
174-
return 0;
175162
};
176163

177-
export const floor = (context: ComputeContext): number => {
164+
export const floor = (context: ComputeContext): void => {
178165
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Floor', 'floor'));
179-
return 0;
180166
};
181167

182-
export const neg = (context: ComputeContext): number => {
168+
export const neg = (context: ComputeContext): void => {
183169
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Neg', a => `-${a}`));
184-
return 0;
185170
};
186171

187-
export const reciprocal = (context: ComputeContext): number => {
172+
export const reciprocal = (context: ComputeContext): void => {
188173
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Reciprocal', a => `1.0/${a}`));
189-
return 0;
190174
};
191175

192-
export const sigmoid = (context: ComputeContext): number => {
176+
export const sigmoid = (context: ComputeContext): void => {
193177
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`));
194-
return 0;
195178
};
196179

197-
export const sin = (context: ComputeContext): number => {
180+
export const sin = (context: ComputeContext): void => {
198181
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sin', 'sin'));
199-
return 0;
200182
};
201183

202-
export const sinh = (context: ComputeContext): number => {
184+
export const sinh = (context: ComputeContext): void => {
203185
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sinh', 'sinh'));
204-
return 0;
205186
};
206187

207-
export const sqrt = (context: ComputeContext): number => {
188+
export const sqrt = (context: ComputeContext): void => {
208189
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sqrt', 'sqrt'));
209-
return 0;
210190
};
211191

212-
export const tan = (context: ComputeContext): number => {
192+
export const tan = (context: ComputeContext): void => {
213193
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Tan', 'tan'));
214-
return 0;
215194
};
216195

217-
export const tanh = (context: ComputeContext): number => {
196+
export const tanh = (context: ComputeContext): void => {
218197
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Tanh', 'tanh'));
219-
return 0;
220198
};

0 commit comments

Comments
 (0)