Skip to content

Commit e742a37

Browse files
committed
LengthLayer, support dyn_size_ext
1 parent de0ac69 commit e742a37

File tree

1 file changed

+25
-17
lines changed

1 file changed

+25
-17
lines changed

returnn/tf/layers/basic.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1661,42 +1661,50 @@ class LengthLayer(LayerBase):
16611661
layer_class = "length"
16621662

16631663
# noinspection PyUnusedLocal
1664-
def __init__(self, add_time_axis=False, dtype="int32", sparse=False, **kwargs):
1664+
def __init__(self, axis="T", add_time_axis=False, dtype="int32", sparse=False, **kwargs):
16651665
"""
1666+
:param str|DimensionTag axis:
16661667
:param bool add_time_axis:
16671668
:param str dtype:
16681669
:param bool sparse:
16691670
"""
16701671
super(LengthLayer, self).__init__(**kwargs)
16711672
assert len(self.sources) == 1, "%s: expects one source" % self
1672-
out = tf.cast(self.sources[0].output.get_sequence_lengths(), dtype)
1673+
source = self.sources[0].output
1674+
axis = source.get_axis_from_description(axis, allow_int=False)
1675+
dim = source.dim_tags[axis]
16731676
if add_time_axis:
1674-
out = tf.expand_dims(out, axis=self.output.time_dim_axis)
1675-
self.output.placeholder = out
1677+
self.output.placeholder = tf.expand_dims(dim.dyn_size, axis=self.output.time_dim_axis)
1678+
else:
1679+
self.output.placeholder = dim.dyn_size_ext.placeholder
16761680

16771681
@classmethod
1678-
def get_out_data_from_opts(cls, name, sources, add_time_axis=False, dtype="int32", sparse=False, **kwargs):
1682+
def get_out_data_from_opts(cls, name, sources, axis="T", add_time_axis=False, dtype="int32", sparse=False, **kwargs):
16791683
"""
16801684
:param str name:
16811685
:param list[LayerBase] sources:
1686+
:param str|DimensionTag axis:
16821687
:param bool add_time_axis:
16831688
:param str dtype:
16841689
:param bool sparse:
16851690
:rtype: Data
16861691
"""
1692+
assert len(sources) == 1
1693+
source = sources[0].output
1694+
axis = source.get_axis_from_description(axis, allow_int=False)
1695+
dim = source.dim_tags[axis]
16871696
if add_time_axis:
1688-
shape = (1,)
1689-
time_dim_axis = 1
1690-
else:
1691-
shape = ()
1692-
time_dim_axis = None
1693-
return Data(
1694-
name="%s_length" % name,
1695-
shape=shape,
1696-
batch_dim_axis=0,
1697-
time_dim_axis=time_dim_axis,
1698-
dtype=dtype,
1699-
sparse=sparse, dim=None if sparse else NotSpecified)
1697+
assert dim.dyn_size_ext and dim.dyn_size_ext.have_batch_axis() and dim.dyn_size_ext.batch_ndim == 1 # [B]
1698+
return Data(
1699+
name="%s_length" % name,
1700+
shape=[1], batch_dim_axis=0, time_dim_axis=1,
1701+
dtype=dtype, sparse=sparse, dim=None if sparse else NotSpecified)
1702+
if not dim.dyn_size_ext: # yet undefined
1703+
return Data(
1704+
name="%s_length" % name,
1705+
shape=(), batch_dim_axis=0, time_dim_axis=None,
1706+
dtype=dtype, sparse=sparse, dim=None if sparse else NotSpecified)
1707+
return dim.dyn_size_ext
17001708

17011709

17021710
class SoftmaxOverSpatialLayer(_ConcatInputLayer):

0 commit comments

Comments
 (0)