Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion returnn/tf/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,13 @@ def fixup_out_data(cls, output, network):
extern_data.init_batch_info() # this should create it and also set it
assert output.batch
output.batch = output.batch.copy_set_beam(output.beam)
if output.control_flow_ctx != network.get_control_flow_ctx():
x = output.placeholder
output = output.copy_template_set_ctx(network.get_control_flow_ctx())
if x is not None:
# Some layers might just copy the input. But the input might have buggy ctx.
# Just leave the placeholder as-is. Most layers should anyway reset this.
output.placeholder = x
return output

def get_full_ctx_name(self):
Expand Down Expand Up @@ -1730,7 +1737,7 @@ def opt_get_layer(layer_name):
# Don't return layer, could be inside loop and that wont work.
output = net.layers[layer_name].output.copy_template()
if not output.have_time_axis() and with_time_dim:
output = output.copy_template_adding_time_dim()
output = output.copy_template_adding_time_dim().copy_template_set_ctx(network.get_control_flow_ctx())
if not output:
layer_desc_ = net.layers_desc[layer_name].copy()
class_name_ = layer_desc_.pop("class")
Expand Down
46 changes: 32 additions & 14 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,10 @@ def get_out_data_from_opts(cls, network, unit, _time_dim_tag=None, sources=(), i
out = None
if isinstance(unit, _SubnetworkRecCell): # subnetwork
subnet = unit
sub_out = subnet.layer_data_templates["output"].output.copy_template_adding_time_dim(
name="%s_output" % kwargs["name"], time_dim_axis=0)
sub_out = (
subnet.layer_data_templates["output"].output
.copy_template_adding_time_dim(name="%s_output" % kwargs["name"], time_dim_axis=0)
.copy_template_set_ctx(network.get_control_flow_ctx()))
if out:
assert sub_out.dim == out.dim
assert sub_out.shape == out.shape
Expand Down Expand Up @@ -993,21 +995,26 @@ def __init__(self, net_dict, source_data, time_dim_tag, rec_layer_name, parent_n
self.parent_net = parent_net
self.net_dict = safe_deep_copy(net_dict)
from returnn.tf.network import TFNetwork, ExternData, LossHolder
from returnn.tf.util.data import ControlFlowContext
control_flow_ctx = ControlFlowContext(
kind=ControlFlowContext.Types.Loop, outer_ctx=parent_net.get_control_flow_ctx())
control_flow_ctx.loop_spatial_dim = time_dim_tag
self.net = TFNetwork(
name="%s/%s(rec-subnet)" % (parent_net.name, rec_layer_name),
extern_data=ExternData(),
train_flag=parent_net.train_flag,
search_flag=parent_net.search_flag,
eval_flag=False,
inside_rec_time_dim=time_dim_tag,
control_flow_ctx=control_flow_ctx,
absolute_name_prefix="%s%s/" % (parent_net.get_absolute_name_prefix(), rec_layer_name),
parent_net=parent_net)
self.net.is_root_in_ctx = True
self.net.layers_desc.update(self.net_dict)
self.source_data = source_data
if source_data:
self.net.extern_data.data["source"] = (
source_data.copy_template_excluding_time_dim())
source_data.copy_template_excluding_time_dim().copy_template_set_ctx(control_flow_ctx))
self.time_dim_tag = time_dim_tag
self._time_dim_tags = {time_dim_tag} # type: typing.Set[DimensionTag]
if source_data:
Expand All @@ -1020,7 +1027,7 @@ def __init__(self, net_dict, source_data, time_dim_tag, rec_layer_name, parent_n
# These are just templates. You can use them as possible targets for dimension information,
# but not as actual sources or targets.
# Note: We maybe should check data.is_same_time_dim()...
self.net.extern_data.data[key] = data.copy_template_excluding_time_dim()
self.net.extern_data.data[key] = data.copy_template_excluding_time_dim().copy_template_set_ctx(control_flow_ctx)
self.layer_data_templates = {} # type: typing.Dict[str,_TemplateLayer]
self.prev_layers_needed = set() # type: typing.Set[str]
self.prev_layer_templates = {} # type: typing.Dict[str,_TemplateLayer]
Expand Down Expand Up @@ -1545,7 +1552,7 @@ def get_input_moved_out(name):
self.parent_rec_layer, self.parent_rec_layer.output.get_time_dim_tag(),
layer, layer.output.get_time_dim_tag())
return layer
output = layer.output.copy_template_excluding_time_dim()
output = layer.output.copy_template_excluding_time_dim().copy_template_set_ctx(self.net.control_flow_ctx)
with tf.name_scope("%s_moved_input" % name.replace(":", "_")):
if prev:
output.placeholder = tf.cond(
Expand Down Expand Up @@ -2513,7 +2520,8 @@ def cond(i, net_vars, acc_tas, seq_len_info=None):
self.parent_rec_layer, input_beam, output_beam,
self.parent_rec_layer.sources, self.parent_rec_layer.target))
assert output_template.output.batch.beam == output_beam
time_dim_tag = time_dim_tag.get_for_batch(output_template.output.batch)
time_dim_tag = time_dim_tag.get_for_batch_ctx(
batch=output_template.output.batch, ctx=self.net.control_flow_ctx)
assert time_dim_tag.dyn_size is not None
seq_len = time_dim_tag.dyn_size
else:
Expand Down Expand Up @@ -2772,7 +2780,7 @@ def get_choice_seq(choice_base):
latest_batch = (
latest_layer_choice.output.batch
or self.parent_rec_layer.output.batch.copy_set_beam(latest_layer_choice.output.beam))
tag = tag.get_for_batch(latest_batch)
tag = tag.get_for_batch_ctx(batch=latest_batch, ctx=self.net.control_flow_ctx)
assert tag.dyn_size is not None
assert tag.batch == latest_batch and tag.batch.beam == latest_layer_choice.output.beam
seq_len = tag.dyn_size
Expand Down Expand Up @@ -3216,7 +3224,10 @@ def get_loop_acc_layer(name):
acc_ta, latest_layer_choice_name, search_choices, resolved_seq_len = self._opt_search_resolve(
layer_name=name, acc_ta=acc_ta, final_net_vars=final_net_vars, seq_len=seq_len,
search_choices_cache=search_choices_cache)
output = self.layer_data_templates[name].output.copy_template_adding_time_dim(time_dim_axis=0)
output = (
self.layer_data_templates[name].output
.copy_template_adding_time_dim(time_dim_axis=0)
.copy_template_set_ctx(self.parent_net.get_control_flow_ctx()))
if latest_layer_choice_name:
output.beam = self.net.layers[latest_layer_choice_name].search_choices.get_beam_info()
elif search_choices:
Expand Down Expand Up @@ -3303,7 +3314,10 @@ def get_layer(name):
for name, search_choices in search_choices_cache.items():
if name not in self.output_layers_net.layers:
# Create dummy layer.
output = self.layer_data_templates[name].output.copy_template_adding_time_dim(time_dim_axis=0)
output = (
self.layer_data_templates[name].output
.copy_template_adding_time_dim(time_dim_axis=0)
.copy_template_set_ctx(self.output_layers_net.get_control_flow_ctx()))
output.beam = search_choices.get_beam_info()
layer = InternalLayer(name=name, network=self.output_layers_net, output=output)
self.output_layers_net.layers[name] = layer
Expand Down Expand Up @@ -3350,7 +3364,8 @@ def __init__(self, network, name, construct_stack=None, cell=None):
output=Data(
name="dummy_initial_template_data",
batch_dim_axis=0, time_dim_axis=None,
shape=()), # (B,). no time-dim
shape=(),
control_flow_ctx=network.get_control_flow_ctx()), # (B,). no time-dim
name=name, network=network)
self.output.size_placeholder = {} # must be initialized
self.layer_class = ":uninitialized-template"
Expand Down Expand Up @@ -5226,7 +5241,7 @@ def decide(cls, src, output=None, owner=None, name=None, length_normalization=Fa
for i, size in src_data.size_placeholder.items():
tag = DimensionTag.get_tag_from_size_tensor(size)
assert tag
tag = tag.get_for_batch(output.batch)
tag = tag.get_for_batch_ctx(batch=output.batch, ctx=output.control_flow_ctx)
if tag.dyn_size is None:
size = tf.reshape(size, [batch_dim, beam_size]) # (batch, beam)
size = tf.gather_nd(size, indices=beam_idxs_ext) # (batch,)
Expand Down Expand Up @@ -7071,8 +7086,11 @@ def _create_template(cls, name, network, sources, masked_from, unit,
# We don't care about the right masked input here, but just about deriving the right output shape.
if masked_from:
if network.is_inside_rec_layer(inside_loop=True):
source_data = masked_from.output.copy_template_excluding_time_dim(
name="%s_%s_masked_input_frame" % (masked_from.output.name, name))
source_data = (
masked_from.output
.copy_template_excluding_time_dim(
name="%s_%s_masked_input_frame" % (masked_from.output.name, name))
.copy_template_set_ctx(network.get_control_flow_ctx()))
else:
source_data = masked_from.output.copy_template(
name="%s_%s_masked_input" % (masked_from.output.name, name))
Expand Down Expand Up @@ -7347,7 +7365,7 @@ def get_out_data_from_opts(cls, name, network, sources, mask, **kwargs):
# thus when we unroll it to get into the loop, the RecLayer would have kept it as-is,
# i.e. it should still have that time-dim-axis.
# Maybe we should do some extra checks if that is like we assume, but for now, just assume that.
return out.copy_template_excluding_time_dim()
return out.copy_template_excluding_time_dim().copy_template_set_ctx(network.get_control_flow_ctx())
return out
assert out.have_time_axis()
out = out.copy_as_time_major()
Expand Down
14 changes: 14 additions & 0 deletions returnn/tf/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def __init__(self, config=None, extern_data=None, rnd_seed=None,
train_flag=None, eval_flag=None, search_flag=None,
parent_layer=None, parent_net=None, extra_parent_net=None, extra_name_prefix=None,
inside_rec_time_dim=None, over_rec_time_dim=None, over_rec_time_dim_subs=None,
control_flow_ctx=None,
absolute_name_prefix=None, name=None):
"""
:param returnn.config.Config config: only needed to init extern_data if not specified explicitly
Expand All @@ -370,6 +371,7 @@ def __init__(self, config=None, extern_data=None, rnd_seed=None,
:param DimensionTag|None inside_rec_time_dim: dim tag of outer rec layer, when run inside the loop (not optimized)
:param DimensionTag|None over_rec_time_dim: dim tag of outer rec layer, when optimized out of the loop
:param set[DimensionTag]|None over_rec_time_dim_subs: outer rec layer, out of loop, potential shorter
:param returnn.tf.util.data.ControlFlowContext control_flow_ctx:
:param str|None absolute_name_prefix:
:param str name: only for debugging
"""
Expand Down Expand Up @@ -432,6 +434,7 @@ def __init__(self, config=None, extern_data=None, rnd_seed=None,
self._inside_rec_time_dim = inside_rec_time_dim
self._over_rec_time_dim = over_rec_time_dim
self._over_rec_time_dim_subs = over_rec_time_dim_subs
self.control_flow_ctx = control_flow_ctx
self.extra_parent_net = extra_parent_net
self.extra_name_prefix = extra_name_prefix
self.extra_deps_in_extra = False
Expand Down Expand Up @@ -519,6 +522,17 @@ def get_root_ctx_network(self):
break
return net, "".join(reversed(path))

def get_control_flow_ctx(self):
"""
:rtype: returnn.tf.util.data.ControlFlowContext|None
"""
net = self
while net:
if net.control_flow_ctx:
return net.control_flow_ctx
net = net.parent_net
return None

def is_extra_internal_template_construction(self):
"""
:rtype: LayerBase|None
Expand Down
2 changes: 1 addition & 1 deletion returnn/tf/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4585,7 +4585,7 @@ def _maybe_to_base_seq_len(v):
base_out_tag.set_tag_on_size_tensor(base_out_seq_len)

assert base_out_tag.batch
out_tag = base_out_tag.get_for_batch(in_tag.batch)
out_tag = base_out_tag.get_for_batch_ctx(batch=in_tag.batch, ctx=in_tag.control_flow_ctx)
assert out_tag.dyn_size is not None
return out_tag.dyn_size

Expand Down
Loading