@@ -343,128 +343,142 @@ def local_self_attention_spatial_blocks(
343343 [output , o_var ], mtf .Shape ([batch , num_w_blocks , w_dim , io_channels ]))
344344
345345
346- def masked_local_attention_1d (query_antecedent ,
347- memory_antecedent ,
346+ def masked_local_attention_1d (x ,
348347 kv_channels ,
349348 heads ,
350- block_length = 128 ,
349+ window_size = 128 ,
351350 master_dtype = tf .float32 ,
352351 slice_dtype = tf .float32 ,
352+ length_per_split = None ,
353353 name = None ):
354354 """Attention to the source position and a neighborhood to the left of it.
355355
356- The sequence is divided into blocks of length block_size.
357- Attention for a given query position can only see memory positions
358- less than or equal to the query position, in the corresponding block
359- and the previous block.
356+ Attention for a given query position p can only see memory positions
357+ in the range (p - window_size, p].
360358
361359 Args:
362- query_antecedent: a mtf.Tensor with shape [batch, query_length, io_channels]
363- memory_antecedent: a mtf.Tensor with shape
364- [batch, memory_length, io_channels] (optional). Currently, memory_length
365- must have the same size as query_length, but a different name.
360+ x: a mtf.Tensor with shape batch_dims + [length, io_channels]
366361 kv_channels: a mtf.Dimension (the size of the key and value vectors)
367362 heads: a mtf.Dimension (the number of heads)
368- block_length : an integer, representing receptive fields for attention.
363+ window_size : an integer
369364 master_dtype: a tf.dtype
370365 slice_dtype: a tf.dtype
366+ length_per_split: an optional integer indicating the part of the length
367+ dimension per processor. You can omit if the length dimension is not
368+ split.
371369 name: an optional string.
372370
373371 Returns:
374- a Tensor of shape [batch, query_length, io_channels]
372+ a Tensor with the same shape as x
375373
376374 Raises:
377375 ValueError: if channels or depth don't match.
378376 """
379377 with tf .variable_scope (
380- name , default_name = "multihead_attention" ,
381- values = [query_antecedent , memory_antecedent ]):
378+ name , default_name = "masked_local_attention_1d" , values = [x ]):
382379
383- batch , query_length , io_channels = query_antecedent .shape .dims
380+ batch_dims = x .shape .dims [:- 2 ]
381+ length , io_channels = x .shape .dims [- 2 :]
384382 q_var , k_var , v_var , o_var = multihead_attention_vars (
385- query_antecedent .mesh , heads , io_channels , kv_channels ,
386- master_dtype , slice_dtype , query_antecedent .dtype )
387-
388- if memory_antecedent is None :
389- memory_antecedent = rename_length_to_memory_length (
390- query_antecedent , query_length .name )
391- memory_batch , memory_length , memory_channels = memory_antecedent .shape .dims
392- if memory_batch != batch :
393- raise ValueError ("memory batch must equal query batch" )
394- if memory_channels != io_channels :
395- raise ValueError ("memory channels must equal query channels" )
383+ x .mesh , heads , io_channels , kv_channels ,
384+ master_dtype , slice_dtype , x .dtype )
396385
397386 # Get query q, keys k and values v.
398- q = mtf .einsum (
399- [query_antecedent , q_var ],
400- mtf .Shape ([batch , heads , query_length , kv_channels ]))
401- k = mtf .einsum (
402- [memory_antecedent , k_var ],
403- mtf .Shape ([batch , heads , memory_length , kv_channels ]))
404- v = mtf .einsum (
405- [memory_antecedent , v_var ],
406- mtf .Shape ([batch , heads , memory_length , kv_channels ]))
407-
408- # Let's assume for now we don't have padding and the block length equally
409- # divides the memory length.
410- block_length = (query_length .size
411- if query_length .size < block_length * 2 else block_length )
412- blength = mtf .Dimension ("block_length" , block_length )
413- mlength = mtf .Dimension ("mem_block_length" , block_length )
414- num_blocks = mtf .Dimension ("num_blocks" , query_length .size // block_length )
415-
416- q = mtf .reshape (
417- q , mtf .Shape ([batch , heads , num_blocks , blength , kv_channels ]))
418- k = mtf .reshape (
419- k , mtf .Shape ([batch , heads , num_blocks , mlength , kv_channels ]))
420- v = mtf .reshape (
421- v , mtf .Shape ([batch , heads , num_blocks , mlength , kv_channels ]))
422-
423- # compute attention for the first query block.
424- def first_block_attention ():
425- """Compute attention for the first block."""
426- first_q = mtf .slice (q , 0 , 1 , num_blocks .name )
427- first_k = mtf .slice (k , 0 , 1 , num_blocks .name )
428- first_v = mtf .slice (v , 0 , 1 , num_blocks .name )
429- first_output = dot_product_attention (first_q ,
430- first_k ,
431- first_v ,
432- mask = None )
433- return first_output
434-
435- # Attention for first block, since query_length = key_length.
436- first_output = first_block_attention ()
437-
438- # Concatenate two adjacent blocks to compute the overlapping memory block.
439- def local (x ):
440- """Helper function to get memory blocks."""
441- prev_block = mtf .slice (x , 0 , num_blocks .size - 1 , num_blocks .name )
442- cur_block = mtf .slice (x , 1 , num_blocks .size - 1 , num_blocks .name )
443- local_block = mtf .concat ([prev_block , cur_block ], mlength .name )
444- return local_block
445-
446- local_k = local (k )
447- local_v = local (v )
448- # Calculate the causal mask to avoid peeking into the future. We compute
449- # this once and reuse it for all blocks since the block_size is known.
450- mlength = local_k .shape .dims [3 ]
451- mask = attention_bias_local_block (query_antecedent .mesh ,
452- blength , mlength )
453-
454- # Remove the first block from q since we already computed that.
455- tail_q = mtf .slice (q , 1 , num_blocks .size - 1 , num_blocks .name )
456-
457- tail_output = dot_product_attention (tail_q ,
458- local_k ,
459- local_v ,
460- mask = mask )
461-
462- # Now concatenate the first and rest of the blocks.
463- final_output = mtf .concat ([first_output , tail_output ], num_blocks .name )
464- final_output = mtf .reshape (final_output , mtf .Shape (
465- [batch , heads , query_length , kv_channels ]))
466- return mtf .einsum ([final_output , o_var ],
467- mtf .Shape ([batch , query_length , io_channels ]))
387+ qkv_shape = mtf .Shape (batch_dims + [heads , length , kv_channels ])
388+ q = mtf .einsum ([x , q_var ], qkv_shape )
389+ k = mtf .einsum ([x , k_var ], qkv_shape )
390+ v = mtf .einsum ([x , v_var ], qkv_shape )
391+
392+ # Choose a suitable block size.
393+ # We choose the greatest divisor of length_per_split less than or equal
394+ # to max(window_size, 128)
395+ if length_per_split is None :
396+ length_per_split = length .size
397+ block_length = max (window_size , 128 )
398+ while length_per_split % block_length != 0 :
399+ block_length -= 1
400+
401+ query_block_length = mtf .Dimension ("query_block_length" , block_length )
402+ memory_block_length = mtf .Dimension ("memory_block_length" , block_length )
403+ # The num_blocks dimension gets the same name as the length dimension,
404+ # so it will be split in the same way.
405+ num_blocks = mtf .Dimension (length .name , length .size // block_length )
406+ q_shape = batch_dims + [heads , num_blocks , query_block_length , kv_channels ]
407+ kv_shape = batch_dims + [
408+ heads , num_blocks , memory_block_length , kv_channels ]
409+ q = mtf .reshape (q , q_shape )
410+ k = mtf .reshape (k , kv_shape )
411+ v = mtf .reshape (v , kv_shape )
412+ # augment the keys and values for each block with keys and values for
413+ # the previous window_size timesteps.
414+ k = mtf .left_halo_exchange (k , num_blocks , memory_block_length , window_size )
415+ v = mtf .left_halo_exchange (v , num_blocks , memory_block_length , window_size )
416+ padded_memory_block_length = mtf .Dimension (
417+ "memory_block_length" , window_size + block_length )
418+ mpos = mtf .range (x .mesh , padded_memory_block_length , tf .float32 )
419+ qpos = mtf .range (x .mesh , query_block_length , tf .float32 ) + window_size
420+ # prevent looking forward
421+ mask = mtf .cast (mtf .greater (mpos , qpos ), x .dtype ) * - 1e9
422+ # prevent looking >=block_length timesteps backward
423+ mask += mtf .cast (mtf .less_equal (mpos , qpos - block_length ), x .dtype ) * - 1e9
424+ # Note: The first window_size-1 positions can see back into pre-time
425+ # where all the keys and values are zero. We could mask this out, but we
426+ # don't.
427+ o = dot_product_attention (q , k , v , mask = mask )
428+ o = mtf .reshape (o , batch_dims + [heads , length , kv_channels ])
429+ return mtf .einsum ([o , o_var ], mtf .Shape (batch_dims + [length , io_channels ]))
430+
431+
432+ def masked_local_attention_1d_incremental (x ,
433+ prev_k ,
434+ prev_v ,
435+ step_num ,
436+ master_dtype ,
437+ slice_dtype ,
438+ name = None ):
439+ """Incremental local self-attention (one decode step).
440+
441+ Incremental version of masked_local_attention_1d()
442+
443+ Args:
444+ x: a mtf.Tensor with shape [batch..., io_channels]
445+ prev_k: mtf.Tensor with shape
446+ [batch..., heads, window_length, kv_channels]
447+ prev_v: mtf.Tensor with shape
448+ [batch..., heads, window_length, kv_channels]
449+ step_num: mtf Scalar with dtype tf.int32
450+ master_dtype: a tf.dtype
451+ slice_dtype: a tf.dtype
452+ name: an optional string.
453+
454+ Returns:
455+ y: A mtf.Tensor with shape [batch..., io_channels]
456+ new_k: mtf.Tensor with shape
457+ [batch..., heads, window_length, kv_channels]
458+ new_v: mtf.Tensor with shape
459+ [batch..., heads, window_length, kv_channels]
460+
461+ Raises:
462+ ValueError: if the dimensions do not match.
463+ """
464+ batch_dims = x .shape .dims [:- 1 ]
465+ io_channels = x .shape .dims [- 1 ]
466+ heads , window_length , kv_channels = prev_k .shape .dims [- 3 :]
467+ with tf .variable_scope (name , default_name = "multihead_attention" ):
468+ q_var , k_var , v_var , o_var = multihead_attention_vars (
469+ x .mesh , heads , io_channels , kv_channels ,
470+ master_dtype , slice_dtype , x .dtype )
471+ q = mtf .einsum ([x , q_var ], mtf .Shape (batch_dims + [heads , kv_channels ]))
472+ k = mtf .einsum ([x , k_var ], mtf .Shape (batch_dims + [heads , kv_channels ]))
473+ v = mtf .einsum ([x , v_var ], mtf .Shape (batch_dims + [heads , kv_channels ]))
474+ current_position = mtf .equal (
475+ mtf .range (x .mesh , window_length , dtype = tf .int32 ),
476+ mtf .mod (step_num , window_length .size ))
477+ k = mtf .where (current_position , k , prev_k , output_shape = prev_k .shape )
478+ v = mtf .where (current_position , v , prev_v , output_shape = prev_v .shape )
479+ o = dot_product_attention (q , k , v , mask = None )
480+ y = mtf .einsum ([o , o_var ], x .shape )
481+ return y , k , v
468482
469483
470484def local_2d_halo_exchange (k , v , num_h_blocks , h_dim ,
0 commit comments