Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit f3c7bc6

Browse files
nshazeerCopybara-Service
authored andcommitted
Added local attention option for mtf transformer (including incremental decoding). Updated configurations for MoE experiments. Updated local attention 1d code.
PiperOrigin-RevId: 220198118
1 parent d530562 commit f3c7bc6

File tree

3 files changed

+150
-109
lines changed

3 files changed

+150
-109
lines changed

mesh_tensorflow/layers.py

Lines changed: 111 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -343,128 +343,142 @@ def local_self_attention_spatial_blocks(
343343
[output, o_var], mtf.Shape([batch, num_w_blocks, w_dim, io_channels]))
344344

345345

346-
def masked_local_attention_1d(query_antecedent,
347-
memory_antecedent,
346+
def masked_local_attention_1d(x,
348347
kv_channels,
349348
heads,
350-
block_length=128,
349+
window_size=128,
351350
master_dtype=tf.float32,
352351
slice_dtype=tf.float32,
352+
length_per_split=None,
353353
name=None):
354354
"""Attention to the source position and a neighborhood to the left of it.
355355
356-
The sequence is divided into blocks of length block_size.
357-
Attention for a given query position can only see memory positions
358-
less than or equal to the query position, in the corresponding block
359-
and the previous block.
356+
Attention for a given query position p can only see memory positions
357+
in the range (p - window_size, p].
360358
361359
Args:
362-
query_antecedent: a mtf.Tensor with shape [batch, query_length, io_channels]
363-
memory_antecedent: a mtf.Tensor with shape
364-
[batch, memory_length, io_channels] (optional). Currently, memory_length
365-
must have the same size as query_length, but a different name.
360+
x: a mtf.Tensor with shape batch_dims + [length, io_channels]
366361
kv_channels: a mtf.Dimension (the size of the key and value vectors)
367362
heads: a mtf.Dimension (the number of heads)
368-
block_length: an integer, representing receptive fields for attention.
363+
window_size: an integer
369364
master_dtype: a tf.dtype
370365
slice_dtype: a tf.dtype
366+
length_per_split: an optional integer indicating the part of the length
367+
dimension per processor. You can omit if the length dimension is not
368+
split.
371369
name: an optional string.
372370
373371
Returns:
374-
a Tensor of shape [batch, query_length, io_channels]
372+
a Tensor with the same shape as x
375373
376374
Raises:
377375
ValueError: if channels or depth don't match.
378376
"""
379377
with tf.variable_scope(
380-
name, default_name="multihead_attention",
381-
values=[query_antecedent, memory_antecedent]):
378+
name, default_name="masked_local_attention_1d", values=[x]):
382379

383-
batch, query_length, io_channels = query_antecedent.shape.dims
380+
batch_dims = x.shape.dims[:-2]
381+
length, io_channels = x.shape.dims[-2:]
384382
q_var, k_var, v_var, o_var = multihead_attention_vars(
385-
query_antecedent.mesh, heads, io_channels, kv_channels,
386-
master_dtype, slice_dtype, query_antecedent.dtype)
387-
388-
if memory_antecedent is None:
389-
memory_antecedent = rename_length_to_memory_length(
390-
query_antecedent, query_length.name)
391-
memory_batch, memory_length, memory_channels = memory_antecedent.shape.dims
392-
if memory_batch != batch:
393-
raise ValueError("memory batch must equal query batch")
394-
if memory_channels != io_channels:
395-
raise ValueError("memory channels must equal query channels")
383+
x.mesh, heads, io_channels, kv_channels,
384+
master_dtype, slice_dtype, x.dtype)
396385

