@@ -242,21 +242,22 @@ def dense_relu_dense(x,
242242 return mtf .einsum ([h , wo ])
243243
244244
245- def local_1d_halo_exchange (k , v , num_w_blocks , w_dim , memory_w_dim , mask_right ):
245+ def local_1d_halo_exchange (k , v , num_w_blocks , w_dim , mask_right ):
246246 """Halo exchange for keys and values for Local 1D attention."""
247247 if num_w_blocks is not None :
248248 if mask_right :
249- k = mtf .left_halo_exchange (k , num_w_blocks , w_dim , memory_w_dim .size )
250- v = mtf .left_halo_exchange (v , num_w_blocks , w_dim , memory_w_dim .size )
249+ k = mtf .left_halo_exchange (k , num_w_blocks , w_dim , w_dim .size )
250+ v = mtf .left_halo_exchange (v , num_w_blocks , w_dim , w_dim .size )
251251 else :
252- k = mtf .halo_exchange (k , num_w_blocks , w_dim , memory_w_dim .size )
253- v = mtf .halo_exchange (v , num_w_blocks , w_dim , memory_w_dim .size )
252+ k = mtf .halo_exchange (k , num_w_blocks , w_dim , w_dim .size )
253+ v = mtf .halo_exchange (v , num_w_blocks , w_dim , w_dim .size )
254254 else :
255255 if mask_right :
256- k = mtf .pad (k , [memory_w_dim , None ], w_dim .name )
256+ k = mtf .pad (k , [w_dim , None ], w_dim .name )
257+ v = mtf .pad (v , [w_dim , None ], w_dim .name )
257258 else :
258- k = mtf .pad (k , [memory_w_dim , memory_w_dim ], w_dim .name )
259- v = mtf .pad (v , [memory_w_dim , memory_w_dim ], w_dim .name )
259+ k = mtf .pad (k , [w_dim , w_dim ], w_dim .name )
260+ v = mtf .pad (v , [w_dim , w_dim ], w_dim .name )
260261 return k , v
261262
262263
@@ -304,23 +305,22 @@ def local_self_attention_spatial_blocks(
304305
305306 # Rename dimensions for the memory height and width.
306307 memory_antecedent = mtf .rename_dimension (
307- query_antecedent , w_dim .name , memory_w_dim .name )
308+ query_antecedent , w_dim .name , "memory_" + w_dim .name )
309+ memory_w_dim = memory_antecedent .shape .dims [- 2 ]
308310
309311 # Call einsum over the query and memory to get query q, keys k and values v.
310312 q = mtf .einsum (
311313 [query_antecedent , q_var ],
312314 mtf .Shape ([batch , heads , num_w_blocks , w_dim , kv_channels ]))
313315 k = mtf .einsum (
314316 [memory_antecedent , k_var ],
315- mtf .Shape ([batch , heads , num_w_blocks , w_dim , kv_channels ]))
317+ mtf .Shape ([batch , heads , num_w_blocks , memory_w_dim , kv_channels ]))
316318 v = mtf .einsum (
317319 [memory_antecedent , v_var ],
318- mtf .Shape ([batch , heads , num_w_blocks , w_dim , kv_channels ]))
320+ mtf .Shape ([batch , heads , num_w_blocks , memory_w_dim , kv_channels ]))
319321
320322 # Halo exchange for memory blocks.
321- if memory_w_dim is not None :
322- k , v = local_1d_halo_exchange (
323- k , v , num_w_blocks , w_dim , memory_w_dim , mask_right )
323+ k , v = local_1d_halo_exchange (k , v , num_w_blocks , memory_w_dim , mask_right )
324324
325325 # Calculate the causal mask to avoid peeking into the future. We compute
326326 # this once and reuse it for all blocks since the block_size is known.
@@ -332,8 +332,7 @@ def local_self_attention_spatial_blocks(
332332 output = dot_product_attention (q , k , v , mask = mask )
333333
334334 return mtf .einsum (
335- [output , o_var ],
336- mtf .Shape ([batch , num_w_blocks , w_dim , io_channels ]))
335+ [output , o_var ], mtf .Shape ([batch , num_w_blocks , w_dim , io_channels ]))
337336
338337
339338def masked_local_attention_1d (query_antecedent ,
@@ -456,6 +455,112 @@ def local(x):
456455 mtf .Shape ([batch , query_length , io_channels ]))
457456
458457
458+ def local_2d_halo_exchange (k , v , num_h_blocks , h_dim ,
459+ num_w_blocks , w_dim , mask_right ):
460+ """Halo exchange for keys and values for Local 2D attention."""
461+ for blocks_dim , block_size_dim , halo_size in [
462+ (num_h_blocks , h_dim , h_dim .size ),
463+ (num_w_blocks , w_dim , w_dim .size )]:
464+ # shape of k is [num_h_blocks, num_w_blocks, h_dim, w_dim, kv_channels]
465+ if halo_size > 0 :
466+ if blocks_dim is not None :
467+ if mask_right :
468+ k = mtf .left_halo_exchange (k , blocks_dim , block_size_dim , halo_size )
469+ v = mtf .left_halo_exchange (v , blocks_dim , block_size_dim , halo_size )
470+ else :
471+ k = mtf .halo_exchange (k , blocks_dim , block_size_dim , halo_size )
472+ v = mtf .halo_exchange (v , blocks_dim , block_size_dim , halo_size )
473+ else :
474+ if mask_right :
475+ k = mtf .pad (k , [halo_size , None ], block_size_dim .name )
476+ v = mtf .pad (v , [halo_size , None ], block_size_dim .name )
477+ else :
478+ k = mtf .pad (k , [halo_size , halo_size ], block_size_dim .name )
479+ v = mtf .pad (v , [halo_size , halo_size ], block_size_dim .name )
480+ return k , v
481+
482+
483+ def local_2d_self_attention_spatial_blocks (query_antecedent ,
484+ kv_channels ,
485+ heads ,
486+ memory_h_dim = None ,
487+ memory_w_dim = None ,
488+ mask_right = False ,
489+ name = None ):
490+ """Attention to the source position and a neighborhood to the left or right.
491+
492+ The sequence is divided into blocks of length block_size.
493+ Attention for a given query position can only see memory positions
494+ less than or equal to the query position, in the corresponding block
495+ and the previous block.
496+
497+ Args:
498+ query_antecedent: a mtf.Tensor with shape [batch, num_h_blocks,
499+ num_w_blocks, h_dim, w_dim, io_channels] must have the same size as
500+ query_length, but a different name.
501+ kv_channels: a mtf.Dimension (the size of the key and value vectors)
502+ heads: a mtf.Dimension (the number of heads)
503+ memory_h_dim: mtf Dimension, for the memory height block.
504+ memory_w_dim: mtf Dimension, for the memory width block.
505+ mask_right: bool, flag specifying whether we mask out attention to the right
506+ for the decoder.
507+ name: an optional string.
508+
509+ Returns:
510+ a Tensor of shape
511+ [batch, num_h_blocks, num_w_blocks, h_dim, w_dim, io_channels]
512+
513+ Raises:
514+ ValueError: if channels or depth don't match.
515+ """
516+ with tf .variable_scope (
517+ name , default_name = "multihead_attention" , values = [query_antecedent ]):
518+
519+ h_dim , w_dim , io_channels = query_antecedent .shape .dims [- 3 :]
520+ batch , num_h_blocks , num_w_blocks = query_antecedent .shape .dims [:3 ]
521+ q_var , k_var , v_var , o_var = multihead_attention_vars (
522+ query_antecedent .mesh , heads , io_channels , kv_channels ,
523+ query_antecedent .dtype )
524+
525+ # Rename dimensions for the memory height and width.
526+ memory_antecedent = mtf .rename_dimension (query_antecedent , h_dim .name ,
527+ "memory_" + h_dim .name )
528+ memory_antecedent = mtf .rename_dimension (memory_antecedent , w_dim .name ,
529+ "memory_" + w_dim .name )
530+ memory_h_dim , memory_w_dim = memory_antecedent .shape .dims [- 3 :- 1 ]
531+
532+ # Call einsum over the query and memory to get query q, keys k and values v.
533+ q = mtf .einsum ([query_antecedent , q_var ],
534+ mtf .Shape ([
535+ batch , heads , num_h_blocks , num_w_blocks , h_dim , w_dim ,
536+ kv_channels
537+ ]))
538+ k = mtf .einsum ([memory_antecedent , k_var ],
539+ mtf .Shape ([batch , heads , num_h_blocks , num_w_blocks ,
540+ memory_h_dim , memory_w_dim , kv_channels ]))
541+ v = mtf .einsum ([memory_antecedent , v_var ],
542+ mtf .Shape ([batch , heads , num_h_blocks , num_w_blocks ,
543+ memory_h_dim , memory_w_dim , kv_channels ]))
544+
545+ # Halo exchange for memory blocks.
546+ k , v = local_2d_halo_exchange (k , v , num_h_blocks , memory_h_dim ,
547+ num_w_blocks , memory_w_dim , mask_right )
548+
549+ # Calculate the causal mask to avoid peeking into the future. We compute
550+ # this once and reuse it for all blocks since the block_size is known.
551+ mask = None
552+ if mask_right :
553+ mask = attention_bias_local_2d_block (query_antecedent .mesh , h_dim , w_dim ,
554+ memory_h_dim , memory_w_dim )
555+
556+ output = dot_product_attention (q , k , v , mask = mask )
557+
558+ return mtf .einsum (
559+ [output , o_var ],
560+ mtf .Shape (
561+ [batch , num_h_blocks , num_w_blocks , h_dim , w_dim , io_channels ]))
562+
563+
459564def rename_length_to_memory_length (
460565 x , length_name = "length" , memory_length_name = "memory_length" ):
461566 return mtf .rename_dimension (x , length_name , memory_length_name )
@@ -759,6 +864,44 @@ def attention_bias_local_block(mesh, block_length, memory_length,
759864 mask = mtf .cast (mtf .less (mtf .range (mesh , block_length , dtype = dtype ),
760865 mtf .range (mesh , memory_length , dtype = dtype )),
761866 dtype = dtype )
762- mask = mtf .cast (mtf .concat ([memory_mask , mask ], memory_length .name ),
763- dtype = tf .float32 ) * - 1e9
867+ mask = mtf .cast (
868+ mtf .concat ([memory_mask , mask ], memory_length .name ),
869+ dtype = tf .float32 ) * - 1e9
870+ return mask
871+
872+
873+ def attention_bias_local_2d_block (mesh ,
874+ h_dim ,
875+ w_dim ,
876+ memory_h_dim ,
877+ memory_w_dim ,
878+ dtype = tf .int32 ):
879+ """Bias for attention for local blocks where attention to right is disallowed.
880+
881+ Create the bias matrix by using two separate masks, one for the memory part
882+ which doesn't overlap with the query and second which interacts with the query
883+ and should be disallowed to look to the right of the current query position.
884+
885+ Args:
886+ mesh: a MeshTensorflow object
887+ h_dim: a mtf.Dimension
888+ w_dim: a mtf.Dimension
889+ memory_h_dim: a mtf.Dimension
890+ memory_w_dim: a mtf.Dimension
891+ dtype: a tf.dtype
892+
893+ Returns:
894+ a mtf.Tensor with shape [block_length, memory_length]
895+ """
896+ memory_height = mtf .Dimension (memory_h_dim .name , h_dim .size )
897+ memory_width = mtf .Dimension (memory_w_dim .name , w_dim .size )
898+ mask_top_visible = mtf .zeros (mesh , [h_dim , memory_height ], dtype = dtype )
899+ mask_left_visible = mtf .zeros (mesh , [w_dim , memory_width ], dtype = dtype )
900+ mask_query = mtf .greater (
901+ mtf .range (mesh , memory_height , dtype = tf .int32 ),
902+ mtf .range (mesh , memory_width , dtype = dtype ))
903+ width_mask = mtf .concat ([mask_left_visible , mask_query ], memory_width .name )
904+ mask = mtf .cast (
905+ mtf .concat ([mask_top_visible , width_mask ], memory_height .name ),
906+ dtype = tf .float32 ) * - 1e9
764907 return mask
0 commit comments