Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1988,7 +1988,7 @@ def get_output(self):
fixed_seq_len = input_seq_len
if fixed_seq_len is not None:
time_dim_tag = DimensionTag.get_tag_from_size_tensor(fixed_seq_len)
assert time_dim_tag is self.time_dim_tag
assert time_dim_tag == self.time_dim_tag
with tf.name_scope("check_seq_len_batch_size"):
fixed_seq_len = check_input_dim(
fixed_seq_len, axis=0, dim=batch_dim * (input_beam.beam_size if input_beam else 1))
Expand Down Expand Up @@ -2631,8 +2631,15 @@ def cond(i, net_vars, acc_tas, seq_len_info=None):
if output_layer:
assert isinstance(output_layer, LayerBase)
output_data = output_layer.output.copy_as_time_major()
assert 0 in output_data.size_placeholder
rec_layer.output.size_placeholder = output_data.size_placeholder.copy()
self.time_dim_tag.declare_same_as(output_data.get_time_dim_tag())
assert len(rec_layer.output.dim_tags) == len(output_data.dim_tags)
for tag1, tag2 in zip(rec_layer.output.dim_tags, output_data.dim_tags):
assert tag1.is_equal(tag2, allow_same_feature_dim=True)
# Make sure they are the same.
# It can happen that they are not when the dim tag is created inside,
# and then created once for the template layer, and again for the real layer.
# Make sure they are really the same such that we get all information like dyn sizes.
tag1.declare_same_as(tag2)
output = output_data.placeholder
else:
assert seq_len is not None
Expand All @@ -2641,12 +2648,6 @@ def cond(i, net_vars, acc_tas, seq_len_info=None):
output = tensor_array_stack(
self.final_acc_tas_dict["output_output"], stop=max_seq_len, name="output_stack") # e.g. (time, batch, dim)

existing_time_dim_tag = DimensionTag.get_tag_from_size_tensor(rec_layer.output.size_placeholder[0])
if existing_time_dim_tag:
self.time_dim_tag.declare_same_as(existing_time_dim_tag)
else:
self.time_dim_tag.set_tag_on_size_tensor(rec_layer.output.size_placeholder[0], batch=rec_layer.output.batch)

for key in (
self.net.used_data_keys |
(self.input_layers_net.used_data_keys if self.input_layers_net else set()) |
Expand Down