Skip to content

Commit ec1b574

Browse files
[layers] Import arraysEqual from the public import path (#7959)
1 parent 0bd03f7 commit ec1b574

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

tfjs-layers/src/layers/nlp/multihead_attention.ts

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020
*/
2121

2222
/* Original source: keras/layers/attention/multi_head_attention.py */
23-
import { Tensor, einsum, linalg, logicalAnd, mul, ones, serialization, tidy } from '@tensorflow/tfjs-core';
24-
// tslint:disable-next-line: no-imports-from-dist
25-
import { arraysEqual } from '@tensorflow/tfjs-core/dist/util_base';
23+
import { Tensor, einsum, linalg, logicalAnd, mul, ones, serialization, tidy, util } from '@tensorflow/tfjs-core';
2624

2725
import { cast, expandDims } from '../../backend/tfjs_backend';
2826
import { Constraint, ConstraintIdentifier, getConstraint, serializeConstraint } from '../../constraints';
@@ -882,7 +880,7 @@ export class MultiHeadAttention extends Layer {
882880
);
883881
}
884882

885-
if (!arraysEqual(valueShape.slice(1, -1), keyShape.slice(1, -1))) {
883+
if (!util.arraysEqual(valueShape.slice(1, -1), keyShape.slice(1, -1))) {
886884
throw new Error(
887885
`All dimensions of 'value' and 'key', except the last one, must be ` +
888886
`equal. Received ${valueShape} and ${keyShape}`

0 commit comments

Comments
 (0)