diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index de621dd7d..5a77fac92 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -1637,42 +1637,51 @@ class LengthLayer(LayerBase): layer_class = "length" # noinspection PyUnusedLocal - def __init__(self, add_time_axis=False, dtype="int32", sparse=False, **kwargs): + def __init__(self, axis="T", add_time_axis=False, dtype="int32", sparse=False, **kwargs): """ + :param str|DimensionTag axis: :param bool add_time_axis: :param str dtype: :param bool sparse: """ super(LengthLayer, self).__init__(**kwargs) assert len(self.sources) == 1, "%s: expects one source" % self - out = tf.cast(self.sources[0].output.get_sequence_lengths(), dtype) + source = self.sources[0].output + axis = source.get_axis_from_description(axis, allow_int=False) + dim = source.dim_tags[axis] + self.dim_tag = dim if add_time_axis: - out = tf.expand_dims(out, axis=self.output.time_dim_axis) - self.output.placeholder = out + self.output.placeholder = tf.expand_dims(dim.dyn_size, axis=self.output.time_dim_axis) + else: + self.output.placeholder = dim.dyn_size_ext.placeholder @classmethod - def get_out_data_from_opts(cls, name, sources, add_time_axis=False, dtype="int32", sparse=False, **kwargs): + def get_out_data_from_opts(cls, name, sources, axis="T", add_time_axis=False, dtype="int32", sparse=False, **kwargs): """ :param str name: :param list[LayerBase] sources: + :param str|DimensionTag axis: :param bool add_time_axis: :param str dtype: :param bool sparse: :rtype: Data """ + assert len(sources) == 1 + source = sources[0].output + axis = source.get_axis_from_description(axis, allow_int=False) + dim = source.dim_tags[axis] if add_time_axis: - shape = (1,) - time_dim_axis = 1 - else: - shape = () - time_dim_axis = None - return Data( - name="%s_length" % name, - shape=shape, - batch_dim_axis=0, - time_dim_axis=time_dim_axis, - dtype=dtype, - sparse=sparse, dim=None if sparse else NotSpecified) + assert dim.dyn_size_ext and dim.dyn_size_ext.have_batch_axis() and dim.dyn_size_ext.batch_ndim == 1 # [B] + return Data( + name="%s_length" % name, + shape=[1], batch_dim_axis=0, time_dim_axis=1, + dtype=dtype, sparse=sparse, dim=None if sparse else NotSpecified) + if not dim.dyn_size_ext: # yet undefined + return Data( + name="%s_length" % name, + shape=(), batch_dim_axis=0, time_dim_axis=None, + dtype=dtype, sparse=sparse, dim=None if sparse else NotSpecified) + return dim.dyn_size_ext class SoftmaxOverSpatialLayer(_ConcatInputLayer):