Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 153 additions & 11 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,6 +1473,36 @@ def __call__(lself, name, is_prev_time_frame=False):
print(s)
raise

def _add_template_layer(self, layer_name, layer_dict):
"""
Use this for simple helpers, after the main template net is already created.
This does not expect layer creation exceptions,
and expects that all dependencies are already created,
and all dependencies are other layers inside our subnet.

:param str layer_name:
:param dict[str] layer_dict: not yet transformed
:rtype: _TemplateLayer
"""
from returnn.tf.network import get_layer_class
assert layer_name not in self.layer_data_templates
self.net_dict[layer_name] = layer_dict
# We replicate the _construct_template logic here, but simplified,
# i.e. not expecting exceptions, and expecting that all dep layers are created.
layer = _TemplateLayer(name=layer_name, network=self.net, cell=self)
layer_dict = layer_dict.copy()
layer_class_name = layer_dict.pop("class")
layer_class = get_layer_class(layer_class_name)
layer_dict["_network"] = self.net
layer_dict["_name"] = layer_name
layer_class.transform_config_dict(
layer_dict, network=self.net, get_layer=lambda _name: self.layer_data_templates[_name])
out = layer_class.get_out_data_from_opts(name=layer_name, network=self.net, **layer_dict)
out = layer_class.fixup_out_data(output=out, network=self.net)
layer.init(output=out, layer_class=layer_class, **layer_dict.copy())
self.layer_data_templates[layer_name] = layer
return layer

def _handle_construct_exception(self, description, exception):
"""
:param str description:
Expand Down Expand Up @@ -1873,18 +1903,33 @@ class OutputToAccumulate:
"""

# noinspection PyShadowingNames
def __init__(self, name, dtype, element_shape, get):
def __init__(self, name, get, dtype=None, element_shape=None, same_shape_every_frame=None, data=None):
"""
:param str name:
:param ()->(Data|tf.Tensor|None) get:
:param tf.DType|str dtype:
:param tuple[int|None] element_shape:
:param ()->(tf.Tensor|None) get:
:param bool same_shape_every_frame:
:param Data|None data:
"""
from returnn.tf.util.basic import get_valid_scope_name_from_str
self.name = name
self.tf_scope_name = get_valid_scope_name_from_str(name.replace("/", "_"))
self.data = data
if dtype is None:
assert data
dtype = data.dtype
self.dtype = dtype
if element_shape is None:
assert data
element_shape = data.batch_shape
self.element_shape = element_shape
if same_shape_every_frame is None:
if data:
same_shape_every_frame = not data.have_varying_shape_in_ctx()
else:
same_shape_every_frame = True
self.same_shape_every_frame = same_shape_every_frame
self.get = get
self.get_returned_none = None # type: typing.Optional[bool]

Expand All @@ -1904,6 +1949,16 @@ def write_to_tensor_array(self, ta, index):
self.get_returned_none = True
return ta
else:
if isinstance(value, tf.Tensor):
pass
elif isinstance(value, Data):
assert self.data
assert value.dtype == self.data.dtype
assert value.batch_shape == self.data.batch_shape
assert value.have_varying_shape_in_ctx() == self.data.have_varying_shape_in_ctx()
value = value.placeholder
else:
raise TypeError("OutputToAccumulate.get: expected tf.Tensor or Data but got %r" % type(value))
self.get_returned_none = False
return ta.write(index=index, value=value, name="%s_acc_ta_write" % self.tf_scope_name)

Expand All @@ -1916,8 +1971,73 @@ def get_final_tensor_array(self, ta):
assert self.get_returned_none is not None
if self.get_returned_none:
return None
if not self.same_shape_every_frame:
ta = self._make_padded_tensor_array(ta)
return ta

def _make_padded_tensor_array(self, ta):
"""
:param tf.TensorArray ta:
:rtype: tf.TensorArray
"""
assert not self.same_shape_every_frame
# First get the max shape of each element. Then create a new TensorArray which can hold all elements.
# Because we use clear_after_read in ta, even to get the max shape, we have to create a new TensorArray.
assert self.data
size = ta.size()
buffer_ta = tf.TensorArray(
name="acc_ta_infer_max_shape_%s" % self.tf_scope_name,
dtype=self.dtype,
element_shape=tf.TensorShape(self.element_shape),
size=size,
clear_after_read=True,
infer_shape=False)

def _body_infer_max_shape(i, max_shape_, new_ta_):
"""
:param tf.Tensor i: scalar
:param tf.Tensor max_shape_:
:param tf.TensorArray new_ta_:
"""
elem = ta.read(i)
max_shape_ = tf.maximum(max_shape_, tf.shape(elem))
new_ta_ = new_ta_.write(i, elem)
return i + 1, max_shape_, new_ta_

