@@ -62,6 +62,13 @@ def __init__(self,
6262
6363 def call (self , context , x , losses = None ):
6464 """Call the layer."""
65+ has_length_dim = context .length_dim in x .shape .dims
66+ if not has_length_dim :
67+ x_shape = x .shape
68+ shape_with_length = mtf .Shape (
69+ x_shape .dims [:- 1 ] + [mtf .Dimension ("length" , 1 )]
70+ + x_shape .dims [- 1 :])
71+ x = mtf .reshape (x , shape_with_length )
6572 y , loss = transformer_moe_layer_v1 (
6673 x ,
6774 context .model_dim ,
@@ -70,6 +77,8 @@ def call(self, context, x, losses=None):
7077 context .variable_dtype )
7178 if context .losses is not None :
7279 context .losses .append (loss )
80+ if not has_length_dim :
81+ y = mtf .reshape (y , x_shape )
7382 return y
7483
7584
@@ -111,6 +120,13 @@ def __init__(self,
111120
112121 def call (self , context , x , losses = None ):
113122 """Call the layer."""
123+ has_length_dim = context .length_dim in x .shape .dims
124+ if not has_length_dim :
125+ x_shape = x .shape
126+ shape_with_length = mtf .Shape (
127+ x_shape .dims [:- 1 ] + [mtf .Dimension ("length" , 1 )]
128+ + x_shape .dims [- 1 :])
129+ x = mtf .reshape (x , shape_with_length )
114130 y , loss = transformer_moe_layer_v2 (
115131 x ,
116132 context .model_dim ,
@@ -119,6 +135,8 @@ def call(self, context, x, losses=None):
119135 context .variable_dtype )
120136 if context .losses is not None :
121137 context .losses .append (loss )
138+ if not has_length_dim :
139+ y = mtf .reshape (y , x_shape )
122140 return y
123141
124142
0 commit comments