@@ -8495,3 +8495,154 @@ def get_out_data_from_opts(cls, name, sources, n_out, **kwargs):
84958495 kind = DimensionTag .Types .Spatial , description = "%s_rel_pos_enc_time" % name , dimension = None )
84968496 data = data .copy_template_new_dim_tags ((dummy_dim_tag , time_dim_tag , feature_dim_tag ))
84978497 return data
8498+
8499+
8500+ class CumConcatLayer (_ConcatInputLayer ):
8501+ """
8502+ Concatenates all previous frames of a time-axis.
8503+ Like :class:`CumsumLayer` uses `sum`, this layer uses `concat`.
8504+
8505+ This layer expects to be inside a :class:`RecLayer`.
8506+
8507+ Inside a rec loop (not optimized out),
8508+ this will concatenate the current input
8509+ to the previous accumulated inputs.
8510+ For an input of shape `input_shape`,
8511+ it will output a tensor of shape `[new_dim] + input_shape`.
8512+ `new_dim` is a special dimension, usually of length `i`,
8513+ where `i` is the current loop frame,
8514+ i.e. the length increases in every loop frame.
8515+ `new_dim` is specified by a separate own dim tag.
8516+ For example, in the first frame,
8517+ this will be of shape `[1] + input_shape`,
8518+ in the second frame shape `[2] + input_shape`,
8519+ and so on,
8520+ and in the last frame shape `[T] + input_shape`.
8521+
8522+ Outside the rec loop (optimized out),
8523+ this layer expects an input with the time dim of the rec layer,
8524+ and returns the input as-is,
8525+ but replacing the time dim tag with the dim tag `new_dim`
8526+ converted as outside the loop.
8527+
8528+ Normally the optimization should not matter for the user,
8529+ i.e. for the user, the logical behavior is always as being inside the rec loop.
8530+ Outside the loop,
8531+ the output represents a tensor of shape `[T, new_dim] + input_shape`,
8532+ although we actually have another `new_dim` outside the loop,
8533+ and `T` is not actually there,
8534+ but we still have all the information,
8535+ because the last frame has all information.
8536+ """
8537+ layer_class = "cum_concat"
8538+ recurrent = True # order matters
8539+
8540+ def __init__ (self , new_dim , ** kwargs ):
8541+ """
8542+ :param DimensionTag new_dim:
8543+ """
8544+ super (CumConcatLayer , self ).__init__ (** kwargs )
8545+ assert self .network .is_inside_rec_layer ()
8546+ out_axis = None
8547+ for a , tag in enumerate (self .output .dim_tags ):
8548+ if tag == new_dim :
8549+ out_axis = a
8550+ break
8551+ assert out_axis is not None
8552+
8553+ if self .network .is_inside_rec_layer (inside_loop = True ):
8554+ current_data = self .input_data .copy_compatible_to (self .output , unbroadcast = False )
8555+ current_frame = current_data .placeholder # [B, 1, ..., D]
8556+ last_frames = self ._rec_previous_layer .rec_vars_outputs ["state" ] # [B, t, ..., D]
8557+ concat_frames = tf .concat ([last_frames , current_frame ], axis = out_axis ) # [B, t+1, ..., D]
8558+ self .rec_vars_outputs ["state" ] = concat_frames
8559+ self .output .placeholder = concat_frames
8560+
8561+ dyn_size = tf .broadcast_to (self .network .get_rec_step_index () + 1 , [data .get_batch_dim ()])
8562+
8563+ else :
8564+ # If not inside a rec loop, this layer is a no-op
8565+ self .output .placeholder = None # TODO
8566+ data .size_placeholder = self .input_data .size_placeholder .copy ()
8567+ dyn_size = tf .identity (data .get_dynamic_size (out_axis ))
8568+
8569+ # We already set the size_placeholder to a dummy rec-history before, now do it properly
8570+ from returnn .tf .util .basic import DimensionTag
8571+ tag = DimensionTag (
8572+ description = "rec-history:%s" % self .get_absolute_name (),
8573+ kind = DimensionTag .Types .Time )
8574+ data .size_placeholder [data .get_batch_axis_excluding_batch (out_axis )] = dyn_size
8575+ tag .set_tag_on_size_tensor (dyn_size )
8576+
8577+ @classmethod
8578+ def get_out_data_from_opts (cls , name , network , sources , new_dim , ** kwargs ):
8579+ """
8580+ :param str name:
8581+ :param returnn.tf.network.TFNetwork network:
8582+ :param list[LayerBase] sources:
8583+ :param DimensionTag new_dim:
8584+ :rtype: Data
8585+ """
8586+ rec_layer = network .get_rec_parent_layer (inside_loop = False )
8587+ assert rec_layer , "This must be inside the loop"
8588+ input_data = get_concat_sources_data_template (sources , name = "%s_output" % name )
8589+ if network .is_inside_rec_layer (inside_loop = True ):
8590+ # Currently SelectSearchSourcesLayer assumes that all rec_vars_outputs are batch-major.
8591+ # Therefore we here copy the input as batch-major, and then add the time axis at axis 1.
8592+ # In the future, when SelectSearchSourcesLayer has support for this, we can change this to operate on axis 0,
8593+ # which should be more efficient
8594+ out = input_data .copy_as_batch_major ()
8595+ out = out .copy_add_dim_by_tag (new_dim , unbroadcast = True , axis = 1 )
8596+ # TODO set new_dim per spatial frame ...
8597+ return out
8598+
8599+ else : # outside loop
8600+ out = input_data .copy_as_batch_major ()
8601+ rec_time = rec_layer .output .get_time_dim_tag ()
8602+ _matches = [i for (i , tag ) in enumerate (out .dim_tags ) if tag == rec_time ]
8603+ assert len (_matches ) == 1
8604+ out = out .copy_move_axis (_matches [0 ], 1 )
8605+ # TODO use separate new_dim outside loop ...
8606+ out = out .copy_template_replace_dim_tag (axis = 1 , new_dim_tag = new_dim )
8607+ return out
8608+
8609+ # noinspection PyMethodOverriding
8610+ @classmethod
8611+ def get_rec_initial_extra_outputs (cls , network , batch_dim , rec_layer , sources , output , new_dim , ** kwargs ):
8612+ """
8613+ :param returnn.tf.network.TFNetwork network:
8614+ :param tf.Tensor batch_dim:
8615+ :param TFNetworkRecLayer.RecLayer|LayerBase rec_layer:
8616+ :param list[LayerBase] sources:
8617+ :param Data output:
8618+ :param DimensionTag new_dim:
8619+ :rtype: dict[str,tf.Tensor]
8620+ """
8621+ if network .is_inside_rec_layer ():
8622+ shape = []
8623+ for tag in output .dim_tags :
8624+ if tag .is_batch_dim ():
8625+ shape .append (batch_dim )
8626+ elif tag == new_dim :
8627+ shape .append (0 )
8628+ elif tag .dimension is not None :
8629+ shape .append (tag .dimension )
8630+ else :
8631+ assert tag .dyn_size is not None
8632+ shape .append (tf .math .reduce_max (tag .dyn_size ))
8633+ return {"state" : tf .zeros (shape , dtype = output .dtype )}
8634+ else :
8635+ return {}
8636+
8637+ @classmethod
8638+ def get_rec_initial_extra_outputs_shape_invariants (cls , network , sources , output , ** kwargs ):
8639+ """
8640+ :param returnn.tf.network.TFNetwork network:
8641+ :param list[LayerBase] sources:
8642+ :param Data output:
8643+ :rtype: dict[str, tf.TensorShape]
8644+ """
8645+ if network .is_inside_rec_layer ():
8646+ return {"state" : tf .TensorShape (output .batch_shape )}
8647+ else :
8648+ return {}
0 commit comments