@@ -1775,7 +1775,7 @@ class SeparableConv1DLayer(Conv1D):
17751775 across packed examples.
17761776 """
17771777
1778- def __init__ (self ,
1778+ def __init__ (self , # pylint: disable=super-init-not-called
17791779 min_relative_pos ,
17801780 max_relative_pos ,
17811781 output_size ,
@@ -2106,19 +2106,33 @@ class ParallelLayer(transformer.TransformerLayer):
21062106 Outputs are summed and divided by sqrt(n).
21072107 """
21082108
2109- def __init__ (self , layer_classes = (DenseReluDense , SelfAttention )):
2109+ def __init__ (self ,
2110+ layer_classes = (DenseReluDense , SelfAttention ),
2111+ use_scope = True ):
21102112 """Create a ParallelLayer.
21112113
21122114 Args:
21132115 layer_classes: a list of TransformerLayer classes
2116+ use_scope: boolean, default True, which indicates whether to use unique
2117+ variable names for each parallel_layer. Here for backward compatibility.
21142118 """
21152119 self .layer_classes = [l () for l in layer_classes ]
2120+ self .use_scope = use_scope
21162121
21172122 def call (self , context , x , losses = None ):
21182123 """Call the layer."""
2119- return (
2120- mtf .add_n (
2121- [l .call (context , x , losses = losses ) for l in self .layer_classes ])
2122- * (len (self .layer_classes ) ** - 0.5 ))
2124+ layer_outputs = []
2125+
2126+ if self .use_scope :
2127+ # Provide unique variable name scopes to avoid overwriting.
2128+ for i , l in enumerate (self .layer_classes ):
2129+ with tf .variable_scope ("parallel_layer_%d" % i ):
2130+ layer_output = l .call (context , x , losses = losses )
2131+ layer_outputs .append (layer_output )
2132+ else :
2133+ layer_outputs = [
2134+ l .call (context , x , losses = losses ) for l in self .layer_classes
2135+ ]
2136+ return mtf .add_n (layer_outputs ) * (len (self .layer_classes )** - 0.5 )
21232137
21242138
0 commit comments