397386
# Get query q, keys k and values v.
398-
q = mtf.einsum(
399-
[query_antecedent, q_var],
400-
mtf.Shape([batch, heads, query_length, kv_channels]))
401-
k = mtf.einsum(
402-
[memory_antecedent, k_var],
403-
mtf.Shape([batch, heads, memory_length, kv_channels]))
404-
v = mtf.einsum(
405-
[memory_antecedent, v_var],
406-
mtf.Shape([batch, heads, memory_length, kv_channels]))
407-
408-
# Let's assume for now we don't have padding and the block length equally
409-
# divides the memory length.
410-
block_length = (query_length.size
411-
if query_length.size < block_length * 2 else block_length)
412-
blength = mtf.Dimension("block_length", block_length)
413-
mlength = mtf.Dimension("mem_block_length", block_length)
414-
num_blocks = mtf.Dimension("num_blocks", query_length.size // block_length)
415-
416-
q = mtf.reshape(
417-
q, mtf.Shape([batch, heads, num_blocks, blength, kv_channels]))
418-
k = mtf.reshape(
419-
k, mtf.Shape([batch, heads, num_blocks, mlength, kv_channels]))
420-
v = mtf.reshape(
421-
v, mtf.Shape([batch, heads, num_blocks, mlength, kv_channels]))
422-
423-
# compute attention for the first query block.
424-
def first_block_attention():
425-
"""Compute attention for the first block."""
426-
first_q = mtf.slice(q, 0, 1, num_blocks.name)
427-
first_k = mtf.slice(k, 0, 1, num_blocks.name)
428-
first_v = mtf.slice(v, 0, 1, num_blocks.name)
429-
first_output = dot_product_attention(first_q,
430-
first_k,
431-
first_v,
432-
mask=None)
433-
return first_output
434-
435-
# Attention for first block, since query_length = key_length.
436-
first_output = first_block_attention()
437-
438-
# Concatenate two adjacent blocks to compute the overlapping memory block.
439-
def local(x):
440-
"""Helper function to get memory blocks."""
441-
prev_block = mtf.slice(x, 0, num_blocks.size-1, num_blocks.name)
442-
cur_block = mtf.slice(x, 1, num_blocks.size-1, num_blocks.name)
443-
local_block = mtf.concat([prev_block, cur_block], mlength.name)
444-
return local_block
445-
446-
local_k = local(k)
447-
local_v = local(v)
448-
# Calculate the causal mask to avoid peeking into the future. We compute
449-
# this once and reuse it for all blocks since the block_size is known.
450-
mlength = local_k.shape.dims[3]
451-
mask = attention_bias_local_block(query_antecedent.mesh,
452-
blength, mlength)
453-
454-
# Remove the first block from q since we already computed that.
455-
tail_q = mtf.slice(q, 1, num_blocks.size-1, num_blocks.name)
456-
457-
tail_output = dot_product_attention(tail_q,
458-
local_k,
459-
local_v,
460-
mask=mask)
461-
462-
# Now concatenate the first and rest of the blocks.
463-
final_output = mtf.concat([first_output, tail_output], num_blocks.name)
464-
final_output = mtf.reshape(final_output, mtf.Shape(
465-
[batch, heads, query_length, kv_channels]))
466-
return mtf.einsum([final_output, o_var],
467-
mtf.Shape([batch, query_length, io_channels]))
387+
qkv_shape = mtf.Shape(batch_dims + [heads, length, kv_channels])
388+
q = mtf.einsum([x, q_var], qkv_shape)
389+
k = mtf.einsum([x, k_var], qkv_shape)
390+
v = mtf.einsum([x, v_var], qkv_shape)
391+
392+
# Choose a suitable block size.
393+
# We choose the greatest divisor of length_per_split less than or equal
394+
# to max(window_size, 128)
395+
if length_per_split is None:
396+
length_per_split = length.size
397+
block_length = max(window_size, 128)
398+
while length_per_split % block_length != 0:
399+
block_length -= 1
400+
401+
query_block_length = mtf.Dimension("query_block_length", block_length)
402+
memory_block_length = mtf.Dimension("memory_block_length", block_length)
403+
# The num_blocks dimension gets the same name as the length dimension,
404+
# so it will be split in the same way.
405+
num_blocks = mtf.Dimension(length.name, length.size // block_length)
406+
q_shape = batch_dims + [heads, num_blocks, query_block_length, kv_channels]
407+
kv_shape = batch_dims + [
408+
heads, num_blocks, memory_block_length, kv_channels]
409+
q = mtf.reshape(q, q_shape)
410+
k = mtf.reshape(k, kv_shape)
411+
v = mtf.reshape(v, kv_shape)
412+
# augment the keys and values for each block with keys and values for
413+
# the previous window_size timesteps.
414+
k = mtf.left_halo_exchange(k, num_blocks, memory_block_length, window_size)
415+
v = mtf.left_halo_exchange(v, num_blocks, memory_block_length, window_size)
416+
padded_memory_block_length = mtf.Dimension(
417+
"memory_block_length", window_size + block_length)
418+
mpos = mtf.range(x.mesh, padded_memory_block_length, tf.float32)
419+
qpos = mtf.range(x.mesh, query_block_length, tf.float32) + window_size
420+
# prevent looking forward
421+
mask = mtf.cast(mtf.greater(mpos, qpos), x.dtype) * -1e9
422+
# prevent looking >=block_length timesteps backward
423+
mask += mtf.cast(mtf.less_equal(mpos, qpos - block_length), x.dtype) * -1e9
424+
# Note: The first window_size-1 positions can see back into pre-time
425+
# where all the keys and values are zero. We could mask this out, but we
426+
# don't.
427+
o = dot_product_attention(q, k, v, mask=mask)
428+
o = mtf.reshape(o, batch_dims + [heads, length, kv_channels])
429+
return mtf.einsum([o, o_var], mtf.Shape(batch_dims + [length, io_channels]))
430+
431+
432+
def masked_local_attention_1d_incremental(x,
433+
prev_k,
434+
prev_v,
435+
step_num,
436+
master_dtype,
437+
slice_dtype,
438+
name=None):
439+
"""Incremental local self-attention (one decode step).
440+
441+
Incremental version of masked_local_attention_1d()
442+
443+
Args:
444+
x: a mtf.Tensor with shape [batch..., io_channels]
445+
prev_k: mtf.Tensor with shape
446+
[batch..., heads, window_length, kv_channels]
447+
prev_v: mtf.Tensor with shape
448+
[batch..., heads, window_length, kv_channels]
449+
step_num: mtf Scalar with dtype tf.int32
450+
master_dtype: a tf.dtype
451+
slice_dtype: a tf.dtype
452+
name: an optional string.
453+
454+
Returns:
455+
y: A mtf.Tensor with shape [batch..., io_channels]
456+
new_k: mtf.Tensor with shape
457+
[batch..., heads, window_length, kv_channels]
458+
new_v: mtf.Tensor with shape
459+
[batch..., heads, window_length, kv_channels]
460+
461+
Raises:
462+
ValueError: if the dimensions do not match.
463+
"""
464+
batch_dims = x.shape.dims[:-1]
465+
io_channels = x.shape.dims[-1]
466+
heads, window_length, kv_channels = prev_k.shape.dims[-3:]
467+
with tf.variable_scope(name, default_name="multihead_attention"):
468+
q_var, k_var, v_var, o_var = multihead_attention_vars(
469+
x.mesh, heads, io_channels, kv_channels,
470+
master_dtype, slice_dtype, x.dtype)
471+
q = mtf.einsum([x, q_var], mtf.Shape(batch_dims + [heads, kv_channels]))
472+
k = mtf.einsum([x, k_var], mtf.Shape(batch_dims + [heads, kv_channels]))
473+
v = mtf.einsum([x, v_var], mtf.Shape(batch_dims + [heads, kv_channels]))
474+
current_position = mtf.equal(
475+
mtf.range(x.mesh, window_length, dtype=tf.int32),
476+
mtf.mod(step_num, window_length.size))
477+
k = mtf.where(current_position, k, prev_k, output_shape=prev_k.shape)
478+
v = mtf.where(current_position, v, prev_v, output_shape=prev_v.shape)
479+
o = dot_product_attention(q, k, v, mask=None)
480+
y = mtf.einsum([o, o_var], x.shape)
481+
return y, k, v
468482

469483

470484
def local_2d_halo_exchange(k, v, num_h_blocks, h_dim,

mesh_tensorflow/layers_test.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,33 +156,26 @@ def testDenseReluDense(self):
156156
(1, 8, 5, 3, 1, 4),
157157
)
158158
def testMaskedLocalAttention1D(self, batch, length, io_channels, kv_channels,
159-
heads, block_length):
159+
heads, window_size):
160160
length_q = length
161-
length_m = length
162161
query = tf.random_normal([batch, length_q, io_channels])
163-
memory = tf.random_normal([batch, length_m, io_channels])
164162

165163
graph = mtf.Graph()
166164
mesh = mtf.Mesh(graph, "my_mesh")
167165
batch_dim = mtf.Dimension("batch", batch)
168166
length_q_dim = mtf.Dimension("length_q", length_q)
169-
length_m_dim = mtf.Dimension("length_m", length_m)
170167
io_channels_dim = mtf.Dimension("io_channels", io_channels)
171168
kv_channels_dim = mtf.Dimension("kv_channels", kv_channels)
172169
heads_dim = mtf.Dimension("heads", heads)
173170

174171
mtf_query = mtf.import_tf_tensor(
175172
mesh, query,
176173
shape=mtf.Shape([batch_dim, length_q_dim, io_channels_dim]))
177-
mtf_memory = mtf.import_tf_tensor(
178-
mesh, memory,
179-
shape=mtf.Shape([batch_dim, length_m_dim, io_channels_dim]))
180174
mtf_outputs = mtf.layers.masked_local_attention_1d(
181175
mtf_query,
182-
mtf_memory,
183176
kv_channels=kv_channels_dim,
184177
heads=heads_dim,
185-
block_length=block_length)
178+
window_size=window_size)
186179
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
187180
shape=[], layout={}, devices=[""])
188181
lowering = mtf.Lowering(graph, {mesh: mesh_impl})

