diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index 9fd6e4956..2fb0826a5 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -1988,7 +1988,7 @@ def get_output(self): fixed_seq_len = input_seq_len if fixed_seq_len is not None: time_dim_tag = DimensionTag.get_tag_from_size_tensor(fixed_seq_len) - assert time_dim_tag is self.time_dim_tag + assert time_dim_tag == self.time_dim_tag with tf.name_scope("check_seq_len_batch_size"): fixed_seq_len = check_input_dim( fixed_seq_len, axis=0, dim=batch_dim * (input_beam.beam_size if input_beam else 1)) @@ -2631,8 +2631,15 @@ def cond(i, net_vars, acc_tas, seq_len_info=None): if output_layer: assert isinstance(output_layer, LayerBase) output_data = output_layer.output.copy_as_time_major() - assert 0 in output_data.size_placeholder - rec_layer.output.size_placeholder = output_data.size_placeholder.copy() + self.time_dim_tag.declare_same_as(output_data.get_time_dim_tag()) + assert len(rec_layer.output.dim_tags) == len(output_data.dim_tags) + for tag1, tag2 in zip(rec_layer.output.dim_tags, output_data.dim_tags): + assert tag1.is_equal(tag2, allow_same_feature_dim=True) + # Make sure they are the same. + # It can happen that they are not when the dim tag is created inside, + # and then created once for the template layer, and again for the real layer. + # Make sure they are really the same such that we get all information like dyn sizes. + tag1.declare_same_as(tag2) output = output_data.placeholder else: assert seq_len is not None @@ -2641,12 +2648,6 @@ def cond(i, net_vars, acc_tas, seq_len_info=None): output = tensor_array_stack( self.final_acc_tas_dict["output_output"], stop=max_seq_len, name="output_stack") # e.g. (time, batch, dim) - existing_time_dim_tag = DimensionTag.get_tag_from_size_tensor(rec_layer.output.size_placeholder[0]) - if existing_time_dim_tag: - self.time_dim_tag.declare_same_as(existing_time_dim_tag) - else: - self.time_dim_tag.set_tag_on_size_tensor(rec_layer.output.size_placeholder[0], batch=rec_layer.output.batch) - for key in ( self.net.used_data_keys | (self.input_layers_net.used_data_keys if self.input_layers_net else set()) |