diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index 5c6805950..b7ee2b864 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -8531,3 +8531,173 @@ def get_out_data_from_opts(cls, name, sources, n_out, **kwargs): kind=DimensionTag.Types.Spatial, description="%s_rel_pos_enc_time" % name, dimension=None) data = data.copy_template_new_dim_tags((dummy_dim_tag, time_dim_tag, feature_dim_tag)) return data + + +class CumConcatLayer(_ConcatInputLayer): + """ + Concatenates all previous frames of a time-axis. + Like :class:`CumsumLayer` uses `sum`, this layer uses `concat`. + + This layer can be used as a base for auto-regressive self-attention. + + This layer expects to be inside a :class:`RecLayer`. + + Inside a rec loop (not optimized out), + this will concatenate the current input + to the previous accumulated inputs. + For an input of shape `input_shape`, + it will output a tensor of shape `[new_dim] + input_shape`. + `new_dim` is a special dimension, usually of length `i`, + where `i` is the current loop frame, + i.e. the length increases in every loop frame. + `new_dim` is specified by a separate own dim tag. + For example, in the first frame, + this will be of shape `[1] + input_shape`, + in the second frame shape `[2] + input_shape`, + and so on, + and in the last frame shape `[T] + input_shape`. + + Outside the rec loop (optimized out), + this layer expects an input with the time dim of the rec layer, + and returns the input as-is, + but replacing the time dim tag with the dim tag `new_dim` + converted as outside the loop. + + Normally the optimization should not matter for the user, + i.e. for the user, the logical behavior is always as being inside the rec loop. + Outside the loop, + the output represents a tensor of shape `[T, new_dim] + input_shape`, + although we actually have another `new_dim` outside the loop, + and `T` is not actually there, + but we still have all the information, + because the last frame has all information. + This `new_dim` outside the loop stores all the dynamic seq lengths + per frame of the loop, i.e. the dyn seq len are extended of shape [B,T] or [T] + (unlike usually just [B]). + This way following layers use different seq lengths of `new_dim` for different loop frames, + just like if the `T` dim would actually exist. + """ + layer_class = "cum_concat" + recurrent = True # order matters + + def __init__(self, new_dim, **kwargs): + """ + :param DimensionTag new_dim: + """ + super(CumConcatLayer, self).__init__(**kwargs) + rec_layer = self.network.get_rec_parent_layer(inside_loop=False) + assert rec_layer, "%r must be used inside a RecLayer" % self + out_axis = self.output.get_axis_from_description(new_dim) + new_dim_ = self.output.dim_tags[out_axis] + assert new_dim_.control_flow_ctx == self.output.control_flow_ctx == self.network.get_control_flow_ctx() + + if not self.input_data.has_axis(rec_layer.time_dim_tag): # inside loop + current_data = self.input_data.copy_compatible_to(self.output, unbroadcast=False) + current_frame = current_data.placeholder # [B, 1, ..., D] + last_frames = self._rec_previous_layer.rec_vars_outputs["state"] # [B, t, ..., D] + concat_frames = tf.concat([last_frames, current_frame], axis=out_axis) # [B, t+1, ..., D] + self.rec_vars_outputs["state"] = concat_frames + self.output.placeholder = concat_frames + + if not new_dim_.dyn_size_ext: + # Unbroadcasting to [B] is not needed because any layers operating on this + # should be able to handle extended dyn sizes. + # Clipping it to the max length for sequences in the loop which are already ended + # (i.e. considering the end flag) + # is also not needed because any calculations after the end are irrelevant. + # Note: In case we have some initial state/output, this can be extended. + dyn_size = self.network.get_rec_step_index() + 1 # scalar + new_dim_.dyn_size_ext = Data( + name="%s:cum-concat:size-inside" % self.name, + dim_tags=[], # scalar + placeholder=dyn_size, dtype="int32", + batch=self.output.batch, control_flow_ctx=self.network.get_control_flow_ctx()) + + else: # outside loop + # If not inside a rec loop, this layer is a no-op on the tensor. + self.output.placeholder = self.input_data.placeholder + + # However, we used new dim tags, which were already prepared. + # We now must fill in the extended dynamic size information. + if not new_dim_.dyn_size_ext: + # This must match the logic above for inside the loop. + # Note: In case we have some initial state/output, this can be extended. + dyn_size = tf.range(tf.math.reduce_max(rec_layer.time_dim_tag.dyn_size)) + 1 # [T] + new_dim_.dyn_size_ext = Data( + name="%s:cum-concat:size-outside" % self.name, + dim_tags=[rec_layer.time_dim_tag], + placeholder=dyn_size, dtype="int32", + batch=self.output.batch, control_flow_ctx=self.network.get_control_flow_ctx()) + + @classmethod + def get_out_data_from_opts(cls, name, network, sources, new_dim, **kwargs): + """ + :param str name: + :param returnn.tf.network.TFNetwork network: + :param list[LayerBase] sources: + :param DimensionTag new_dim: + :rtype: Data + """ + input_data = get_concat_sources_data_template(sources, name="%s_output" % name) + assert network.is_inside_rec_layer(inside_loop=False), "CumConcatLayer %r must be used inside a RecLayer" % name + rec_time_dim = network.get_inside_rec_time_dim(inside_loop=False) + assert rec_time_dim + ctx = network.get_control_flow_ctx() + assert ctx == input_data.control_flow_ctx + new_dim_in_ctx = new_dim.get_for_batch_ctx(batch=input_data.batch, ctx=ctx) + + if not input_data.has_axis(rec_time_dim): # inside loop + assert ctx and ctx.is_loop() and ctx.loop_spatial_dim == rec_time_dim + # Currently SelectSearchSourcesLayer assumes that all rec_vars_outputs are batch-major. + # Therefore we here copy the input as batch-major, and then add the time axis at axis 1. + # In the future, when SelectSearchSourcesLayer has support for this, we can change this to operate on axis 0, + # which should be more efficient + out = input_data.copy_as_batch_major() + out = out.copy_add_dim_by_tag(new_dim_in_ctx, unbroadcast=True, axis=1) + return out + + else: # outside loop + # Assume that the input has the time dim from the rec layer. + axis = input_data.get_axis_from_description(rec_time_dim) + return input_data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=new_dim_in_ctx) + + # noinspection PyMethodOverriding + @classmethod + def get_rec_initial_extra_outputs(cls, network, batch_dim, rec_layer, sources, output, new_dim, **kwargs): + """ + :param returnn.tf.network.TFNetwork network: + :param tf.Tensor batch_dim: + :param returnn.tf.layers.rec.RecLayer|LayerBase rec_layer: + :param list[LayerBase] sources: + :param Data output: + :param DimensionTag new_dim: + :rtype: dict[str,tf.Tensor] + """ + if network.is_inside_rec_layer(): + shape = [] + for tag in output.dim_tags: + if tag.is_batch_dim(): + shape.append(batch_dim) + elif tag == new_dim: + shape.append(0) + elif tag.dimension is not None: + shape.append(tag.dimension) + else: + assert tag.dyn_size is not None + shape.append(tf.math.reduce_max(tag.dyn_size)) + return {"state": tf.zeros(shape, dtype=output.dtype)} + else: + return {} + + @classmethod + def get_rec_initial_extra_outputs_shape_invariants(cls, network, sources, output, **kwargs): + """ + :param returnn.tf.network.TFNetwork network: + :param list[LayerBase] sources: + :param Data output: + :rtype: dict[str, tf.TensorShape] + """ + if network.is_inside_rec_layer(): + return {"state": tf.TensorShape(output.batch_shape)} + else: + return {} diff --git a/tests/test_TFNetworkRecLayer.py b/tests/test_TFNetworkRecLayer.py index 7595b3df4..cfcd55309 100644 --- a/tests/test_TFNetworkRecLayer.py +++ b/tests/test_TFNetworkRecLayer.py @@ -3385,8 +3385,7 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha rec_layer_dict["unit"].update(other_subnet_layers) config = Config({ "debug_print_layer_output_template": True, - "num_inputs": n_in, - "num_outputs": n_out + "extern_data": {"data": {"dim": n_in}}, }) from returnn.tf.layers.rec import _SubnetworkRecCell with make_scope() as session: @@ -3463,6 +3462,40 @@ def test_reclayer_optimize_out_selfatt_left(): "class": "self_attention", "attention_left_only": True, "num_heads": 2, "total_key_dim": 6, "n_out": 18}) +def test_reclayer_optimize_out_cum_concat_gen_self_att(): + new_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="cum_concat_new_dim") + n_key = 5 + n_value = 7 + check_reclayer_optimize_out( + {"class": "linear", "from": "att", "activation": None}, + { + # This is very much the vanilla self attention, + # implemented via the new generic way. + # See https://github.com/rwth-i6/returnn/issues/391 for a long discussion. + # Commented shapes are always for the layers inside the loop (not optimized). + "qkv": {"class": "linear", "from": "data:source", "activation": None, "n_out": n_key * 2 + n_value}, # [B,2*K+V] + "qkv_split": {"class": "split", "from": "qkv", "size_splits": [n_key, n_key, n_value]}, + "q": {"class": "copy", "from": "qkv_split/0"}, # inside [B,K]. optimized out [T,B,K] + "k": {"class": "copy", "from": "qkv_split/1"}, # inside [B,K]. optimized out [T,B,K] + "v": {"class": "copy", "from": "qkv_split/2"}, # inside [B,V]. optimized out [T,B,V] + # cum_concat here. Note that the optimized-out shape is not as you might expect [T,max(t),B,K], + # but instead using the optimized format, with extended dyn size on the special dim tag, + # i.e. [t*,B,K], representing [T,t*,B,K]. + "k_accum": {"class": "cum_concat", "new_dim": new_dim, "from": "k"}, # inside [t,B,K]. opt out [t*,B,K] + "v_accum": {"class": "cum_concat", "new_dim": new_dim, "from": "v"}, # inside [t,B,V]. opt out [t*,B,K] + "energy": { + "class": "dot", "from": ["q", "k_accum"], + "red1": "static:-1", "red2": "static:-1", + "var1": None, "var2": new_dim}, # inside [B,t]. optimized out [T,B,t*] + "att_weights": { + "class": "softmax_over_spatial", "from": "energy", "axis": new_dim}, # inside [B,t]. opt out [T,B,t*] + "att": { + "class": "dot", "from": ["att_weights", "v_accum"], + "red1": new_dim, "red2": new_dim, + "var1": None, "var2": "static:-1"}, # inside [B,V]. opt out [T,B,V] + }) + + def test_reclayer_optimize_out_dot(): # Used for multi-head dot-attention. AttNumHeads = 4