Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 7ff0a4d

Browse files
Niki ParmarCopybara-Service
authored andcommitted
Image transformer with local1d and local 2d spatial partitioning.
PiperOrigin-RevId: 218263114
1 parent a9d6a74 commit 7ff0a4d

File tree

1 file changed

+161
-18
lines changed

1 file changed

+161
-18
lines changed

mesh_tensorflow/layers.py

Lines changed: 161 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -242,21 +242,22 @@ def dense_relu_dense(x,
242242
return mtf.einsum([h, wo])
243243

244244

245-
def local_1d_halo_exchange(k, v, num_w_blocks, w_dim, memory_w_dim, mask_right):
245+
def local_1d_halo_exchange(k, v, num_w_blocks, w_dim, mask_right):
246246
"""Halo exchange for keys and values for Local 1D attention."""
247247
if num_w_blocks is not None:
248248
if mask_right:
249-
k = mtf.left_halo_exchange(k, num_w_blocks, w_dim, memory_w_dim.size)
250-
v = mtf.left_halo_exchange(v, num_w_blocks, w_dim, memory_w_dim.size)
249+
k = mtf.left_halo_exchange(k, num_w_blocks, w_dim, w_dim.size)
250+
v = mtf.left_halo_exchange(v, num_w_blocks, w_dim, w_dim.size)
251251
else:
252-
k = mtf.halo_exchange(k, num_w_blocks, w_dim, memory_w_dim.size)
253-
v = mtf.halo_exchange(v, num_w_blocks, w_dim, memory_w_dim.size)
252+
k = mtf.halo_exchange(k, num_w_blocks, w_dim, w_dim.size)
253+
v = mtf.halo_exchange(v, num_w_blocks, w_dim, w_dim.size)
254254
else:
255255
if mask_right:
256-
k = mtf.pad(k, [memory_w_dim, None], w_dim.name)
256+
k = mtf.pad(k, [w_dim, None], w_dim.name)
257+
v = mtf.pad(v, [w_dim, None], w_dim.name)
257258
else:
258-
k = mtf.pad(k, [memory_w_dim, memory_w_dim], w_dim.name)
259-
v = mtf.pad(v, [memory_w_dim, memory_w_dim], w_dim.name)
259+
k = mtf.pad(k, [w_dim, w_dim], w_dim.name)
260+
v = mtf.pad(v, [w_dim, w_dim], w_dim.name)
260261
return k, v
261262

262263

@@ -304,23 +305,22 @@ def local_self_attention_spatial_blocks(
304305

305306
# Rename dimensions for the memory height and width.
306307
memory_antecedent = mtf.rename_dimension(
307-
query_antecedent, w_dim.name, memory_w_dim.name)
308+
query_antecedent, w_dim.name, "memory_" + w_dim.name)
309+
memory_w_dim = memory_antecedent.shape.dims[-2]
308310

309311
# Call einsum over the query and memory to get query q, keys k and values v.
310312
q = mtf.einsum(
311313
[query_antecedent, q_var],
312314
mtf.Shape([batch, heads, num_w_blocks, w_dim, kv_channels]))
313315
k = mtf.einsum(
314316
[memory_antecedent, k_var],
315-
mtf.Shape([batch, heads, num_w_blocks, w_dim, kv_channels]))
317+
mtf.Shape([batch, heads, num_w_blocks, memory_w_dim, kv_channels]))
316318
v = mtf.einsum(
317319
[memory_antecedent, v_var],
318-
mtf.Shape([batch, heads, num_w_blocks, w_dim, kv_channels]))
320+
mtf.Shape([batch, heads, num_w_blocks, memory_w_dim, kv_channels]))
319321

320322
# Halo exchange for memory blocks.
321-
if memory_w_dim is not None:
322-
k, v = local_1d_halo_exchange(
323-
k, v, num_w_blocks, w_dim, memory_w_dim, mask_right)
323+
k, v = local_1d_halo_exchange(k, v, num_w_blocks, memory_w_dim, mask_right)
324324

