diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index 1bcedf993..de6ff0da0 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -5220,10 +5220,14 @@ def decide(cls, src, output=None, owner=None, name=None, length_normalization=Fa output.size_placeholder = {} for i, size in src_data.size_placeholder.items(): tag = DimensionTag.get_tag_from_size_tensor(size) - size = tf.reshape(size, [batch_dim, beam_size]) # (batch, beam) - size = tf.gather_nd(size, indices=beam_idxs_ext) # (batch,) - if tag: + assert tag + tag = tag.get_for_batch(output.batch) + if tag.dyn_size is None: + size = tf.reshape(size, [batch_dim, beam_size]) # (batch, beam) + size = tf.gather_nd(size, indices=beam_idxs_ext) # (batch,) tag.set_tag_on_size_tensor(size, batch=output.batch) + else: + size = tag.dyn_size output.size_placeholder[i] = size final_search_choices = SearchChoices(owner=owner, is_decided=True, beam_size=1) if owner: