Skip to content

Commit 60500e3

Browse files
committed
LengthLayer, add dim_tag attrib
1 parent e742a37 commit 60500e3

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

returnn/tf/layers/basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,6 +1673,7 @@ def __init__(self, axis="T", add_time_axis=False, dtype="int32", sparse=False, *
16731673
source = self.sources[0].output
16741674
axis = source.get_axis_from_description(axis, allow_int=False)
16751675
dim = source.dim_tags[axis]
1676+
self.dim_tag = dim
16761677
if add_time_axis:
16771678
self.output.placeholder = tf.expand_dims(dim.dyn_size, axis=self.output.time_dim_axis)
16781679
else:

0 commit comments

Comments
 (0)