max_shape = tf.convert_to_tensor([d if d else 0 for d in self.data.batch_shape], dtype=tf.int32)
_, max_shape, buffer_ta = tf.while_loop(
cond=lambda i, *_args: tf.less(i, size),
body=_body_infer_max_shape,
loop_vars=(0, max_shape, buffer_ta))

# Now again create a new TensorArray.
new_ta_padded = tf.TensorArray(
name="acc_ta_pad_max_shape_%s" % self.tf_scope_name,
dtype=self.dtype,
element_shape=tf.TensorShape(self.element_shape),
size=size,
clear_after_read=True,
infer_shape=True)

def _body_pad_max_shape(i, new_ta_):
"""
:param tf.Tensor i: scalar
:param tf.TensorArray new_ta_:
"""
from returnn.tf.util.basic import get_shape
elem = buffer_ta.read(i)
elem_shape = get_shape(elem)
pad_values = [(0, max_shape[a] - elem_shape[a]) for a in range(len(elem_shape))]
elem_padded = tf.pad(elem, pad_values)
new_ta_ = new_ta_.write(i, elem_padded)
return i + 1, new_ta_

_, new_ta_padded = tf.while_loop(
cond=lambda i, *_args: tf.less(i, size),
body=_body_pad_max_shape,
loop_vars=(0, new_ta_padded))
return new_ta_padded

def get_output(self):
"""
:return: output of shape (time, batch, dim), search choices
Expand Down Expand Up @@ -2085,11 +2205,11 @@ def add_output_to_acc(layer_name):
name_ = "output_%s" % layer_name
if any([(out.name == name_) for out in outputs_to_accumulate]):
return
template_layer = self.layer_data_templates[layer_name]
outputs_to_accumulate.append(_SubnetworkRecCell.OutputToAccumulate(
name=name_,
dtype=self.layer_data_templates[layer_name].output.dtype,
element_shape=self.layer_data_templates[layer_name].output.batch_shape,
get=lambda: self.net.get_layer(layer_name).output.placeholder))
data=template_layer.output,
get=lambda: self.net.get_layer(layer_name).output))

for name, template in self.layer_data_templates.items():
if template.is_output_layer():
Expand Down Expand Up @@ -2248,9 +2368,8 @@ def get_loop_loss():
if layer_name in self.layers_in_loop:
outputs_to_accumulate.append(_SubnetworkRecCell.OutputToAccumulate(
name="debug_output_%s" % layer_name,
dtype=self.layer_data_templates[layer_name].output.dtype,
element_shape=self.layer_data_templates[layer_name].output.batch_shape,
get=lambda name_=layer_name: self.net.get_layer(name_).output.placeholder))
data=self.layer_data_templates[layer_name].output,
get=lambda name_=layer_name: self.net.get_layer(name_).output))

# Maybe some of the moved-out output-layers depend on data inside the loop,
# so we should accumulate it to have access to it.
Expand All @@ -2262,6 +2381,24 @@ def get_loop_loss():
add_output_to_acc(dep.name)
needed_outputs.add(dep.name)

in_loop_ctx_dim_tags = set()
for layer_name in self.layers_in_loop:
if layer_name in list(needed_outputs):
layer = self.layer_data_templates[layer_name]
for tag in layer.output.dim_tags:
if tag.control_flow_ctx == self.net.control_flow_ctx:
if tag not in in_loop_ctx_dim_tags:
in_loop_ctx_dim_tags.add(tag)
# The helper layer name does not matter except for debugging and it should not clash with other layers.
# The helper layer name matters only on the sense that it must come sorted before other extra layers,
# such that we construct it first in _construct_output_layers_moved_out.
helper_layer_name = ":dyn-tag-accum:%i:%s" % (len(in_loop_ctx_dim_tags), layer_name)
helper_layer_dict = {"class": "length", "from": layer_name, "axis": tag}
self._add_template_layer(helper_layer_name, helper_layer_dict)
add_output_to_acc(helper_layer_name)
needed_outputs.add(helper_layer_name)
extra_output_layers.add(helper_layer_name)

# Tensor arrays for any layers which were moved out.
input_layers_moved_out_tas = {}
if self.input_layers_moved_out:
Expand Down Expand Up @@ -2306,7 +2443,7 @@ def get_loop_loss():
size=min_loop_len,
dynamic_size=True, # we will automatically grow it when needed
clear_after_read=not out.name.startswith("choice_"),
infer_shape=True)
infer_shape=out.same_shape_every_frame)
for out in outputs_to_accumulate]

def body(i, net_vars, acc_tas, seq_len_info=None):
Expand Down Expand Up @@ -3199,6 +3336,7 @@ def _construct_output_layers_moved_out(self, loop_accumulated, seq_len, extra_ou
from returnn.tf.util.basic import tensor_array_stack, concat_with_opt_broadcast
from returnn.tf.network import TFNetwork, ExternData
from .base import InternalLayer
from .basic import LengthLayer

self.output_layers_net = TFNetwork(
name="%s/%s(rec-subnet-output)" % (
Expand Down Expand Up @@ -3269,6 +3407,10 @@ def get_loop_acc_layer(name):
if latest_layer_choice_name:
loop_acc_layers_search_choices[name] = latest_layer_choice_name
loop_acc_layers[name] = layer_
if isinstance(in_loop_layer, LengthLayer):
tag = in_loop_layer.dim_tag.get_for_batch_ctx(layer_.output.batch, layer_.output.control_flow_ctx)
if not tag.dyn_size_ext:
tag.dyn_size_ext = layer_.output
return layer_

# noinspection PyShadowingNames
Expand Down Expand Up @@ -3327,10 +3469,10 @@ def get_layer(name):
# Same scope as the main subnet, so that it stays compatible.
# noinspection PyProtectedMember
with reuse_name_scope(self.parent_rec_layer._rec_scope):
for layer_name in sorted(extra_output_layers):
self.output_layers_net.layers[layer_name] = get_layer(layer_name)
for layer_name in self.output_layers_moved_out:
get_layer(layer_name)
for layer_name in extra_output_layers:
self.output_layers_net.layers[layer_name] = get_layer(layer_name)

# We want to have one single layer with search choices.
for name, search_choices in search_choices_cache.items():
Expand Down
12 changes: 12 additions & 0 deletions returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2664,6 +2664,18 @@ def get_dynamic_batch_shape(self):
"""
return [self.get_dim(axis) for axis in range(self.batch_ndim)]

