Skip to content

Commit a0f860b

Browse files
fujunweichromium-wpt-export-bot
authored andcommitted
webnn: Limit the tensor size with the opSupportLimits
The tensor byte length is UINT32_MAX for DirectML backend on Windows, INT32_MAX for CoreML and TFLite backend, `OperandDescriptor::Create` function will be used to validate the limit. This limitation also prevents the mask tensor of triangular to allocate large amounts of memory. The `OperandDescriptor::CreateForDeserialization` only be used by mojom traits that need to be validated tensor size limit later. Add some unittests and WPT tests to validate the byte length limit. Bug: 359729258 Change-Id: Ic174fe47da3c110a6d263d22318738f5656a07cc Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6079317 Commit-Queue: Junwei Fu <[email protected]> Reviewed-by: Reilly Grant <[email protected]> Reviewed-by: ningxin hu <[email protected]> Reviewed-by: Alex Gough <[email protected]> Cr-Commit-Position: refs/heads/main@{#1407058}
1 parent fc0e6aa commit a0f860b

17 files changed

+200
-20
lines changed

webnn/conformance_tests/tensor.https.any.js

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,17 @@ const testCreateTensorFails = (testName, tensorDescriptor) => {
128128
}, `${testName} / ${tensorDescriptor.dataType}`);
129129
};
130130

