From 6b39ceda1feb9b6157af4f2a5ac35934f4972fe8 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Sat, 18 Sep 2021 23:09:01 +0200 Subject: [PATCH 1/5] Rec subnet dim tag dyn sizes accum logic, accum helper layer --- returnn/tf/layers/rec.py | 48 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index 91f3cf067..d8de29511 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -1473,6 +1473,36 @@ def __call__(lself, name, is_prev_time_frame=False): print(s) raise + def _add_template_layer(self, layer_name, layer_dict): + """ + Use this for simple helpers, after the main template net is already created. + This does not expect layer creation exceptions, + and expects that all dependencies are already created, + and all dependencies are other layers inside our subnet. + + :param str layer_name: + :param dict[str] layer_dict: not yet transformed + :rtype: _TemplateLayer + """ + from returnn.tf.network import get_layer_class + assert layer_name not in self.layer_data_templates + self.net_dict[layer_name] = layer_dict + # We replicate the _construct_template logic here, but simplified, + # i.e. not expecting exceptions, and expecting that all dep layers are created. + layer = _TemplateLayer(name=layer_name, network=self.net, cell=self) + layer_dict = layer_dict.copy() + layer_class_name = layer_dict.pop("class") + layer_class = get_layer_class(layer_class_name) + layer_dict["_network"] = self.net + layer_dict["_name"] = layer_name + layer_class.transform_config_dict( + layer_dict, network=self.net, get_layer=lambda _name: self.layer_data_templates[_name]) + out = layer_class.get_out_data_from_opts(name=layer_name, network=self.net, **layer_dict) + out = layer_class.fixup_out_data(output=out, network=self.net) + layer.init(output=out, layer_class=layer_class, **layer_dict.copy()) + self.layer_data_templates[layer_name] = layer + return layer + def _handle_construct_exception(self, description, exception): """ :param str description: @@ -2262,6 +2292,24 @@ def get_loop_loss(): add_output_to_acc(dep.name) needed_outputs.add(dep.name) + in_loop_ctx_dim_tags = set() + for layer_name in self.layers_in_loop: + if layer_name in list(needed_outputs): + layer = self.layer_data_templates[layer_name] + for tag in layer.output.dim_tags: + if tag.control_flow_ctx == self.net.control_flow_ctx: + if tag not in in_loop_ctx_dim_tags: + in_loop_ctx_dim_tags.add(tag) + # The helper layer name does not matter except for debugging and it should not clash with other layers. + # The helper layer name matters only on the sense that it must come sorted before other extra layers, + # such that we construct it first in _construct_output_layers_moved_out. + helper_layer_name = ":dyn-tag-accum:%i:%s" % (len(in_loop_ctx_dim_tags), layer_name) + helper_layer_dict = {"class": "length", "from": layer_name, "axis": tag} + self._add_template_layer(helper_layer_name, helper_layer_dict) + add_output_to_acc(helper_layer_name) + needed_outputs.add(helper_layer_name) + extra_output_layers.add(helper_layer_name) + # Tensor arrays for any layers which were moved out. input_layers_moved_out_tas = {} if self.input_layers_moved_out: From 0b7e7fae0f881afac9e168f1f225e1556a37b0dc Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 20 Sep 2021 17:49:38 +0200 Subject: [PATCH 2/5] Rec subnet, construct output extra layers first and in order --- returnn/tf/layers/rec.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index d8de29511..6f7a0b66a 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -3375,10 +3375,10 @@ def get_layer(name): # Same scope as the main subnet, so that it stays compatible. # noinspection PyProtectedMember with reuse_name_scope(self.parent_rec_layer._rec_scope): + for layer_name in sorted(extra_output_layers): + self.output_layers_net.layers[layer_name] = get_layer(layer_name) for layer_name in self.output_layers_moved_out: get_layer(layer_name) - for layer_name in extra_output_layers: - self.output_layers_net.layers[layer_name] = get_layer(layer_name) # We want to have one single layer with search choices. for name, search_choices in search_choices_cache.items(): From eebb257cf9b164245d20b17db8edd2f7c927d361 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 20 Sep 2021 17:50:48 +0200 Subject: [PATCH 3/5] Rec subnet, construct output, handle accum length layer dim tag --- returnn/tf/layers/rec.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index 6f7a0b66a..4c584df2e 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -3247,6 +3247,7 @@ def _construct_output_layers_moved_out(self, loop_accumulated, seq_len, extra_ou from returnn.tf.util.basic import tensor_array_stack, concat_with_opt_broadcast from returnn.tf.network import TFNetwork, ExternData from .base import InternalLayer + from .basic import LengthLayer self.output_layers_net = TFNetwork( name="%s/%s(rec-subnet-output)" % ( @@ -3317,6 +3318,10 @@ def get_loop_acc_layer(name): if latest_layer_choice_name: loop_acc_layers_search_choices[name] = latest_layer_choice_name loop_acc_layers[name] = layer_ + if isinstance(in_loop_layer, LengthLayer): + tag = in_loop_layer.dim_tag.get_for_batch_ctx(layer_.output.batch, layer_.output.control_flow_ctx) + if not tag.dyn_size_ext: + tag.dyn_size_ext = layer_.output return layer_ # noinspection PyShadowingNames From 6a89d62a761693940a13587907e7b28146554744 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 21 Sep 2021 01:51:01 +0200 Subject: [PATCH 4/5] Rec subnet support accum TensorArray with different shaped elements --- returnn/tf/layers/rec.py | 107 +++++++++++++++++++++++++++++++++++---- returnn/tf/util/data.py | 12 +++++ 2 files changed, 110 insertions(+), 9 deletions(-) diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index 4c584df2e..28d0e4a62 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -1903,18 +1903,33 @@ class OutputToAccumulate: """ # noinspection PyShadowingNames - def __init__(self, name, dtype, element_shape, get): + def __init__(self, name, get, dtype=None, element_shape=None, same_shape_every_frame=None, data=None): """ :param str name: + :param ()->(Data|tf.Tensor|None) get: :param tf.DType|str dtype: :param tuple[int|None] element_shape: - :param ()->(tf.Tensor|None) get: + :param bool same_shape_every_frame: + :param Data|None data: """ from returnn.tf.util.basic import get_valid_scope_name_from_str self.name = name self.tf_scope_name = get_valid_scope_name_from_str(name.replace("/", "_")) + self.data = data + if dtype is None: + assert data + dtype = data.dtype self.dtype = dtype + if element_shape is None: + assert data + element_shape = data.batch_shape self.element_shape = element_shape + if same_shape_every_frame is None: + if data: + same_shape_every_frame = not data.have_varying_shape_in_ctx() + else: + same_shape_every_frame = True + self.same_shape_every_frame = same_shape_every_frame self.get = get self.get_returned_none = None # type: typing.Optional[bool] @@ -1934,6 +1949,16 @@ def write_to_tensor_array(self, ta, index): self.get_returned_none = True return ta else: + if isinstance(value, tf.Tensor): + pass + elif isinstance(value, Data): + assert self.data + assert value.dtype == self.data.dtype + assert value.batch_shape == self.data.batch_shape + assert value.have_varying_shape_in_ctx() == self.data.have_varying_shape_in_ctx() + value = value.placeholder + else: + raise TypeError("OutputToAccumulate.get: expected tf.Tensor or Data but got %r" % type(value)) self.get_returned_none = False return ta.write(index=index, value=value, name="%s_acc_ta_write" % self.tf_scope_name) @@ -1946,8 +1971,73 @@ def get_final_tensor_array(self, ta): assert self.get_returned_none is not None if self.get_returned_none: return None + if not self.same_shape_every_frame: + ta = self._make_padded_tensor_array(ta) return ta + def _make_padded_tensor_array(self, ta): + """ + :param tf.TensorArray ta: + :rtype: tf.TensorArray + """ + assert not self.same_shape_every_frame + # First get the max shape of each element. Then create a new TensorArray which can hold all elements. + # Because we use clear_after_read in ta, even to get the max shape, we have to create a new TensorArray. + assert self.data + size = ta.size() + buffer_ta = tf.TensorArray( + name="acc_ta_infer_max_shape_%s" % self.tf_scope_name, + dtype=self.dtype, + element_shape=tf.TensorShape(self.element_shape), + size=size, + clear_after_read=True, + infer_shape=False) + + def _body_infer_max_shape(i, max_shape_, new_ta_): + """ + :param tf.Tensor i: scalar + :param tf.Tensor max_shape_: + :param tf.TensorArray new_ta_: + """ + elem = ta.read(i) + max_shape_ = tf.maximum(max_shape_, tf.shape(elem)) + new_ta_ = new_ta_.write(i, elem) + return i + 1, max_shape_, new_ta_ + + max_shape = tf.convert_to_tensor([d if d else 0 for d in self.data.batch_shape], dtype=tf.int32) + _, max_shape, buffer_ta = tf.while_loop( + cond=lambda i, *_args: tf.less(i, size), + body=_body_infer_max_shape, + loop_vars=(0, max_shape, buffer_ta)) + + # Now again create a new TensorArray. + new_ta_padded = tf.TensorArray( + name="acc_ta_pad_max_shape_%s" % self.tf_scope_name, + dtype=self.dtype, + element_shape=tf.TensorShape(self.element_shape), + size=size, + clear_after_read=True, + infer_shape=True) + + def _body_pad_max_shape(i, new_ta_): + """ + :param tf.Tensor i: scalar + :param tf.TensorArray new_ta_: + """ + from returnn.tf.util.basic import get_shape + elem = buffer_ta.read(i) + elem_shape = get_shape(elem) + pad_values = [(0, max_shape[a] - elem_shape[a]) for a in range(len(elem_shape))] + elem_padded = tf.pad(elem, pad_values) + new_ta_ = new_ta_.write(i, elem_padded) + return i + 1, new_ta_ + + _, new_ta_padded = tf.while_loop( + cond=lambda i, *_args: tf.less(i, size), + body=_body_pad_max_shape, + loop_vars=(0, new_ta_padded)) + return new_ta_padded + def get_output(self): """ :return: output of shape (time, batch, dim), search choices @@ -2115,11 +2205,11 @@ def add_output_to_acc(layer_name): name_ = "output_%s" % layer_name if any([(out.name == name_) for out in outputs_to_accumulate]): return + template_layer = self.layer_data_templates[layer_name] outputs_to_accumulate.append(_SubnetworkRecCell.OutputToAccumulate( name=name_, - dtype=self.layer_data_templates[layer_name].output.dtype, - element_shape=self.layer_data_templates[layer_name].output.batch_shape, - get=lambda: self.net.get_layer(layer_name).output.placeholder)) + data=template_layer.output, + get=lambda: self.net.get_layer(layer_name).output)) for name, template in self.layer_data_templates.items(): if template.is_output_layer(): @@ -2278,9 +2368,8 @@ def get_loop_loss(): if layer_name in self.layers_in_loop: outputs_to_accumulate.append(_SubnetworkRecCell.OutputToAccumulate( name="debug_output_%s" % layer_name, - dtype=self.layer_data_templates[layer_name].output.dtype, - element_shape=self.layer_data_templates[layer_name].output.batch_shape, - get=lambda name_=layer_name: self.net.get_layer(name_).output.placeholder)) + data=self.layer_data_templates[layer_name].output, + get=lambda name_=layer_name: self.net.get_layer(name_).output)) # Maybe some of the moved-out output-layers depend on data inside the loop, # so we should accumulate it to have access to it. @@ -2354,7 +2443,7 @@ def get_loop_loss(): size=min_loop_len, dynamic_size=True, # we will automatically grow it when needed clear_after_read=not out.name.startswith("choice_"), - infer_shape=True) + infer_shape=out.same_shape_every_frame) for out in outputs_to_accumulate] def body(i, net_vars, acc_tas, seq_len_info=None): diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 10e7d494c..572d4fa92 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -2664,6 +2664,18 @@ def get_dynamic_batch_shape(self): """ return [self.get_dim(axis) for axis in range(self.batch_ndim)] + def have_varying_shape_in_ctx(self): + """ + :return: whether the (dynamic) shape can change in this control flow context. + E.g. when self.control_flow_context is a loop, and we have one dynamic dim + where dyn_size_ext has the same control_flow_context + (such that dyn_size_ext has e.g. shape [B,T] outside the loop). + This can be relevant for accumulating values of self.placeholder + e.g. via tf.TensorArray. + :rtype: bool + """ + return any(tag.control_flow_ctx for tag in self.dim_tags) + @property def size_placeholder(self): """ From 3d17449d1af0494ec95fe39ab3d23df29a0cdc19 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 15 Sep 2021 17:07:19 +0200 Subject: [PATCH 5/5] test_reclayer_optimize_out_accum_loop_dyn_size --- tests/test_TFNetworkRecLayer.py | 36 +++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/test_TFNetworkRecLayer.py b/tests/test_TFNetworkRecLayer.py index bbdeb7fc7..e1f2d0552 100644 --- a/tests/test_TFNetworkRecLayer.py +++ b/tests/test_TFNetworkRecLayer.py @@ -3506,6 +3506,42 @@ def test_reclayer_optimize_out_cum_concat_gen_self_att(): }) +def test_reclayer_optimize_out_accum_loop_dyn_size(): + # We want to test for the case where some layer inside the loop + # generates some dyn size of shape [B] which is different in each loop frame. + # So outside the loop, the accumulated dyn size should be of shape [T,B] or [B,T]. + # To test this, we first generate some random seq lens based on the input data (shape [B,T,D]). + from returnn.tf.util.basic import py_print + + def _eval_seq_lens(source, **_kwargs): + # Get some random varying seq lens. + res = tf.cast(4. * source(0) / source(1) + 0.3 * tf.cast(source(2), tf.float32), tf.int32) + 1 + res = py_print(res, ["seq lens", res, "step :i", source(2)]) + return res + + check_reclayer_optimize_out( + subnet_layer_dict={"class": "linear", "from": "combine", "activation": None, "n_out": 3}, + other_subnet_layers={ + "exp_data": {"class": "activation", "from": "data:source", "activation": "exp"}, # >0 + "sum_exp_data": {"class": "reduce", "mode": "sum", "from": "exp_data", "axis": "F"}, # [B] + "seq_lens": { + "class": "eval", "from": ["sum_exp_data", "base:max_sum_exp_data", ":i"], + "out_type": {"dtype": "int32"}, + "eval": _eval_seq_lens}, # [B] + "range": {"class": "range_from_length", "from": "seq_lens"}, # [T_new] + "combine": { + "class": "eval", "from": ["data:source", "range"], + "eval": "source(0) + 0.1 * tf.cast(source(1), tf.float32)"}, # [B,T_new,D] + }, + shared_base_net={ + "exp_data": {"class": "activation", "from": "data", "activation": "exp"}, # >0 + "sum_exp_data": {"class": "reduce", "mode": "sum", "from": "exp_data", "axis": "F"}, # [B,T] + "max_sum_exp_data": { + "class": "reduce", "mode": "max", "from": "sum_exp_data", "axis": "T", + "is_output_layer": True}, # [B] + }) + + def test_reclayer_optimize_out_dot(): # Used for multi-head dot-attention. AttNumHeads = 4