Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 5c866ff

Browse files
William FedusMesh TensorFlow Team
authored andcommitted
Unique variable names for ParallelLayer
PiperOrigin-RevId: 378397935
1 parent b53f293 commit 5c866ff

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

mesh_tensorflow/transformer/transformer_layers.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1775,7 +1775,7 @@ class SeparableConv1DLayer(Conv1D):
17751775
across packed examples.
17761776
"""
17771777

1778-
def __init__(self,
1778+
def __init__(self, # pylint: disable=super-init-not-called
17791779
min_relative_pos,
17801780
max_relative_pos,
17811781
output_size,
@@ -2106,19 +2106,33 @@ class ParallelLayer(transformer.TransformerLayer):
21062106
Outputs are summed and divided by sqrt(n).
21072107
"""
21082108

2109-
def __init__(self, layer_classes=(DenseReluDense, SelfAttention)):
2109+
def __init__(self,
2110+
layer_classes=(DenseReluDense, SelfAttention),
2111+
use_scope=True):
21102112
"""Create a ParallelLayer.
21112113
21122114
Args:
21132115
layer_classes: a list of TransformerLayer classes
2116+
use_scope: boolean, default True, which indicates whether to use unique
2117+
variable names for each parallel_layer. Here for backward compatibility.
21142118
"""
21152119
self.layer_classes = [l() for l in layer_classes]
2120+
self.use_scope = use_scope
21162121

21172122
def call(self, context, x, losses=None):
21182123
"""Call the layer."""
2119-
return (
2120-
mtf.add_n(
2121-
[l.call(context, x, losses=losses) for l in self.layer_classes])
2122-
* (len(self.layer_classes) ** -0.5))
2124+
layer_outputs = []
2125+
2126+
if self.use_scope:
2127+
# Provide unique variable name scopes to avoid overwriting.
2128+
for i, l in enumerate(self.layer_classes):
2129+
with tf.variable_scope("parallel_layer_%d" % i):
2130+
layer_output = l.call(context, x, losses=losses)
2131+
layer_outputs.append(layer_output)
2132+
else:
2133+
layer_outputs = [
2134+
l.call(context, x, losses=losses) for l in self.layer_classes
2135+
]
2136+
return mtf.add_n(layer_outputs) * (len(self.layer_classes)**-0.5)
21232137

21242138

0 commit comments

Comments
 (0)