131+
132+
promise_test(async t => {
133+
const tensorDescriptor = {
134+
dataType: 'int32',
135+
shape: [(context.opSupportLimits().maxTensorByteLength + 1) / 4],
136+
writable: true,
137+
};
138+
await promise_rejects_js(
139+
t, TypeError, context.createTensor(tensorDescriptor));
140+
}, `create too large tensor byte length that exceeds limit`);
141+
131142
/**
132143
* Asserts the tensor data in MLTensor matches expected.
133144
* @param {MLContext} mlContext - The context used to create the tensor.

webnn/resources/utils_validation.js

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,24 @@ function validateTwoInputsBroadcastable(operationName, label) {
248248
}, `[${operationName}] TypeError is expected if two inputs aren't broadcastable`);
249249
}
250250

251+
function validateTwoBroadcastableInputsTensorLimit(operationName, label) {
252+
if (navigator.ml === undefined) {
253+
return;
254+
}
255+
promise_test(async t => {
256+
const builder = new MLGraphBuilder(context);
257+
258+
const a = builder.input('a', {dataType: 'float32',
259+
shape: [context.opSupportLimits().maxTensorByteLength / 4, 1]});
260+
const b = builder.input('b', {dataType: 'float32', shape: [1, 5] });
261+
262+
const options = {label};
263+
const regrexp = new RegExp('\\[' + label + '\\]');
264+
assert_throws_with_label(
265+
() => builder[operationName](a, b, options), regrexp);
266+
}, `[${operationName}] throw if the output tensor byte length exceeds limit`);
267+
}
268+
251269
function validateTwoInputsOfSameDataType(operationName, label) {
252270
if (navigator.ml === undefined) {
253271
return;

webnn/validation_tests/cast.https.any.js

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,12 @@ multi_builder_test(async (t, builder, otherBuilder) => {
1414
assert_throws_js(
1515
TypeError, () => builder.cast(inputFromOtherBuilder, 'int64'));
1616
}, '[cast] throw if input is from another builder');
17+
18+
promise_test(async t => {
19+
const builder = new MLGraphBuilder(context);
20+
const input = builder.input('input', {
21+
dataType: 'int8',
22+
shape: [context.opSupportLimits().maxTensorByteLength / 2]});
23+
assert_throws_js(
24+
TypeError, () => builder.cast(input, 'int64'));
25+
}, '[cast] throw if the output tensor byte length exceeds limit');

webnn/validation_tests/concat.https.any.js

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ const tests = [
8383
],
8484
axis: 1,
8585
},
86-
8786
];
8887

8988
tests.forEach(
@@ -120,3 +119,18 @@ multi_builder_test(async (t, builder, otherBuilder) => {
120119
TypeError,
121120
() => builder.concat([input1, input2, inputFromOtherBuilder, input3]));
122121
}, '[concat] throw if any input is from another builder');
122+
123+
promise_test(async t => {
124+
const builder = new MLGraphBuilder(context);
125+
126+
const operandDescriptor = {
127+
dataType: 'float32',
128+
shape: [context.opSupportLimits().maxTensorByteLength / 4]
129+
};
130+
const input1 = builder.input('input1', operandDescriptor);
131+
const input2 = builder.input('input2', operandDescriptor);
132+
const input3 = builder.input('input3', operandDescriptor);
133+
134+
assert_throws_js(
135+
TypeError, () => builder.concat(input1, input2, input3));
136+
}, '[concat] throw if the output tensor byte length exceeds limit');

webnn/validation_tests/dequantizeLinear.https.any.js

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
'use strict';
99

10+
const label = 'dequantize_linear_123';
11+
const regrexp = new RegExp('\\[' + label + '\\]');
1012
const tests = [
1113
{
1214
name:
@@ -94,9 +96,7 @@ tests.forEach(
9496
assert_equals(output.dataType, test.output.dataType);
9597
assert_array_equals(output.shape, test.output.shape);
9698
} else {
97-
const label = 'dequantize_linear_123';
9899
const options = {label};
99-
const regrexp = new RegExp('\\[' + label + '\\]');
100100
assert_throws_with_label(
101101
() => builder.dequantizeLinear(input, scale, zeroPoint, options),
102102
regrexp);
@@ -143,3 +143,17 @@ multi_builder_test(async (t, builder, otherBuilder) => {
143143
TypeError,
144144
() => builder.dequantizeLinear(input, scale, zeroPointFromOtherBuilder));
145145
}, '[dequantizeLinear] throw if zeroPoint is from another builder');
146+
147+
promise_test(async t => {
148+
const builder = new MLGraphBuilder(context);
149+
150+
const input = builder.input('input', {
151+
dataType: 'int8',
152+
shape: [context.opSupportLimits().maxTensorByteLength / 5, 5]});
153+
const scale = builder.input('scale', {dataType: 'float32', shape: [5]});
154+
const zeroPoint = builder.input('zeroPoint', {dataType: 'int8', shape: [5]});
155+
156+
const options = {label};
157+
assert_throws_with_label(
158+
() => builder.dequantizeLinear(input, scale, zeroPoint, options), regrexp);
159+
}, '[dequantizeLinear] throw if the output tensor byte length exceeds limit');

webnn/validation_tests/elementwise-binary.https.any.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,5 +88,6 @@ kElementwiseBinaryOperators.forEach((operatorName) => {
8888
validateTwoInputsOfSameDataType(operatorName, label);
8989
validateTwoInputsBroadcastable(operatorName, label);
9090
validateTwoInputsFromMultipleBuilders(operatorName);
91+
validateTwoBroadcastableInputsTensorLimit(operatorName, label);
9192
runElementWiseBinaryTests(operatorName, tests);
9293
});

webnn/validation_tests/expand.https.any.js

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ multi_builder_test(async (t, builder, otherBuilder) => {
1717
}, '[expand] throw if input is from another builder');
1818

1919
const label = 'xxx_expand';
20-
20+
const regrexp = new RegExp('\\[' + label + '\\]');
2121
const tests = [
2222
{
2323
name: '[expand] Test with 0-D scalar to 3-D tensor.',
@@ -76,7 +76,6 @@ tests.forEach(
7676
} else {
7777
const options = {...test.options};
7878
if (options.label) {
79-
const regrexp = new RegExp('\\[' + label + '\\]');
8079
assert_throws_with_label(
8180
() => builder.expand(input, test.newShape, options), regrexp);
8281
} else {
@@ -104,3 +103,15 @@ promise_test(async t => {
104103
}
105104
}
106105
}, `[expand] Test expand with all of the data types.`);
106+
107+
promise_test(async t => {
108+
const builder = new MLGraphBuilder(context);
109+
110+
const input = builder.input('input', {
111+
dataType: 'float32', shape: [1, 2, 1, 1]});
112+
const newShape = [1, 2, context.opSupportLimits().maxTensorByteLength, 1];
113+
114+
const options = {label};
115+
assert_throws_with_label(
116+
() => builder.expand(input, newShape, options), regrexp);
117+
}, '[expand] throw if the output tensor byte length exceeds limit');

webnn/validation_tests/gather.https.any.js

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
'use strict';
99

10+
const label = 'gather_'
11+
const regrexp = new RegExp('\\[' + label + '\\]');
1012
const tests = [
1113
{
1214
name: '[gather] Test gather with default options and 0-D indices',
@@ -58,7 +60,7 @@ const tests = [
5860
'[gather] TypeError is expected if the data type of indices is uint64 which is invalid',
5961
input: {dataType: 'float16', shape: [1, 2, 3, 4]},
6062
indices: {dataType: 'uint64', shape: [5, 6]},
61-
}
63+
},
6264
];
6365

6466
tests.forEach(
@@ -77,9 +79,7 @@ tests.forEach(
7779
assert_equals(output.dataType, test.output.dataType);
7880
assert_array_equals(output.shape, test.output.shape);
7981
} else {
80-
const label = 'gather_'
8182
options.label = label;
82-
const regrexp = new RegExp('\\[' + label + '\\]');
8383
assert_throws_with_label(
8484
() => builder.gather(input, indices, options), regrexp);
8585
}
@@ -102,3 +102,19 @@ multi_builder_test(async (t, builder, otherBuilder) => {
102102
assert_throws_js(
103103
TypeError, () => builder.gather(input, indicesFromOtherBuilder));
104104
}, '[gather] throw if indices is from another builder');
105+
106+
promise_test(async t => {
107+
const builder = new MLGraphBuilder(context);
108+
109+
const input = builder.input('input', {
110+
dataType: 'float32', shape: [1, 3, 3, 4]});
111+
const indices = builder.input('indices', {
112+
dataType: 'int32',
113+
shape: [context.opSupportLimits().maxTensorByteLength / 4] });
114+
115+
const options = {};
116+
options.label = label;
117+
options.axis = 2;
118+
assert_throws_with_label(
119+
() => builder.gather(input, indices, options), regrexp);
120+
}, '[gather] throw if the output tensor byte length exceeds limit');

webnn/validation_tests/gatherND.https.any.js

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
'use strict';
99

10+
const label = 'gatherND_';
11+
const regexp = new RegExp('\\[' + label + '\\]');
1012
const tests = [
1113
{
1214
name: '[gatherND] Test gatherND with 5D input 3D indices',
@@ -49,9 +51,7 @@ tests.forEach(test => promise_test(async t => {
4951
assert_equals(output.dataType, test.output.dataType);
5052
assert_array_equals(output.shape, test.output.shape);
5153
} else {
52-
const label = 'gatherND_';
5354
const options = {label: label};
54-
const regexp = new RegExp('\\[' + label + '\\]');
5555
assert_throws_with_label(
5656
() => builder.gatherND(input, indices, options), regexp);
5757
}
@@ -74,3 +74,17 @@ multi_builder_test(async (t, builder, otherBuilder) => {
7474
assert_throws_js(
7575
TypeError, () => builder.gatherND(input, indicesFromOtherBuilder));
7676
}, '[gatherND] Throw if indices is from another builder');
77+
78+
promise_test(async t => {
79+
const builder = new MLGraphBuilder(context);
80+
81+
const input = builder.input('input', {
82+
dataType: 'float32', shape: [2, 2, 3, 3, 4]});
83+
const indices = builder.input('indices', {
84+
dataType: 'int32',
85+
shape: [context.opSupportLimits().maxTensorByteLength / 4, 1, 1]});
86+
87+
const options = {label};
88+
assert_throws_with_label(
89+
() => builder.gatherND(input, indices, options), regexp);
90+
}, '[gatherND] throw if the output tensor byte length exceeds limit');

webnn/validation_tests/gemm.https.any.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77

88
'use strict';
99

10+
const label = 'gemm_xxx';
1011
const kExampleInputDescriptor = {
1112
dataType: 'float32',
1213
shape: [2, 2]
1314
};
1415

1516
validateTwoInputsFromMultipleBuilders('gemm');
17+
validateTwoBroadcastableInputsTensorLimit('gemm', label);
1618

1719
multi_builder_test(async (t, builder, otherBuilder) => {
1820
const cFromOtherBuilder = otherBuilder.input('c', kExampleInputDescriptor);
@@ -23,8 +25,6 @@ multi_builder_test(async (t, builder, otherBuilder) => {
2325
assert_throws_js(TypeError, () => builder.gemm(a, b, options));
2426
}, '[gemm] throw if c option is from another builder');
2527

26-
const label = 'gemm_xxx';
27-
2828
const tests = [
2929
{
3030
name: '[gemm] Test building gemm with default option.',

0 commit comments

Comments
 (0)