Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit d80f0bf

Browse files
nshazeerMesh TensorFlow Team
authored andcommitted
add ParallelLayer
PiperOrigin-RevId: 351906296
1 parent 5ce9683 commit d80f0bf

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

mesh_tensorflow/transformer/transformer_layers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2096,3 +2096,26 @@ def call(self, context, x, losses=None):
20962096
name="local_conv_attn_op_projection")
20972097

20982098
return op_projection
2099+
2100+
2101+
@gin.configurable
2102+
class ParallelLayer(transformer.TransformerLayer):
2103+
"""Multiple layers in parallel.
2104+
2105+
Outputs are summed and divided by sqrt(n).
2106+
"""
2107+
2108+
def __init__(self, layer_classes=(DenseReluDense, SelfAttention)):
2109+
"""Create a ParallelLayer.
2110+
2111+
Args:
2112+
layer_classes: a list of TransformerLayer classes
2113+
"""
2114+
self.layer_classes = [l() for l in layer_classes]
2115+
2116+
def call(self, context, x, losses=None):
2117+
"""Call the layer."""
2118+
return (
2119+
mtf.add_n(
2120+
[l.call(context, x, losses=losses) for l in self.layer_classes])
2121+
* (len(self.layer_classes) ** -0.5))

0 commit comments

Comments
 (0)