Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,42 +1637,51 @@ class LengthLayer(LayerBase):
layer_class = "length"

# noinspection PyUnusedLocal
def __init__(self, add_time_axis=False, dtype="int32", sparse=False, **kwargs):
def __init__(self, axis="T", add_time_axis=False, dtype="int32", sparse=False, **kwargs):
"""
:param str|DimensionTag axis:
:param bool add_time_axis:
:param str dtype:
:param bool sparse:
"""
super(LengthLayer, self).__init__(**kwargs)
assert len(self.sources) == 1, "%s: expects one source" % self
out = tf.cast(self.sources[0].output.get_sequence_lengths(), dtype)
source = self.sources[0].output
axis = source.get_axis_from_description(axis, allow_int=False)
dim = source.dim_tags[axis]
self.dim_tag = dim
if add_time_axis:
out = tf.expand_dims(out, axis=self.output.time_dim_axis)
self.output.placeholder = out
self.output.placeholder = tf.expand_dims(dim.dyn_size, axis=self.output.time_dim_axis)
else:
self.output.placeholder = dim.dyn_size_ext.placeholder

@classmethod
def get_out_data_from_opts(cls, name, sources, add_time_axis=False, dtype="int32", sparse=False, **kwargs):
def get_out_data_from_opts(cls, name, sources, axis="T", add_time_axis=False, dtype="int32", sparse=False, **kwargs):
"""
:param str name:
:param list[LayerBase] sources:
:param str|DimensionTag axis:
:param bool add_time_axis:
:param str dtype:
:param bool sparse:
:rtype: Data
"""
assert len(sources) == 1
source = sources[0].output
axis = source.get_axis_from_description(axis, allow_int=False)
dim = source.dim_tags[axis]
if add_time_axis:
shape = (1,)
time_dim_axis = 1
else:
shape = ()
time_dim_axis = None
return Data(
name="%s_length" % name,
shape=shape,
batch_dim_axis=0,
time_dim_axis=time_dim_axis,
dtype=dtype,
sparse=sparse, dim=None if sparse else NotSpecified)
assert dim.dyn_size_ext and dim.dyn_size_ext.have_batch_axis() and dim.dyn_size_ext.batch_ndim == 1 # [B]
return Data(
name="%s_length" % name,
shape=[1], batch_dim_axis=0, time_dim_axis=1,
dtype=dtype, sparse=sparse, dim=None if sparse else NotSpecified)
if not dim.dyn_size_ext: # yet undefined
return Data(
name="%s_length" % name,
shape=(), batch_dim_axis=0, time_dim_axis=None,
dtype=dtype, sparse=sparse, dim=None if sparse else NotSpecified)
return dim.dyn_size_ext


class SoftmaxOverSpatialLayer(_ConcatInputLayer):
Expand Down