mesh_tensorflow/ops.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3283,6 +3283,10 @@ def add(x1, x2, output_shape=None, name=None):
32833283
x1.shape, x2.shape, output_shape)).outputs[0]
32843284

32853285

3286+
def add_n(xs):
3287+
return reduce(add, xs)
3288+
3289+
32863290
def sub(x1, x2, output_shape=None, name=None):
32873291
"""Binary subtraction with broadcsting.
32883292
@@ -4048,11 +4052,12 @@ def my_body_fn(*inputs):
40484052
my_cond_fn, my_body_fn, inputs, kwargs).outputs
40494053

40504054

4051-
def where(condition, if_true, if_false):
4055+
def where(condition, if_true, if_false, output_shape=None):
40524056
dtype = if_true.dtype
40534057
return (
4054-
if_true * cast(condition, dtype) +
4055-
if_false * cast(logical_not(condition), dtype))
4058+
multiply(if_true, cast(condition, dtype), output_shape=output_shape) +
4059+
multiply(if_false,
4060+
cast(logical_not(condition), dtype), output_shape=output_shape))
40564061

40574062

40584063
def _shape_union(shapes):
@@ -4241,3 +4246,32 @@ def conv2d_with_blocks(
42414246
conv_input = pad(
42424247
conv_input, [halo_size, halo_size], block_size_dim.name)
42434248
return conv2d(conv_input, conv_filter, strides, "VALID", name)
4249+
4250+
4251+
def tensor_dim_to_mesh_dim_size(layout, mesh_shape, tensor_dim):
4252+
"""How many ways does a tensor dimension get split.
4253+
4254+
This is used to "cheat" when building the mtf graph and peek at how a
4255+
tensor dimension will be split. Returns 1 if the tensor dimension is not
4256+
split.
4257+
4258+
Args:
4259+
layout: an input to convert_to_layout_rules
4260+
mesh_shape: an in put to convert_to_shape
4261+
tensor_dim: a Dimension
4262+
4263+
Returns:
4264+
an integer
4265+
"""
4266+
layout_rules = convert_to_layout_rules(layout)
4267+
mesh_shape = convert_to_shape(mesh_shape)
4268+
mesh_axis = layout_rules.tensor_dimension_to_mesh_axis(tensor_dim, mesh_shape)
4269+
if mesh_axis is None:
4270+
return 1
4271+
else:
4272+
return mesh_shape.dims[mesh_axis].size
4273+
4274+
4275+
def tensor_dim_to_size_per_split(layout, mesh_shape, tensor_dim):
4276+
return tensor_dim.size // tensor_dim_to_mesh_dim_size(
4277+
layout, mesh_shape, tensor_dim)

0 commit comments

Comments
 (0)