@@ -6368,30 +6368,36 @@ def get_positional_encoding(num_channels, length=None, position=None, min_timesc
63686368
63696369 :param int num_channels: scalar, size of timing embeddings to create. The number of
63706370 different timescales is equal to channels / 2.
6371- :param tf.Tensor|None length: scalar, length of timing signal sequence.
6372- :param tf.Tensor|None position: could be provided directly. int32. Can have any shape.
6371+ :param tf.Tensor|int|None length: scalar, length of timing signal sequence.
6372+ :param tf.Tensor|None position: could be provided directly. int32. Can have any shape, e.g. [length] or [B,len].
6373+ If not given, will be tf.range(length), i.e. of shape [length].
63736374 :param float min_timescale: a float.
63746375 :param float max_timescale: a float.
6375- :return: a Tensor of timing signals of shape (length, channels) or (batch, length, channels).
6376+ :return: a Tensor of timing signals of shape position.shape + [num_channels], e.g. [ length,num_channels]
63766377 :rtype: tf.Tensor
63776378 """
63786379 import math
63796380 if position is None :
63806381 assert length is not None
6381- position = tf .range (length )
6382+ position = tf .range (length ) # [length]
6383+ assert isinstance (position , tf .Tensor )
6384+ if isinstance (length , int ):
6385+ position .set_shape ([length ])
63826386 else :
63836387 assert length is None
6388+ assert isinstance (position , tf .Tensor )
63846389 position = tf .cast (position , tf .float32 )
6385- num_timescales = num_channels // 2
6390+ num_timescales = num_channels // 2 # D//2
63866391 log_timescale_increment = (
63876392 math .log (float (max_timescale ) / float (min_timescale )) / (float (num_timescales - 1 )))
63886393 inv_timescales = min_timescale * tf .exp (
63896394 tf .cast (tf .range (num_timescales ), tf .float32 ) * - log_timescale_increment )
63906395 scale = tf .reshape (inv_timescales , [1 ] * len (position .shape ) + [num_timescales ]) # Usually (1, D//2) or (1, 1, D//2).
6391- scaled_time = tf .expand_dims (position , - 1 ) * scale
6392- signal = tf .concat ([tf .sin (scaled_time ), tf .cos (scaled_time )], axis = - 1 )
6396+ scaled_time = tf .expand_dims (position , - 1 ) * scale # pos.shape + [D//2]
6397+ signal = tf .concat ([tf .sin (scaled_time ), tf .cos (scaled_time )], axis = - 1 ) # pos.shape + [2*D//2]
63936398 # (length, channels) or (batch, length, channels).
6394- signal = tf .pad (signal , [[0 , 0 ]] * len (position .shape ) + [[0 , num_channels % 2 ]])
6399+ if num_channels % 2 != 0 :
6400+ signal = tf .pad (signal , [[0 , 0 ]] * len (position .shape ) + [[0 , num_channels % 2 ]]) # pos.shape + [D]
63956401 return signal
63966402
63976403
0 commit comments