Skip to content

Commit 4687567

Browse files
authored
[tfjs-layers] Fix outputs and mask mismatch (#7993)
Bug: #7974
1 parent ac9519e commit 4687567

File tree

1 file changed

+20
-25
lines changed

1 file changed

+20
-25
lines changed

tfjs-layers/src/engine/topology.ts

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,8 +1066,8 @@ export abstract class Layer extends serialization.Serializable {
10661066
If the input tensor(s) had no previous history,
10671067
this does nothing.
10681068
*/
1069-
this.addInboundNode(inputs, output, null, null,
1070-
inputShape, outputShape, kwargs);
1069+
this.addInboundNode(
1070+
inputs, output, null, null, inputShape, outputShape, kwargs);
10711071
this._refCount++;
10721072

10731073
if (this.activityRegularizer != null) {
@@ -1387,29 +1387,24 @@ export abstract class Layer extends serialization.Serializable {
13871387
return mask;
13881388
}
13891389

1390-
private setMaskMetadata(inputs: Tensor|Tensor[], outputs: Tensor|Tensor[],
1391-
previousMask?: Tensor|Tensor[]): void {
1390+
private setMaskMetadata(
1391+
inputs: Tensor|Tensor[], outputs: Tensor|Tensor[],
1392+
previousMask?: Tensor|Tensor[]): void {
13921393
if (!this.supportsMasking) {
13931394
return;
13941395
}
13951396

13961397
const outputMasks = this.computeMask(inputs, previousMask);
1397-
if (outputs instanceof Array && outputMasks instanceof Array) {
1398-
if (outputs.length !== outputMasks.length) {
1399-
throw new Error(`${this.name} outputs ${outputs.length} tensors `
1400-
+ `but ${outputMasks.length} masks for those tensors`);
1401-
}
1402-
for (let i = 0; i < outputs.length; i++) {
1403-
outputs[i].kerasMask = outputMasks[i];
1404-
}
1405-
} else if (outputMasks instanceof Array) {
1406-
throw new Error(`{this.name} outputs a single tensor `
1407-
+ `but ${outputMasks.length} masks`);
1408-
} else if (outputs instanceof Array) {
1409-
throw new Error(`{this.name} outputs ${outputs.length} tensors `
1410-
+ `but only one mask`);
1411-
} else {
1412-
outputs.kerasMask = outputMasks;
1398+
const outputsList = generic_utils.toList(outputs);
1399+
const outputMasksList = generic_utils.toList(outputMasks);
1400+
1401+
if (outputsList.length !== outputMasksList.length) {
1402+
throw new Error(
1403+
`${this.name} outputs ${outputsList.length} tensors ` +
1404+
`but ${outputsList.length} masks for those tensors`);
1405+
}
1406+
for (let i = 0; i < outputsList.length; i++) {
1407+
outputsList[i].kerasMask = outputMasksList[i];
14131408
}
14141409
}
14151410

@@ -1661,10 +1656,10 @@ export function getSourceInputs(
16611656
}
16621657
}
16631658

1664-
type MaybeSymbolic = SymbolicTensor | Tensor;
1659+
type MaybeSymbolic = SymbolicTensor|Tensor;
16651660

1666-
function checkAllSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[]
1667-
): tensors is SymbolicTensor | SymbolicTensor[] {
1661+
function checkAllSymbolic(tensors: MaybeSymbolic|MaybeSymbolic[]):
1662+
tensors is SymbolicTensor|SymbolicTensor[] {
16681663
let allAreSymbolic = true;
16691664
for (const tensor of generic_utils.toList(tensors)) {
16701665
if (!(tensor instanceof SymbolicTensor)) {
@@ -1675,8 +1670,8 @@ function checkAllSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[]
16751670
return allAreSymbolic;
16761671
}
16771672

1678-
function checkNoneSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[]
1679-
): tensors is Tensor | Tensor[] {
1673+
function checkNoneSymbolic(tensors: MaybeSymbolic|
1674+
MaybeSymbolic[]): tensors is Tensor|Tensor[] {
16801675
let noneAreSymbolic = true;
16811676
for (const tensor of generic_utils.toList(tensors)) {
16821677
if (tensor instanceof SymbolicTensor) {

0 commit comments

Comments
 (0)