@@ -1066,8 +1066,8 @@ export abstract class Layer extends serialization.Serializable {
1066
1066
If the input tensor(s) had no previous history,
1067
1067
this does nothing.
1068
1068
*/
1069
- this . addInboundNode ( inputs , output , null , null ,
1070
- inputShape , outputShape , kwargs ) ;
1069
+ this . addInboundNode (
1070
+ inputs , output , null , null , inputShape , outputShape , kwargs ) ;
1071
1071
this . _refCount ++ ;
1072
1072
1073
1073
if ( this . activityRegularizer != null ) {
@@ -1387,29 +1387,24 @@ export abstract class Layer extends serialization.Serializable {
1387
1387
return mask ;
1388
1388
}
1389
1389
1390
- private setMaskMetadata ( inputs : Tensor | Tensor [ ] , outputs : Tensor | Tensor [ ] ,
1391
- previousMask ?: Tensor | Tensor [ ] ) : void {
1390
+ private setMaskMetadata (
1391
+ inputs : Tensor | Tensor [ ] , outputs : Tensor | Tensor [ ] ,
1392
+ previousMask ?: Tensor | Tensor [ ] ) : void {
1392
1393
if ( ! this . supportsMasking ) {
1393
1394
return ;
1394
1395
}
1395
1396
1396
1397
const outputMasks = this . computeMask ( inputs , previousMask ) ;
1397
- if ( outputs instanceof Array && outputMasks instanceof Array ) {
1398
- if ( outputs . length !== outputMasks . length ) {
1399
- throw new Error ( `${ this . name } outputs ${ outputs . length } tensors `
1400
- + `but ${ outputMasks . length } masks for those tensors` ) ;
1401
- }
1402
- for ( let i = 0 ; i < outputs . length ; i ++ ) {
1403
- outputs [ i ] . kerasMask = outputMasks [ i ] ;
1404
- }
1405
- } else if ( outputMasks instanceof Array ) {
1406
- throw new Error ( `{this.name} outputs a single tensor `
1407
- + `but ${ outputMasks . length } masks` ) ;
1408
- } else if ( outputs instanceof Array ) {
1409
- throw new Error ( `{this.name} outputs ${ outputs . length } tensors `
1410
- + `but only one mask` ) ;
1411
- } else {
1412
- outputs . kerasMask = outputMasks ;
1398
+ const outputsList = generic_utils . toList ( outputs ) ;
1399
+ const outputMasksList = generic_utils . toList ( outputMasks ) ;
1400
+
1401
+ if ( outputsList . length !== outputMasksList . length ) {
1402
+ throw new Error (
1403
+ `${ this . name } outputs ${ outputsList . length } tensors ` +
1404
+ `but ${ outputsList . length } masks for those tensors` ) ;
1405
+ }
1406
+ for ( let i = 0 ; i < outputsList . length ; i ++ ) {
1407
+ outputsList [ i ] . kerasMask = outputMasksList [ i ] ;
1413
1408
}
1414
1409
}
1415
1410
@@ -1661,10 +1656,10 @@ export function getSourceInputs(
1661
1656
}
1662
1657
}
1663
1658
1664
- type MaybeSymbolic = SymbolicTensor | Tensor ;
1659
+ type MaybeSymbolic = SymbolicTensor | Tensor ;
1665
1660
1666
- function checkAllSymbolic ( tensors : MaybeSymbolic | MaybeSymbolic [ ]
1667
- ) : tensors is SymbolicTensor | SymbolicTensor [ ] {
1661
+ function checkAllSymbolic ( tensors : MaybeSymbolic | MaybeSymbolic [ ] ) :
1662
+ tensors is SymbolicTensor | SymbolicTensor [ ] {
1668
1663
let allAreSymbolic = true ;
1669
1664
for ( const tensor of generic_utils . toList ( tensors ) ) {
1670
1665
if ( ! ( tensor instanceof SymbolicTensor ) ) {
@@ -1675,8 +1670,8 @@ function checkAllSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[]
1675
1670
return allAreSymbolic ;
1676
1671
}
1677
1672
1678
- function checkNoneSymbolic ( tensors : MaybeSymbolic | MaybeSymbolic [ ]
1679
- ) : tensors is Tensor | Tensor [ ] {
1673
+ function checkNoneSymbolic ( tensors : MaybeSymbolic |
1674
+ MaybeSymbolic [ ] ) : tensors is Tensor | Tensor [ ] {
1680
1675
let noneAreSymbolic = true ;
1681
1676
for ( const tensor of generic_utils . toList ( tensors ) ) {
1682
1677
if ( tensor instanceof SymbolicTensor ) {
0 commit comments