Skip to content

Commit 66a36b0

Browse files
albertzZettelkasten
andcommitted
CumConcatLayer wip ...
Co-authored-by: Frithjof <[email protected]>
1 parent bd2f2d8 commit 66a36b0

File tree

1 file changed

+151
-0
lines changed

1 file changed

+151
-0
lines changed

returnn/tf/layers/rec.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8495,3 +8495,154 @@ def get_out_data_from_opts(cls, name, sources, n_out, **kwargs):
84958495
kind=DimensionTag.Types.Spatial, description="%s_rel_pos_enc_time" % name, dimension=None)
84968496
data = data.copy_template_new_dim_tags((dummy_dim_tag, time_dim_tag, feature_dim_tag))
84978497
return data
8498+
8499+
8500+
class CumConcatLayer(_ConcatInputLayer):
8501+
"""
8502+
Concatenates all previous frames of a time-axis.
8503+
Like :class:`CumsumLayer` uses `sum`, this layer uses `concat`.
8504+
8505+
This layer expects to be inside a :class:`RecLayer`.
8506+
8507+
Inside a rec loop (not optimized out),
8508+
this will concatenate the current input
8509+
to the previous accumulated inputs.
8510+
For an input of shape `input_shape`,
8511+
it will output a tensor of shape `[new_dim] + input_shape`.
8512+
`new_dim` is a special dimension, usually of length `i`,
8513+
where `i` is the current loop frame,
8514+
i.e. the length increases in every loop frame.
8515+
`new_dim` is specified by a separate own dim tag.
8516+
For example, in the first frame,
8517+
this will be of shape `[1] + input_shape`,
8518+
in the second frame shape `[2] + input_shape`,
8519+
and so on,
8520+
and in the last frame shape `[T] + input_shape`.
8521+
8522+
Outside the rec loop (optimized out),
8523+
this layer expects an input with the time dim of the rec layer,
8524+
and returns the input as-is,
8525+
but replacing the time dim tag with the dim tag `new_dim`
8526+
converted as outside the loop.
8527+
8528+
Normally the optimization should not matter for the user,
8529+
i.e. for the user, the logical behavior is always as being inside the rec loop.
8530+
Outside the loop,
8531+
the output represents a tensor of shape `[T, new_dim] + input_shape`,
8532+
although we actually have another `new_dim` outside the loop,
8533+
and `T` is not actually there,
8534+
but we still have all the information,
8535+
because the last frame has all information.
8536+
"""
8537+
layer_class = "cum_concat"
8538+
recurrent = True # order matters
8539+
8540+
def __init__(self, new_dim, **kwargs):
8541+
"""
8542+
:param DimensionTag new_dim:
8543+
"""
8544+
super(CumConcatLayer, self).__init__(**kwargs)
8545+
assert self.network.is_inside_rec_layer()
8546+
out_axis = None
8547+
for a, tag in enumerate(self.output.dim_tags):
8548+
if tag == new_dim:
8549+
out_axis = a
8550+
break
8551+
assert out_axis is not None
8552+
8553+
if self.network.is_inside_rec_layer(inside_loop=True):
8554+
current_data = self.input_data.copy_compatible_to(self.output, unbroadcast=False)
8555+
current_frame = current_data.placeholder # [B, 1, ..., D]
8556+
last_frames = self._rec_previous_layer.rec_vars_outputs["state"] # [B, t, ..., D]
8557+
concat_frames = tf.concat([last_frames, current_frame], axis=out_axis) # [B, t+1, ..., D]
8558+
self.rec_vars_outputs["state"] = concat_frames
8559+
self.output.placeholder = concat_frames
8560+
8561+
dyn_size = tf.broadcast_to(self.network.get_rec_step_index() + 1, [data.get_batch_dim()])
8562+
8563+
else:
8564+
# If not inside a rec loop, this layer is a no-op
8565+
self.output.placeholder = None # TODO
8566+
data.size_placeholder = self.input_data.size_placeholder.copy()
8567+
dyn_size = tf.identity(data.get_dynamic_size(out_axis))
8568+
8569+
# We already set the size_placeholder to a dummy rec-history before, now do it properly
8570+
from returnn.tf.util.basic import DimensionTag
8571+
tag = DimensionTag(
8572+
description="rec-history:%s" % self.get_absolute_name(),
8573+
kind=DimensionTag.Types.Time)
8574+
data.size_placeholder[data.get_batch_axis_excluding_batch(out_axis)] = dyn_size
8575+
tag.set_tag_on_size_tensor(dyn_size)
8576+
8577+
@classmethod
8578+
def get_out_data_from_opts(cls, name, network, sources, new_dim, **kwargs):
8579+
"""
8580+
:param str name:
8581+
:param returnn.tf.network.TFNetwork network:
8582+
:param list[LayerBase] sources:
8583+
:param DimensionTag new_dim:
8584+
:rtype: Data
8585+
"""
8586+
rec_layer = network.get_rec_parent_layer(inside_loop=False)
8587+
assert rec_layer, "This must be inside the loop"
8588+
input_data = get_concat_sources_data_template(sources, name="%s_output" % name)
8589+
if network.is_inside_rec_layer(inside_loop=True):
8590+
# Currently SelectSearchSourcesLayer assumes that all rec_vars_outputs are batch-major.
8591+
# Therefore we here copy the input as batch-major, and then add the time axis at axis 1.
8592+
# In the future, when SelectSearchSourcesLayer has support for this, we can change this to operate on axis 0,
8593+
# which should be more efficient
8594+
out = input_data.copy_as_batch_major()
8595+
out = out.copy_add_dim_by_tag(new_dim, unbroadcast=True, axis=1)
8596+
# TODO set new_dim per spatial frame ...
8597+
return out
8598+
8599+
else: # outside loop
8600+
out = input_data.copy_as_batch_major()
8601+
rec_time = rec_layer.output.get_time_dim_tag()
8602+
_matches = [i for (i, tag) in enumerate(out.dim_tags) if tag == rec_time]
8603+
assert len(_matches) == 1
8604+
out = out.copy_move_axis(_matches[0], 1)
8605+
# TODO use separate new_dim outside loop ...
8606+
out = out.copy_template_replace_dim_tag(axis=1, new_dim_tag=new_dim)
8607+
return out
8608+
8609+
# noinspection PyMethodOverriding
8610+
@classmethod
8611+
def get_rec_initial_extra_outputs(cls, network, batch_dim, rec_layer, sources, output, new_dim, **kwargs):
8612+
"""
8613+
:param returnn.tf.network.TFNetwork network:
8614+
:param tf.Tensor batch_dim:
8615+
:param TFNetworkRecLayer.RecLayer|LayerBase rec_layer:
8616+
:param list[LayerBase] sources:
8617+
:param Data output:
8618+
:param DimensionTag new_dim:
8619+
:rtype: dict[str,tf.Tensor]
8620+
"""
8621+
if network.is_inside_rec_layer():
8622+
shape = []
8623+
for tag in output.dim_tags:
8624+
if tag.is_batch_dim():
8625+
shape.append(batch_dim)
8626+
elif tag == new_dim:
8627+
shape.append(0)
8628+
elif tag.dimension is not None:
8629+
shape.append(tag.dimension)
8630+
else:
8631+
assert tag.dyn_size is not None
8632+
shape.append(tf.math.reduce_max(tag.dyn_size))
8633+
return {"state": tf.zeros(shape, dtype=output.dtype)}
8634+
else:
8635+
return {}
8636+
8637+
@classmethod
8638+
def get_rec_initial_extra_outputs_shape_invariants(cls, network, sources, output, **kwargs):
8639+
"""
8640+
:param returnn.tf.network.TFNetwork network:
8641+
:param list[LayerBase] sources:
8642+
:param Data output:
8643+
:rtype: dict[str, tf.TensorShape]
8644+
"""
8645+
if network.is_inside_rec_layer():
8646+
return {"state": tf.TensorShape(output.batch_shape)}
8647+
else:
8648+
return {}

0 commit comments

Comments
 (0)