Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit fb41a04

Browse files
nshazeerCopybara-Service
authored andcommitted
mesh_tensorflow/transformer - fixes to incremental decoding.
PiperOrigin-RevId: 224011242
1 parent 18f356e commit fb41a04

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

mesh_tensorflow/transformer/moe.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ def __init__(self,
6262

6363
def call(self, context, x, losses=None):
6464
"""Call the layer."""
65+
has_length_dim = context.length_dim in x.shape.dims
66+
if not has_length_dim:
67+
x_shape = x.shape
68+
shape_with_length = mtf.Shape(
69+
x_shape.dims[:-1] + [mtf.Dimension("length", 1)]
70+
+ x_shape.dims[-1:])
71+
x = mtf.reshape(x, shape_with_length)
6572
y, loss = transformer_moe_layer_v1(
6673
x,
6774
context.model_dim,
@@ -70,6 +77,8 @@ def call(self, context, x, losses=None):
7077
context.variable_dtype)
7178
if context.losses is not None:
7279
context.losses.append(loss)
80+
if not has_length_dim:
81+
y = mtf.reshape(y, x_shape)
7382
return y
7483

7584

@@ -111,6 +120,13 @@ def __init__(self,
111120

112121
def call(self, context, x, losses=None):
113122
"""Call the layer."""
123+
has_length_dim = context.length_dim in x.shape.dims
124+
if not has_length_dim:
125+
x_shape = x.shape
126+
shape_with_length = mtf.Shape(
127+
x_shape.dims[:-1] + [mtf.Dimension("length", 1)]
128+
+ x_shape.dims[-1:])
129+
x = mtf.reshape(x, shape_with_length)
114130
y, loss = transformer_moe_layer_v2(
115131
x,
116132
context.model_dim,
@@ -119,6 +135,8 @@ def call(self, context, x, losses=None):
119135
context.variable_dtype)
120136
if context.losses is not None:
121137
context.losses.append(loss)
138+
if not has_length_dim:
139+
y = mtf.reshape(y, x_shape)
122140
return y
123141

124142

mesh_tensorflow/transformer/transformer_layers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,7 @@ def call(self, context, x, losses=None):
153153
if context.mode == "incremental":
154154
prev_k, prev_v = context.next_states(2)
155155
y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental(
156-
x, prev_k, prev_v,
157-
context.position, context.master_dtype, context.slice_dtype,
158-
params=params)
156+
x, prev_k, prev_v, context.position, params=params)
159157
context.new_states.extend([new_k, new_v])
160158
return y
161159
else:

0 commit comments

Comments
 (0)