Skip to content

Commit 767a15e

Browse files
committed
extend doc and comments
1 parent fbc998f commit 767a15e

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

returnn/tf/layers/rec.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8418,7 +8418,8 @@ class RelativePositionalEncodingLayer(_ConcatInputLayer):
84188418
"total_key_dim": self.EncKeyTotalDim,
84198419
"n_out": self.EncValueTotalDim, "from": [output + '_self_att_laynorm'],
84208420
"attention_left_only": False, "attention_dropout": self.attention_dropout,
8421-
"forward_weights_init": self.ff_init, "key_shift": output + '_rel_pos'}
8421+
"forward_weights_init": self.ff_init,
8422+
"key_shift": output + '_rel_pos'}
84228423
84238424
"""
84248425
layer_class = "relative_positional_encoding"
@@ -8444,16 +8445,17 @@ def __init__(self, n_out, forward_weights_init="glorot_uniform", clipping=16, fi
84448445
if fixed:
84458446
from returnn.tf.util.basic import get_positional_encoding
84468447
encoding_matrix = get_positional_encoding(
8447-
length=tf.constant(2 * clipping + 1),
8448-
num_channels=n_out)
8448+
length=2 * clipping + 1,
8449+
num_channels=n_out) # shape [2 * clipping + 1, n_out]
84498450
else:
84508451
fwd_weights_initializer = get_initializer(
84518452
forward_weights_init, seed=self.network.random.randint(2 ** 31), eval_local_ns={"layer": self})
84528453
with self.var_creation_scope():
84538454
encoding_matrix = self.add_param(tf_compat.v1.get_variable(
84548455
name="encoding_matrix", shape=(2 * clipping + 1, n_out), initializer=fwd_weights_initializer))
8456+
# encoding_matrix has shape [2 * clipping + 1, n_out]
84558457

8456-
range_vec = tf.range(length) - offset
8458+
range_vec = tf.range(length) - offset # [length]
84578459

84588460
if self.input_data.have_time_axis():
84598461
range_mat = tf.reshape(tf.tile(range_vec, [length]), [length, length])

returnn/tf/util/basic.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6368,30 +6368,36 @@ def get_positional_encoding(num_channels, length=None, position=None, min_timesc
63686368
63696369
:param int num_channels: scalar, size of timing embeddings to create. The number of
63706370
different timescales is equal to channels / 2.
6371-
:param tf.Tensor|None length: scalar, length of timing signal sequence.
6372-
:param tf.Tensor|None position: could be provided directly. int32. Can have any shape.
6371+
:param tf.Tensor|int|None length: scalar, length of timing signal sequence.
6372+
:param tf.Tensor|None position: could be provided directly. int32. Can have any shape, e.g. [length] or [B,len].
6373+
If not given, will be tf.range(length), i.e. of shape [length].
63736374
:param float min_timescale: a float.
63746375
:param float max_timescale: a float.
6375-
:return: a Tensor of timing signals of shape (length, channels) or (batch, length, channels).
6376+
:return: a Tensor of timing signals of shape position.shape + [num_channels], e.g. [length,num_channels]
63766377
:rtype: tf.Tensor
63776378
"""
63786379
import math
63796380
if position is None:
63806381
assert length is not None
6381-
position = tf.range(length)
6382+
position = tf.range(length) # [length]
6383+
assert isinstance(position, tf.Tensor)
6384+
if isinstance(length, int):
6385+
position.set_shape([length])
63826386
else:
63836387
assert length is None
6388+
assert isinstance(position, tf.Tensor)
63846389
position = tf.cast(position, tf.float32)
6385-
num_timescales = num_channels // 2
6390+
num_timescales = num_channels // 2 # D//2
63866391
log_timescale_increment = (
63876392
math.log(float(max_timescale) / float(min_timescale)) / (float(num_timescales - 1)))
63886393
inv_timescales = min_timescale * tf.exp(
63896394
tf.cast(tf.range(num_timescales), tf.float32) * -log_timescale_increment)
63906395
scale = tf.reshape(inv_timescales, [1] * len(position.shape) + [num_timescales]) # Usually (1, D//2) or (1, 1, D//2).
6391-
scaled_time = tf.expand_dims(position, -1) * scale
6392-
signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=-1)
6396+
scaled_time = tf.expand_dims(position, -1) * scale # pos.shape + [D//2]
6397+
signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=-1) # pos.shape + [2*D//2]
63936398
# (length, channels) or (batch, length, channels).
6394-
signal = tf.pad(signal, [[0, 0]] * len(position.shape) + [[0, num_channels % 2]])
6399+
if num_channels % 2 != 0:
6400+
signal = tf.pad(signal, [[0, 0]] * len(position.shape) + [[0, num_channels % 2]]) # pos.shape + [D]
63956401
return signal
63966402

63976403

0 commit comments

Comments
 (0)