def have_varying_shape_in_ctx(self):
"""
:return: whether the (dynamic) shape can change in this control flow context.
E.g. when self.control_flow_context is a loop, and we have one dynamic dim
where dyn_size_ext has the same control_flow_context
(such that dyn_size_ext has e.g. shape [B,T] outside the loop).
This can be relevant for accumulating values of self.placeholder
e.g. via tf.TensorArray.
:rtype: bool
"""
return any(tag.control_flow_ctx for tag in self.dim_tags)

@property
def size_placeholder(self):
"""
Expand Down
36 changes: 36 additions & 0 deletions tests/test_TFNetworkRecLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3506,6 +3506,42 @@ def test_reclayer_optimize_out_cum_concat_gen_self_att():
})


def test_reclayer_optimize_out_accum_loop_dyn_size():
# We want to test for the case where some layer inside the loop
# generates some dyn size of shape [B] which is different in each loop frame.
# So outside the loop, the accumulated dyn size should be of shape [T,B] or [B,T].
# To test this, we first generate some random seq lens based on the input data (shape [B,T,D]).
from returnn.tf.util.basic import py_print

def _eval_seq_lens(source, **_kwargs):
# Get some random varying seq lens.
res = tf.cast(4. * source(0) / source(1) + 0.3 * tf.cast(source(2), tf.float32), tf.int32) + 1
res = py_print(res, ["seq lens", res, "step :i", source(2)])
return res

check_reclayer_optimize_out(
subnet_layer_dict={"class": "linear", "from": "combine", "activation": None, "n_out": 3},
other_subnet_layers={
"exp_data": {"class": "activation", "from": "data:source", "activation": "exp"}, # >0
"sum_exp_data": {"class": "reduce", "mode": "sum", "from": "exp_data", "axis": "F"}, # [B]
"seq_lens": {
"class": "eval", "from": ["sum_exp_data", "base:max_sum_exp_data", ":i"],
"out_type": {"dtype": "int32"},
"eval": _eval_seq_lens}, # [B]
"range": {"class": "range_from_length", "from": "seq_lens"}, # [T_new]
"combine": {
"class": "eval", "from": ["data:source", "range"],
"eval": "source(0) + 0.1 * tf.cast(source(1), tf.float32)"}, # [B,T_new,D]
},
shared_base_net={
"exp_data": {"class": "activation", "from": "data", "activation": "exp"}, # >0
"sum_exp_data": {"class": "reduce", "mode": "sum", "from": "exp_data", "axis": "F"}, # [B,T]
"max_sum_exp_data": {
"class": "reduce", "mode": "max", "from": "sum_exp_data", "axis": "T",
"is_output_layer": True}, # [B]
})


def test_reclayer_optimize_out_dot():
# Used for multi-head dot-attention.
AttNumHeads = 4
Expand Down