325325
# Calculate the causal mask to avoid peeking into the future. We compute
326326
# this once and reuse it for all blocks since the block_size is known.
@@ -332,8 +332,7 @@ def local_self_attention_spatial_blocks(
332332
output = dot_product_attention(q, k, v, mask=mask)
333333

334334
return mtf.einsum(
335-
[output, o_var],
336-
mtf.Shape([batch, num_w_blocks, w_dim, io_channels]))
335+
[output, o_var], mtf.Shape([batch, num_w_blocks, w_dim, io_channels]))
337336

338337

339338
def masked_local_attention_1d(query_antecedent,
@@ -456,6 +455,112 @@ def local(x):
456455
mtf.Shape([batch, query_length, io_channels]))
457456

458457

458+
def local_2d_halo_exchange(k, v, num_h_blocks, h_dim,
459+
num_w_blocks, w_dim, mask_right):
460+
"""Halo exchange for keys and values for Local 2D attention."""
461+
for blocks_dim, block_size_dim, halo_size in [
462+
(num_h_blocks, h_dim, h_dim.size),
463+
(num_w_blocks, w_dim, w_dim.size)]:
464+
# shape of k is [num_h_blocks, num_w_blocks, h_dim, w_dim, kv_channels]
465+
if halo_size > 0:
466+
if blocks_dim is not None:
467+
if mask_right:
468+
k = mtf.left_halo_exchange(k, blocks_dim, block_size_dim, halo_size)
469+
v = mtf.left_halo_exchange(v, blocks_dim, block_size_dim, halo_size)
470+
else:
471+
k = mtf.halo_exchange(k, blocks_dim, block_size_dim, halo_size)
472+
v = mtf.halo_exchange(v, blocks_dim, block_size_dim, halo_size)
473+
else:
474+
if mask_right:
475+
k = mtf.pad(k, [halo_size, None], block_size_dim.name)
476+
v = mtf.pad(v, [halo_size, None], block_size_dim.name)
477+
else:
478+
k = mtf.pad(k, [halo_size, halo_size], block_size_dim.name)
479+
v = mtf.pad(v, [halo_size, halo_size], block_size_dim.name)
480+
return k, v
481+
482+
483+
def local_2d_self_attention_spatial_blocks(query_antecedent,
484+
kv_channels,
485+
heads,
486+
memory_h_dim=None,
487+
memory_w_dim=None,
488+
mask_right=False,
489+
name=None):
490+
"""Attention to the source position and a neighborhood to the left or right.
491+
492+
The sequence is divided into blocks of length block_size.
493+
Attention for a given query position can only see memory positions
494+
less than or equal to the query position, in the corresponding block
495+
and the previous block.
496+
497+
Args:
498+
query_antecedent: a mtf.Tensor with shape [batch, num_h_blocks,
499+
num_w_blocks, h_dim, w_dim, io_channels] must have the same size as
500+
query_length, but a different name.
501+
kv_channels: a mtf.Dimension (the size of the key and value vectors)
502+
heads: a mtf.Dimension (the number of heads)
503+
memory_h_dim: mtf Dimension, for the memory height block.
504+
memory_w_dim: mtf Dimension, for the memory width block.
505+
mask_right: bool, flag specifying whether we mask out attention to the right
506+
for the decoder.
507+
name: an optional string.
508+
509+
Returns:
510+
a Tensor of shape
511+
[batch, num_h_blocks, num_w_blocks, h_dim, w_dim, io_channels]
512+
513+
Raises:
514+
ValueError: if channels or depth don't match.
515+
"""
516+
with tf.variable_scope(
517+
name, default_name="multihead_attention", values=[query_antecedent]):
518+
519+
h_dim, w_dim, io_channels = query_antecedent.shape.dims[-3:]
520+
batch, num_h_blocks, num_w_blocks = query_antecedent.shape.dims[:3]
521+
q_var, k_var, v_var, o_var = multihead_attention_vars(
522+
query_antecedent.mesh, heads, io_channels, kv_channels,
523+
query_antecedent.dtype)
524+
525+
# Rename dimensions for the memory height and width.
526+
memory_antecedent = mtf.rename_dimension(query_antecedent, h_dim.name,
527+
"memory_" + h_dim.name)
528+
memory_antecedent = mtf.rename_dimension(memory_antecedent, w_dim.name,
529+
"memory_" + w_dim.name)
530+
memory_h_dim, memory_w_dim = memory_antecedent.shape.dims[-3:-1]
531+
532+
# Call einsum over the query and memory to get query q, keys k and values v.
533+
q = mtf.einsum([query_antecedent, q_var],
534+
mtf.Shape([
535+
batch, heads, num_h_blocks, num_w_blocks, h_dim, w_dim,
536+
kv_channels
537+
]))
538+
k = mtf.einsum([memory_antecedent, k_var],
539+
mtf.Shape([batch, heads, num_h_blocks, num_w_blocks,
540+
memory_h_dim, memory_w_dim, kv_channels]))
541+
v = mtf.einsum([memory_antecedent, v_var],
542+
mtf.Shape([batch, heads, num_h_blocks, num_w_blocks,
543+
memory_h_dim, memory_w_dim, kv_channels]))
544+
545+
# Halo exchange for memory blocks.
546+
k, v = local_2d_halo_exchange(k, v, num_h_blocks, memory_h_dim,
547+
num_w_blocks, memory_w_dim, mask_right)
548+
549+
# Calculate the causal mask to avoid peeking into the future. We compute
550+
# this once and reuse it for all blocks since the block_size is known.
551+
mask = None
552+
if mask_right:
553+
mask = attention_bias_local_2d_block(query_antecedent.mesh, h_dim, w_dim,
554+
memory_h_dim, memory_w_dim)
555+
556+
output = dot_product_attention(q, k, v, mask=mask)
557+
558+
return mtf.einsum(
559+
[output, o_var],
560+
mtf.Shape(
561+
[batch, num_h_blocks, num_w_blocks, h_dim, w_dim, io_channels]))
562+
563+
459564
def rename_length_to_memory_length(
460565
x, length_name="length", memory_length_name="memory_length"):
461566
return mtf.rename_dimension(x, length_name, memory_length_name)
@@ -759,6 +864,44 @@ def attention_bias_local_block(mesh, block_length, memory_length,
759864
mask = mtf.cast(mtf.less(mtf.range(mesh, block_length, dtype=dtype),
760865
mtf.range(mesh, memory_length, dtype=dtype)),
761866
dtype=dtype)
762-
mask = mtf.cast(mtf.concat([memory_mask, mask], memory_length.name),
763-
dtype=tf.float32) * -1e9
867+
mask = mtf.cast(
868+
mtf.concat([memory_mask, mask], memory_length.name),
869+
dtype=tf.float32) * -1e9
870+
return mask
871+
872+
873+
def attention_bias_local_2d_block(mesh,
874+
h_dim,
875+
w_dim,
876+
memory_h_dim,
877+
memory_w_dim,
878+
dtype=tf.int32):
879+
"""Bias for attention for local blocks where attention to right is disallowed.
880+
881+
Create the bias matrix by using two separate masks, one for the memory part
882+
which doesn't overlap with the query and second which interacts with the query
883+
and should be disallowed to look to the right of the current query position.
884+
885+
Args:
886+
mesh: a MeshTensorflow object
887+
h_dim: a mtf.Dimension
888+
w_dim: a mtf.Dimension
889+
memory_h_dim: a mtf.Dimension
890+
memory_w_dim: a mtf.Dimension
891+
dtype: a tf.dtype
892+
893+
Returns:
894+
a mtf.Tensor with shape [block_length, memory_length]
895+
"""
896+
memory_height = mtf.Dimension(memory_h_dim.name, h_dim.size)
897+
memory_width = mtf.Dimension(memory_w_dim.name, w_dim.size)
898+
mask_top_visible = mtf.zeros(mesh, [h_dim, memory_height], dtype=dtype)
899+
mask_left_visible = mtf.zeros(mesh, [w_dim, memory_width], dtype=dtype)
900+
mask_query = mtf.greater(
901+
mtf.range(mesh, memory_height, dtype=tf.int32),
902+
mtf.range(mesh, memory_width, dtype=dtype))
903+
width_mask = mtf.concat([mask_left_visible, mask_query], memory_width.name)
904+
mask = mtf.cast(
905+
mtf.concat([mask_top_visible, width_mask], memory_height.name),
906+
dtype=tf.float32) * -1e9
764907
return mask

0 commit comments

Comments
 (0)