diff --git a/returnn/tf/layers/base.py b/returnn/tf/layers/base.py index a1633dd14..d955962a6 100644 --- a/returnn/tf/layers/base.py +++ b/returnn/tf/layers/base.py @@ -11,7 +11,7 @@ from returnn.util.basic import NotSpecified, CollectionReadCheckCovered, BehaviorVersion import returnn.tf.compat as tf_compat import returnn.tf.util.basic as tf_util -from returnn.tf.util.data import Data, SearchBeam +from returnn.tf.util.data import Data from returnn.tf.util.basic import OutputWithActivation, CustomUpdate, reuse_name_scope from returnn.log import log @@ -250,6 +250,7 @@ def _base_get_out_data_from_opts(cls, network, name, out_type=None, n_out=NotSpe :return: Data template (placeholder not set) :rtype: Data """ + from ..util.data import DimensionTag if callable(out_type): return out_type( network=network, name=name, n_out=n_out, target=target, size_target=size_target, sources=sources, loss=loss, @@ -268,9 +269,8 @@ def _base_get_out_data_from_opts(cls, network, name, out_type=None, n_out=NotSpe network=network, mark_data_key_as_used=False).dim if n_out is not NotSpecified: assert out_type["dim"] == n_out - sources_data = None - if sources and sources[0]: - sources_data = sources[0].output.copy_template() + sources_data_list = [src.output for src in sources if src] + sources_data = Data.get_common_data(sources_data_list, ignore_feature_dim=True) if sources_data_list else None if sources_data and not sources_data.sparse and not out_type.get("sparse", False): out_type.setdefault("dtype", sources_data.dtype) # You are supposed to set self.output.{batch_dim_axis,time_dim_axis} explicitly, @@ -291,38 +291,30 @@ def _base_get_out_data_from_opts(cls, network, name, out_type=None, n_out=NotSpe if "shape" not in out_type and "dim_tags" not in out_type: if sources_data: if out_type.get("sparse", False): - out_type.setdefault("shape", sources_data.shape_sparse) + out_type["dim_tags"] = sources_data.dim_tags_sparse else: # not sparse feature_dim_axis = out_type.get("feature_dim_axis", NotSpecified) - if feature_dim_axis is NotSpecified: - if sources_data.feature_dim_axis is not None: - feature_dim_axis = sources_data.feature_dim_axis - else: - feature_dim_axis = -1 - if sources_data.shape: - default_shape = list(sources_data.shape_dense) - if sources_data.batch_dim_axis is not None: - default_shape.insert(sources_data.batch_dim_axis, None) - default_shape[feature_dim_axis] = out_type.get("dim", None) - if out_type.get("batch_dim_axis") is not None: - default_shape.pop(out_type.get("batch_dim_axis")) - else: # source is scalar - if out_type.get("dim") or out_type.get("feature_dim_axis") is not None: - default_shape = (out_type.get("dim"),) + dim = out_type.get("dim", None) + dim_tags = list(sources_data.dim_tags_sparse) + feature_dim_tag = DimensionTag( + kind=DimensionTag.Types.Feature, description="%s:feature-dense" % name, dimension=dim) + if feature_dim_axis in (NotSpecified, None): + if sources_data.feature_dim_axis is None: + feature_dim_axis = len(dim_tags) else: - default_shape = () - out_type.setdefault("shape", tuple(default_shape)) + feature_dim_axis = sources_data.feature_dim_axis + dim_tags.insert(feature_dim_axis, feature_dim_tag) + out_type["dim_tags"] = dim_tags elif network.is_inside_rec_layer(): if out_type.get("sparse", False): out_type.setdefault("shape", ()) else: out_type.setdefault("shape", (out_type.get("dim", None),)) # Note: No special handling for feature_dim_axis here for now... - beam = None - for src in sources: - if src: # might be None if template construction - beam = SearchBeam.get_combined_beam(beam, src.output.beam) - out_type.setdefault("beam", beam) + if sources_data and sources_data.batch: + out_type.setdefault("batch", sources_data.batch) + if sources_data and sources_data.beam: + out_type.setdefault("beam", sources_data.beam) output = Data(**out_type) cls._post_init_output( output=output, network=network, target=target, size_target=size_target, _target_layers=_target_layers, diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 923df26ac..d3c2e2191 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -436,6 +436,9 @@ def transform(v): assert isinstance(v, (tf.Tensor, tf.TensorArray)) if isinstance(v, tf.Tensor) and v.get_shape().ndims == 0: return v # leave scalars as-is + if isinstance(v, tf.Tensor) and getattr(v, "_RETURNN_beam_expanded_base_data", None): + # This tensor was just expanded by a beam. Selecting beams are not needed. + return v for i, base_src_choices in enumerate(reversed(search_choices_seq)): assert isinstance(base_src_choices, SearchChoices) assert base_src_choices.src_beams is not None, ( @@ -452,7 +455,7 @@ def transform(v): i, get_valid_scope_name_from_str(base_src_choices.owner.name), len(search_choices_seq), get_valid_scope_name_from_str(search_choices.owner.name))) if tag: - tag.set_tag_on_size_tensor(v, batch=self.output.batch) + tag.set_tag_on_size_tensor(v, batch=self.output.batch.copy_set_beam(base_src_choices.get_beam_info())) self.used_search_choices_beams = True return v @@ -3128,7 +3131,7 @@ def __init__(self, sizes, num_axes, declare_same_sizes_as=None, **kwargs): for i, other in declare_same_sizes_as.items(): assert 0 <= i < num_axes other_dim_tag = other.output.get_size_dim_tag(0) - other_dim_tag.set_tag_on_size_tensor(size_placeholder[i], batch=self.output.batch) + other_dim_tag.set_tag_on_size_tensor(size_placeholder[i], batch=self.output.batch, same_as_before=True) self.output.size_placeholder = size_placeholder def get_dep_layers(self): @@ -5487,6 +5490,7 @@ def get_out_data_from_opts(cls, name, sources, red1=-1, red2=-2, var1=-2, var2=- :param bool add_var2_if_empty: :rtype: Data """ + from ..util.data import DimensionTag, BatchInfo assert len(sources) == 2, "dot-layer %r: needs exactly two sources" % (name,) # See __init__. a_out = sources[0].output.copy() @@ -5508,11 +5512,9 @@ def get_out_data_from_opts(cls, name, sources, red1=-1, red2=-2, var1=-2, var2=- assert all(b_axis in map_a_to_b_rem_axes.values() for b_axis in b_rem_axes) b_rem_axes = [map_a_to_b_rem_axes[a_axis] for a_axis in a_rem_axes] - a_shape = a_out.batch_shape - b_shape = b_out.batch_shape - a_rem_dims = [a_shape[i] for i in a_rem_axes] - a_var_dims = [a_shape[i] for i in a_var_axes] - b_var_dims = [b_shape[i] for i in b_var_axes] + a_rem_dims = [a_out.dim_tags[i] for i in a_rem_axes] + a_var_dims = [a_out.dim_tags[i] for i in a_var_axes] + b_var_dims = [b_out.dim_tags[i] for i in b_var_axes] def find_axis(a_axis, b_axis): """ @@ -5533,49 +5535,18 @@ def find_axis(a_axis, b_axis): return axis time_dim_axis = find_axis(a_out.time_dim_axis, b_out.time_dim_axis) - batch_dim_axis = find_axis(a_out.batch_dim_axis, b_out.batch_dim_axis) - assert batch_dim_axis != NotSpecified or (a_out.batch_dim_axis is None and b_out.batch_dim_axis is None) if not b_var_dims and add_var2_if_empty: - b_var_dims.append(1) - - def get_batch_axis_excluding_batch(axis): - """ - :param int axis: - :rtype: int - """ - if batch_dim_axis is None: - return axis - assert axis != batch_dim_axis - if axis < batch_dim_axis: - return axis - return axis - 1 - - # Collect dynamic size info. - size_placeholder = {} - for axis1_wo_b in sorted(a_out.size_placeholder.keys()): - axis_out_wb = cls._axis1_to_output(a_out.get_batch_axis(axis1_wo_b), a_rem_axes=a_rem_axes, a_var_axes=a_var_axes) - if axis_out_wb is None: - continue - size_placeholder[get_batch_axis_excluding_batch(axis_out_wb)] = a_out.size_placeholder[axis1_wo_b] - for axis2_wo_b in sorted(b_out.size_placeholder.keys()): - axis_out_wb = cls._axis2_to_output( - b_out.get_batch_axis(axis2_wo_b), b_rem_axes=b_rem_axes, a_var_axes=a_var_axes, b_var_axes=b_var_axes) - if axis_out_wb is None or axis_out_wb in size_placeholder: - continue - size_placeholder[get_batch_axis_excluding_batch(axis_out_wb)] = b_out.size_placeholder[axis2_wo_b] - - shape = list(a_rem_dims + a_var_dims + b_var_dims) - if batch_dim_axis is not None and batch_dim_axis is not NotSpecified: - shape.pop(batch_dim_axis) + b_var_dims.append( + DimensionTag(kind=DimensionTag.Types.Spatial, description="%s:dot:dummy-var2" % name, dimension=1)) + dim_tags = list(a_rem_dims + a_var_dims + b_var_dims) return Data( name="%s_output" % name, - shape=tuple(shape), - batch_dim_axis=batch_dim_axis, + dim_tags=dim_tags, time_dim_axis=time_dim_axis, dtype=a_out.dtype, - size_placeholder=size_placeholder, + batch=BatchInfo.get_common_batch_info([src.batch for src in (a_out, b_out)]), beam=SearchBeam.get_combined_beam(a_out.beam, b_out.beam)) @@ -6361,7 +6332,7 @@ def __init__(self, condition, true_layer, false_layer, old_size = self.output.size_placeholder[i] old_tag = DimensionTag.get_tag_from_size_tensor(old_size) assert old_tag - old_tag.set_tag_on_size_tensor(size, batch=self.output.batch) + old_tag.set_tag_on_size_tensor(size, batch=self.output.batch, same_as_before=True) self.output.size_placeholder[i] = size def _cond_layer_return(self, layer): diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index f818e0781..1f0b29b69 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -310,11 +310,14 @@ def transform_config_dict(cls, d, network, get_layer): # We need to figure out the output time dim tag at this early point, # because _SubnetworkRecCell might need it during template construction. source_data = get_concat_sources_data_template(d["sources"]) if d["sources"] else None + have_dyn_seq_len_end = False + if isinstance(d.get("unit"), dict): + have_dyn_seq_len_end = "end" in d["unit"] if source_data and not source_data.have_time_axis(): # expect to be inside other RecLayer time_dim_tag = None else: # We will output a time-dim. - if source_data: + if source_data and not have_dyn_seq_len_end: assert source_data.have_time_axis() time_dim_tag = source_data.get_time_dim_tag() elif d.get("size_target"): # if this is set, always use it @@ -322,7 +325,7 @@ def transform_config_dict(cls, d, network, get_layer): target=d["size_target"], mark_data_key_as_used=True, network=network) assert target_data.have_time_axis() time_dim_tag = target_data.get_time_dim_tag() - elif d.get("target") and network.eval_flag: + elif d.get("target") and network.eval_flag and not network.search_flag: target_data = cls._static_get_target_value( target=d["target"][0] if isinstance(d["target"], (list, tuple)) else d["target"], mark_data_key_as_used=True, network=network) @@ -413,7 +416,6 @@ def get_out_data_from_opts(cls, network, unit, _time_dim_tag=None, sources=(), i assert out if out.have_time_axis() and _time_dim_tag: out = out.copy_template_replace_dim_tag(axis=out.time_dim_axis, new_dim_tag=_time_dim_tag) - cls._post_init_output(output=out, sources=sources, network=network, **kwargs) for dep in deps: if dep: out.beam = SearchBeam.get_combined_beam(out.beam, dep.output.beam) @@ -1007,6 +1009,11 @@ def __init__(self, net_dict, source_data, time_dim_tag, rec_layer_name, parent_n self.net.extern_data.data["source"] = ( source_data.copy_template_excluding_time_dim()) self.time_dim_tag = time_dim_tag + self._time_dim_tags = {time_dim_tag} # type: typing.Set[DimensionTag] + if source_data: + # Maybe the input has a different time dim tag, but we still have new dynamic length here + # because of custom endings ("end" layer). + self._time_dim_tags.add(source_data.get_time_dim_tag()) for key, data in parent_net.extern_data.data.items(): if key in self.net.extern_data.data or data.time_dim_axis is None: continue # Don't overwrite existing, e.g. "source". @@ -1533,7 +1540,7 @@ def get_input_moved_out(name): assert isinstance(self.input_layers_net, TFNetwork) layer = self.input_layers_net.layers[layer_name] assert isinstance(layer, LayerBase) - if not self.parent_rec_layer.output.is_same_time_dim(layer.output): + if layer_name not in inputs_moved_out_tas: assert name != "output" and not prev, "Time dim does not match: RecLayer %s (%r) vs sub layer %s (%r)." % ( self.parent_rec_layer, self.parent_rec_layer.output.get_time_dim_tag(), layer, layer.output.get_time_dim_tag()) @@ -1960,17 +1967,21 @@ def get_output(self): fixed_seq_len = input_seq_len if fixed_seq_len is not None: time_dim_tag = DimensionTag.get_tag_from_size_tensor(fixed_seq_len) + assert time_dim_tag is self.time_dim_tag with tf.name_scope("check_seq_len_batch_size"): fixed_seq_len = check_input_dim( fixed_seq_len, axis=0, dim=batch_dim * (input_beam.beam_size if input_beam else 1)) if time_dim_tag: - time_dim_tag.set_tag_on_size_tensor(fixed_seq_len, batch=output_template.output.batch) + time_dim_tag.set_tag_on_size_tensor(fixed_seq_len, batch=time_dim_tag.batch, same_as_before=True) max_seq_len = tf.reduce_max(fixed_seq_len, name="max_seq_len") have_known_seq_len = True else: assert "end" in self.layer_data_templates, "length not defined, provide 'end' layer" max_seq_len = None have_known_seq_len = False + # self.time_dim_tag (via transform_config_dict) should match the logic in here. + # Ie. for this case the dyn size should not be set, as it will be dynamic later on via "end". + assert self.time_dim_tag.dyn_size is None if time_dim_tag: self.time_dim_tag.declare_same_as(time_dim_tag) else: @@ -1980,6 +1991,8 @@ def get_output(self): used_keys = self.net.used_data_keys.copy() for key in sorted(used_keys): data = rec_layer.network.get_extern_data(key, mark_data_key_as_used=True) + # This could be another dim tag from the rec time dim tag. + self._time_dim_tags.add(data.get_time_dim_tag()) data_placeholder = data.get_placeholder_as_time_major() with tf.name_scope("check_data_len"): data_len = tf.shape(data_placeholder)[0] @@ -2233,8 +2246,6 @@ def get_loop_loss(): if self.input_layers_moved_out: with tf.name_scope("input_layers_moved_out"): self._construct_input_layers_moved_out() - if fixed_seq_len is None and rec_layer.output.size_placeholder: # might have set it by now - fixed_seq_len = rec_layer.output.size_placeholder[0] for layer_name in self.input_layers_moved_out: # Create only Tensor arrays for those which we use inside the loop. if not self._input_layer_used_inside_loop(layer_name): @@ -2245,21 +2256,21 @@ def get_loop_loss(): assert layer.output.have_time_axis() assert rec_layer.output.is_same_time_dim(layer.output) # Only unroll if that is the same time dim. - if not layer.output.mark_same_time(rec_layer.output): + if not layer.output.mark_same_time(self._time_dim_tags): continue - assert fixed_seq_len is not None + assert max_seq_len is not None inp_ta = tf.TensorArray( name="%s_ta" % layer_name, dtype=self.layer_data_templates[layer_name].output.dtype, element_shape=self.layer_data_templates[layer_name].output.batch_shape, - size=tf.reduce_max(fixed_seq_len), + size=layer.output.time_dimension(), infer_shape=True) with tf.control_dependencies([ - tf.Assert(tf.equal( - tf.shape(layer.output.placeholder)[layer.output.time_dim_axis], tf.reduce_max(fixed_seq_len)), + tf.Assert(tf.greater_equal( + layer.output.time_dimension(), max_seq_len), ["input TA unstack", str(layer.output), "shape", tf.shape(layer.output.placeholder), "seq len", layer.output.get_sequence_lengths(), "do not match", - "fixed seq len", fixed_seq_len, "max", tf.reduce_max(fixed_seq_len)])]): + "max seq len", max_seq_len])]): inp_ta = inp_ta.unstack( layer.output.get_placeholder_as_time_major(), name="%s_ta_unstack" % layer_name) @@ -2501,14 +2512,17 @@ def cond(i, net_vars, acc_tas, seq_len_info=None): "%s: input beam %r, output beam %r, sources %r, target %r" % ( self.parent_rec_layer, input_beam, output_beam, self.parent_rec_layer.sources, self.parent_rec_layer.target)) - from returnn.tf.util.basic import tile_transposed - seq_len = tile_transposed(seq_len, axis=0, multiples=output_beam.beam_size) # (batch * beam,) - seq_len._RETURNN_dyn_size_beam = rec_layer.output.beam - time_dim_tag.set_tag_on_size_tensor(seq_len, batch=rec_layer.output.batch) + assert output_template.output.batch.beam == output_beam + time_dim_tag = time_dim_tag.get_for_batch(output_template.output.batch) + assert time_dim_tag.dyn_size is not None + seq_len = time_dim_tag.dyn_size else: _, final_net_vars, final_acc_tas, (_, seq_len) = final_loop_vars - seq_len._RETURNN_dyn_size_beam = rec_layer.output.beam - time_dim_tag.set_tag_on_size_tensor(seq_len, batch=rec_layer.output.batch) + # Note: In case of search, the seq len would have the beam from the end layer. + # This might not be the same as the final output beam. + # This should correctly be resolved in _construct_output_layers_moved_out and _opt_search_resolve. + # So we do not assign it to the dim tag at this point. + assert "output" in extra_output_layers max_seq_len = tf.reduce_max(seq_len, name="dyn_max_seq_len") self.get_final_rec_vars = lambda layer_name_: self.get_layer_rec_var_from_loop_vars( loop_vars=final_net_vars, layer_name=layer_name_, final_frame=True, seq_len=seq_len) @@ -2643,6 +2657,7 @@ def _opt_search_resolve(self, layer_name, acc_ta, final_net_vars, seq_len, searc import os from returnn.tf.util.basic import nd_indices, assert_min_tf_version, expand_dims_unbroadcast from returnn.tf.util.basic import get_shape_dim, get_valid_scope_name_from_str + from returnn.tf.util.basic import TensorCachedComputation rec_layer = self.parent_rec_layer try: layer = self.net.get_layer(layer_name) @@ -2655,6 +2670,16 @@ def _opt_search_resolve(self, layer_name, acc_ta, final_net_vars, seq_len, searc return acc_ta, None, search_choices, seq_len if search_choices.keep_raw: search_choices_cache[search_choices.owner.name] = search_choices + # Make a new seq_len tensor, to be able to attach a new dim tag to it. + # This is needed as long as we make use of get_tag_from_size_tensor. + # Cache it such that we have it unique. + cache = TensorCachedComputation(search_choices, key=("seq_len_raw", seq_len)) + if not cache.has_cache(): + assert search_choices.owner.output.batch and search_choices.owner.output.batch.beam + seq_len = tf.identity(seq_len) + seq_len._RETURNN_dyn_size_beam = search_choices.owner.output.batch.beam + cache.set_cache(seq_len) + seq_len = cache.get_cache() return acc_ta, search_choices.owner.name, search_choices, seq_len layer_choice = search_choices.owner is_prev_choice = False @@ -2735,20 +2760,22 @@ def get_choice_seq(choice_base): src_choice_beams = self.final_acc_tas_dict["choice_%s" % choice_.name].read( max_seq_len - 1, name="ta_read_choice") # (batch, beam) -> beam_in idx seq_len = select_src_beams(seq_len, src_choice_beams) - else: + assert choice_.output.batch and choice_.output.batch.beam + assert getattr(seq_len, "_RETURNN_dyn_size_beam", NotSpecified) in (NotSpecified, choice_.output.batch.beam) + seq_len._RETURNN_dyn_size_beam = choice_.output.batch.beam + + else: # not end_layer # Here we don't need to resolve anything, as the sequence length is the same for all hyps in the beam. # However, beam size for the current output may be different from the "output" layer. - # Therefore take the first len in the beam and tile it to the desired beam size. - - # Separate batch and beam dims - seq_len_beam_size = rec_layer.output.beam.beam_size - seq_len = tf.reshape(seq_len, [batch_dim, seq_len_beam_size], name="split_batch_beam") - - seq_len = seq_len[:, 0:1] - seq_len = tf.tile(seq_len, [1, latest_beam_size], name="resize_seq_len_beam") - - # Recombine batch and beam dims - seq_len = tf.reshape(seq_len, [batch_dim * latest_beam_size], name="merge_batch_beam") + tag = DimensionTag.get_tag_from_size_tensor(seq_len) + assert tag + latest_batch = ( + latest_layer_choice.output.batch + or self.parent_rec_layer.output.batch.copy_set_beam(latest_layer_choice.output.beam)) + tag = tag.get_for_batch(latest_batch) + assert tag.dyn_size is not None + assert tag.batch == latest_batch and tag.batch.beam == latest_layer_choice.output.beam + seq_len = tag.dyn_size new_acc_output_ta = tf.TensorArray( name="search_resolved_%s" % os.path.basename(acc_ta.handle.op.name), @@ -3131,19 +3158,6 @@ def get_layer(name): for layer_name in self.input_layers_moved_out: get_layer(layer_name) - # We might have figured out the real output seq length (and dim tag) by now. - if not self.parent_rec_layer.output.size_placeholder and "output" in self.input_layers_moved_out: - output_layer = self.input_layers_net.layers["output"] - assert output_layer.output.have_time_axis() - self.parent_rec_layer.output.size_placeholder = {0: output_layer.output.get_sequence_lengths()} - # This might be set e.g. by ChoiceLayer, or losses. - if not self.parent_rec_layer.output.size_placeholder and self.input_layers_net.used_data_keys: - for data_key in sorted(self.input_layers_net.used_data_keys): - data = self.input_layers_net.extern_data.data[data_key] - if data.have_time_axis(): - self.parent_rec_layer.output.size_placeholder = {0: data.get_sequence_lengths()} - break - def _construct_output_layers_moved_out(self, loop_accumulated, seq_len, extra_output_layers, final_net_vars): """ See self._move_outside_loop(). @@ -3158,15 +3172,10 @@ def _construct_output_layers_moved_out(self, loop_accumulated, seq_len, extra_ou """ if not self.output_layers_moved_out and not extra_output_layers: return - from returnn.tf.util.basic import tensor_array_stack, has_control_flow_context, concat_with_opt_broadcast - from returnn.tf.util.basic import DimensionTag, tile_transposed + from returnn.tf.util.basic import tensor_array_stack, concat_with_opt_broadcast from returnn.tf.network import TFNetwork, ExternData from .base import InternalLayer - if seq_len is not None: - time_dim_tag = DimensionTag.get_tag_from_size_tensor(seq_len) - else: - time_dim_tag = None self.output_layers_net = TFNetwork( name="%s/%s(rec-subnet-output)" % ( self.parent_net.name, self.parent_rec_layer.name if self.parent_rec_layer else "?"), @@ -3199,7 +3208,6 @@ def get_loop_acc_layer(name): if name in loop_acc_layers: return loop_acc_layers[name] with tf.name_scope(self.layer_data_templates[name].layer_class_type.cls_get_tf_scope_name(name)): - inner_layer = self.net.get_layer(name) acc_ta = loop_accumulated["output_%s" % name] acc_ta, latest_layer_choice_name, search_choices, resolved_seq_len = self._opt_search_resolve( layer_name=name, acc_ta=acc_ta, final_net_vars=final_net_vars, seq_len=seq_len, @@ -3213,39 +3221,10 @@ def get_loop_acc_layer(name): output.beam = None if output.batch: output.batch = output.batch.copy_set_beam(output.beam) - resolved_seq_len._RETURNN_dyn_size_beam = output.beam - if time_dim_tag: - time_dim_tag.set_tag_on_size_tensor(resolved_seq_len, batch=output.batch) + output.set_dynamic_size(axis=0, sizes=resolved_seq_len) max_len = tf.reduce_max(resolved_seq_len) # We should have accumulated it. output.placeholder = tensor_array_stack(acc_ta, stop=max_len) # e.g. (time,batch,dim) - output.size_placeholder = {0: resolved_seq_len} - if latest_layer_choice_name and search_choices and search_choices.keep_raw: - if output.beam != self.parent_rec_layer.output.beam: - # TODO this is not quite correct... - # (It is correct only if you use keep_beam or so...) - if output.beam.beam_size % self.parent_rec_layer.output.beam.beam_size == 0: - size = tile_transposed( - seq_len, axis=0, multiples=output.beam.beam_size // self.parent_rec_layer.output.beam.beam_size) - size._RETURNN_dyn_size_beam = output.beam - if time_dim_tag: - time_dim_tag.set_tag_on_size_tensor(size, batch=output.batch) - output.size_placeholder[0] = size - if inner_layer.output.size_placeholder: - for i, size in inner_layer.output.size_placeholder.items(): - tag = DimensionTag.get_tag_from_size_tensor(size) - if tag and tag.dyn_size is not None: - size = tag.dyn_size # this is more likely out of the loop - if not has_control_flow_context(size): # copy if this size comes from outside the loop - if inner_layer.output.beam: - # Might need tiling... - size = tile_transposed( - size, axis=0, - multiples=tf.shape(output.size_placeholder[0])[0] // tf.shape(size)[0]) - size._RETURNN_dyn_size_beam = output.beam - if tag: - tag.set_tag_on_size_tensor(size, batch=output.batch) - output.size_placeholder[i + 1] = size assert isinstance(self.output_layers_net, TFNetwork) layer_ = self.output_layers_net.add_layer( name=name, output=output, layer_class=InternalLayer, sources=[]) @@ -3342,7 +3321,7 @@ def get_layer(name): # However, after construction, when accessing any of these layers, # we would expect that their time-dim-axis matches the same as from the rec loop. for layer in self.output_layers_net.layers.values(): - layer.output.mark_same_time(self.parent_rec_layer.output) + layer.output.mark_same_time(self._time_dim_tags) RecLayer.SubnetworkRecCell = _SubnetworkRecCell @@ -5231,10 +5210,14 @@ def decide(cls, src, output=None, owner=None, name=None, length_normalization=Fa output.size_placeholder = {} for i, size in src_data.size_placeholder.items(): tag = DimensionTag.get_tag_from_size_tensor(size) - size = tf.reshape(size, [batch_dim, beam_size]) # (batch, beam) - size = tf.gather_nd(size, indices=beam_idxs_ext) # (batch,) - if tag: + assert tag + tag = tag.get_for_batch(output.batch) + if tag.dyn_size is None: + size = tf.reshape(size, [batch_dim, beam_size]) # (batch, beam) + size = tf.gather_nd(size, indices=beam_idxs_ext) # (batch,) tag.set_tag_on_size_tensor(size, batch=output.batch) + else: + size = tag.dyn_size output.size_placeholder[i] = size final_search_choices = SearchChoices(owner=owner, is_decided=True, beam_size=1) if owner: diff --git a/returnn/tf/util/basic.py b/returnn/tf/util/basic.py index 31f87a28d..318f4abf6 100644 --- a/returnn/tf/util/basic.py +++ b/returnn/tf/util/basic.py @@ -4496,39 +4496,33 @@ def _maybe_to_base_seq_len(v): in_seq_len = _get_main_seq_len() in_tag = DimensionTag.get_tag_from_size_tensor(in_seq_len) - assert in_tag + assert in_tag and in_tag.batch + # The base_in* is via dim tag same_base. + # This might be the global batch without beam. + # But it might also be the same as in*. base_in_seq_len = _maybe_to_base_seq_len(in_seq_len) base_in_tag = DimensionTag.get_tag_from_size_tensor(base_in_seq_len) assert base_in_tag base_kwargs = {k: _maybe_to_base_seq_len(v) for (k, v) in kwargs.items()} base_cache_key = (key, tuple(sorted(base_kwargs.items()))) - cache_key = (key, tuple(sorted(kwargs.items()))) - kwargs_tensors = [v for (_, v) in sorted(kwargs.items()) if isinstance(v, tf.Tensor)] base_kwargs_tensors = [v for (_, v) in sorted(base_kwargs.items()) if isinstance(v, tf.Tensor)] cache = TensorCachedComputation(base_in_seq_len, key=base_cache_key) if cache.has_cache(): base_out_seq_len = cache.get_cache() - base_tag = DimensionTag.get_tag_from_size_tensor(base_out_seq_len) - assert base_tag + base_out_tag = DimensionTag.get_tag_from_size_tensor(base_out_seq_len) + assert base_out_tag else: with same_control_flow_ctx(base_kwargs_tensors): base_out_seq_len = func(**base_kwargs) cache.set_cache(base_out_seq_len) - base_tag = DimensionTag(description=dim_tag_desc, kind=DimensionTag.Types.Spatial, batch=base_in_tag.batch) - base_tag.set_tag_on_size_tensor(base_out_seq_len) + base_out_tag = DimensionTag(description=dim_tag_desc, kind=DimensionTag.Types.Spatial, batch=base_in_tag.batch) + base_out_tag.set_tag_on_size_tensor(base_out_seq_len) - cache = TensorCachedComputation(in_seq_len, key=cache_key) - if cache.has_cache(): - out_seq_len = cache.get_cache() - tag_ = DimensionTag.get_tag_from_size_tensor(out_seq_len) - assert tag_ == base_tag - else: - with same_control_flow_ctx(kwargs_tensors): - out_seq_len = func(**kwargs) - cache.set_cache(out_seq_len) - base_tag.set_tag_on_size_tensor(out_seq_len, batch=in_tag.batch) - return out_seq_len + assert base_out_tag.batch + out_tag = base_out_tag.get_for_batch(in_tag.batch) + assert out_tag.dyn_size is not None + return out_tag.dyn_size def smoothing_cross_entropy(logits, diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index ffa8e7ca9..f256fafda 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -10,6 +10,7 @@ import os import typing import tensorflow as tf +import traceback from returnn.util.basic import NotSpecified import returnn.tf.compat as tf_compat @@ -56,6 +57,7 @@ def __init__(self, kind=Types.Unspecified, description=None, self.description = description self.dimension = dimension self.same_as = None # type: typing.Optional[DimensionTag] + self._same_as_tb = None # type: typing.Optional[traceback.StackSummary] # for debugging if src_data: assert isinstance(src_data, Data) and isinstance(src_axis, int) if not batch and dyn_size_ext: @@ -65,12 +67,16 @@ def __init__(self, kind=Types.Unspecified, description=None, self.src_axis = src_axis if dyn_size_ext and not dyn_size_ext.batch and batch: dyn_size_ext.batch = batch - if dyn_size_ext and dyn_size_ext.batch and batch: + if dyn_size_ext: assert batch == dyn_size_ext.batch self.dyn_size_ext = dyn_size_ext # type: typing.Optional[Data] if dyn_size is not None: assert not dyn_size_ext self.dyn_size = dyn_size + self._dyn_size_same = set() # type: typing.Set[tf.Tensor] + # We can have different tag variants per batch info (e.g. with beam). + # They each have same_as = self. The same_base should have the base (global) batch info. + self._same_for_batch = {} # type: typing.Dict[BatchInfo,DimensionTag] # When we have some dynamic size, this dynamic size could be inside a loop (RecLayer), # and different per each loop frame. # In that case, we can not access it from outside (except when we accumulate it). @@ -112,8 +118,71 @@ def copy(self, kind=None): batch=self.batch, src_data=self.src_data, src_axis=self.src_axis) tag.same_as = self # not declare_same_as, none of the extra checks needed + tag._same_as_tb = traceback.extract_stack() return tag + def get_for_batch(self, batch): + """ + :param BatchInfo batch: + :rtype: DimensionTag + """ + if self.batch == batch: + return self + if self.is_batch_dim(): + # For now, just create another copy in case of batch dim tag. + return DimensionTag(kind=DimensionTag.Types.Batch, description="batch:%s" % batch.short_repr(), batch=batch) + if self.dimension is not None: + # If static dim, no effect. + assert not self.batch + return self + same_base = self.get_same_base() + # Might be uninitialized in some cases. Assume batch is global. + if not same_base.batch: + batch_base = batch.get_global_base() + if same_base.dyn_size_ext: + assert batch == batch_base + same_base.batch = batch + assert not same_base.dyn_size_ext.batch + same_base.dyn_size_ext.batch = batch + else: + same_base.batch = batch_base + if same_base.dyn_size_ext: + assert same_base.batch == same_base.dyn_size_ext.batch + if same_base.batch == batch: + return same_base + if batch in same_base._same_for_batch: + return same_base._same_for_batch[batch] + if batch.copy_remove_beam() == batch.get_global_base() and batch.beam and same_base.dyn_size_ext: + dyn_size_ext = same_base.dyn_size_ext.copy_extend_with_beam(batch.beam) + assert dyn_size_ext.batch == batch + beam_expanded_base_data = getattr(dyn_size_ext.placeholder, "_RETURNN_beam_expanded_base_data", None) + assert beam_expanded_base_data + # Note: The beam expansion used tiling, which can be cached. + # This means that we could end up with the same size tensor (placeholder) for multiple different beams, + # when there are different beams with same beam size! + # This breaks the current logic in get_tag_from_size_tensor. + # As a workaround, we make an explicit new tensor here. + from .basic import get_valid_scope_name_from_str, same_control_flow_ctx + with same_control_flow_ctx(dyn_size_ext.placeholder): + dyn_size_ext.placeholder = tf.identity( + dyn_size_ext.placeholder, + name=get_valid_scope_name_from_str("%s_identity_for_beam_%s" % (dyn_size_ext.name, batch.beam.name))) + dyn_size_ext.placeholder._RETURNN_dyn_size_beam = batch.beam + dyn_size_ext.placeholder._RETURNN_beam_expanded_base_data = beam_expanded_base_data + else: + # We have some more custom batch info (via merge dims or so). + # Just leave uninitialized for now. + dyn_size_ext = None + dim_tag = DimensionTag( + kind=self.kind, description=self.description, dimension=self.dimension, + batch=batch, dyn_size_ext=dyn_size_ext) + dim_tag.same_as = same_base + dim_tag._same_as_tb = traceback.extract_stack() + if dyn_size_ext: + dim_tag.set_tag_on_size_tensor(dyn_size_ext.placeholder, batch=batch) + same_base._same_for_batch[batch] = dim_tag + return dim_tag + @property def dyn_size(self): """ @@ -167,25 +236,68 @@ def is_spatial_dim(self): """ return self.kind == DimensionTag.Types.Spatial - def set_tag_on_size_tensor(self, x, batch=None): + def is_same_size_tensor(self, x): """ + :param tf.Tensor x: + :return: whether this dim tag for this specific batch (incl beam) is the same as the given size + :rtype: bool + """ + if x is self.dyn_size: + return True + if x in self._dyn_size_same: + return True + return False + + def set_tag_on_size_tensor(self, x, batch=None, same_as_before=False): + """ + This function is used + to couple a tf.Tensor instance representing the dyn size + with the dim tag. + + This is usually a newly created dim tag, + which is yet unset. + + It is also used to couple an existing dim tag with other dyn sizes + which just differ by an expansion of the batch (e.g. search beam). + + See also :func:`get_tag_from_size_tensor`. + :param tf.Tensor x: :param BatchInfo|None batch: + :param bool same_as_before: implies it was set before, and the new size is the same. + e.g. it could be some identity with added checks, or other change. + :return: self or new dim tag + :rtype: DimensionTag """ # It's unusual if self.dimension is not None, but let's accept that. if hasattr(x, "_is_size_of_dim_tag"): # noinspection PyProtectedMember assert x._is_size_of_dim_tag in (None, self) # If we already have another dyn size set or different batch, create a new DimensionTag instance. - if (self.dyn_size is not None and self.dyn_size is not x) or (self.batch and batch and self.batch != batch): - if self.batch: - assert self.dyn_size is not None - new_dim_tag = self.copy() - new_dim_tag.dyn_size_ext = None - if batch: - new_dim_tag.batch = batch - new_dim_tag.set_tag_on_size_tensor(x) - return + if self.batch and batch and self.batch != batch: + assert not same_as_before # it cannot be the same when it is another batch... + new_dim_tag = self.get_for_batch(batch) + new_dim_tag.set_tag_on_size_tensor(x, batch=batch) + return new_dim_tag + if self.dyn_size is not None and self.dyn_size is not x: + if x in self._dyn_size_same: + pass # ok, pass on + elif same_as_before: + self._dyn_size_same.add(x) + # And now pass on. + else: + assert self.batch and batch + # It's not clear what to do. We could create a new dim tag, but the sizes might be different. + # Usually we should not get here. + # So for now, just error. + from .basic import format_graph_output + raise Exception("\n".join([ + "%r (%r) already has size %r, and another incompatible size %r (batch %r) is being assigned." % ( + self, self.description, self.dyn_size, x, batch), + "\nNew size computation graph:", + format_graph_output(x, max_depth=3), + "\nThis is maybe the result of an incorrect declare_same_as. Traceback of declare_same_as:", + "".join(self._same_as_tb.format()) if self._same_as_tb else ("same_as = %s" % self.same_as)])) if batch and getattr(x, "_RETURNN_dyn_size_beam", None): assert batch.beam == getattr(x, "_RETURNN_dyn_size_beam") if self.batch and batch: @@ -196,6 +308,7 @@ def set_tag_on_size_tensor(self, x, batch=None): setattr(x, "_is_size_of_dim_tag", self) if self.dyn_size is None: self.dyn_size = x + return self @classmethod def get_tag_from_size_tensor(cls, x): @@ -318,7 +431,6 @@ def declare_same_as(self, other): """ :param DimensionTag other: """ - from .basic import same_control_flow_ctx, tile_transposed if self is other: return other_same_base = other.get_same_base() @@ -330,6 +442,7 @@ def declare_same_as(self, other): if self_same_as is other_same_base: return self_same_as.same_as = other_same_base + self_same_as._same_as_tb = traceback.extract_stack() if self_same_as.dyn_size_ext is None: self_same_as.dyn_size_ext = other_same_base.dyn_size_ext elif other_same_base.dyn_size_ext is None: @@ -337,9 +450,10 @@ def declare_same_as(self, other): if self.dyn_size_ext is None: self.dyn_size_ext = self_same_as.dyn_size_ext self.same_as = other_same_base + self._same_as_tb = traceback.extract_stack() if self.dyn_size is not None and other_same_base.dyn_size is not None: if self.dyn_size is not other_same_base.dyn_size: - if self.src_data and other_same_base.src_data and self.src_data.beam == other_same_base.src_data.beam: + if self.batch == other_same_base.batch: # Note: Instead of making this a warning, we could also enforce this at some point. # The user should be able to fix `extern_data` in the config such that this is correct in the first place. # Also, in addition to this warning, we might want to add some runtime check on the eq of the dyn sizes. @@ -353,13 +467,6 @@ def declare_same_as(self, other): if self.src_data.get_dim_tag(self.src_axis).description == self.description: self.src_data.size_placeholder[ self.src_data.get_batch_axis_excluding_batch(self.src_axis)] = self.same_as.dyn_size - # if the tag is used in a recurrent layer during search, the placeholder has to be expanded by the beam size - if self.src_data.beam and (not self.same_as.src_data or not self.same_as.src_data.beam): - for i, v in sorted(self.src_data.size_placeholder.items()): - with same_control_flow_ctx(v): - size = tile_transposed(v, axis=0, multiples=self.src_data.beam.beam_size) - self.set_tag_on_size_tensor(size, batch=self.src_data.batch) - self.src_data.size_placeholder[i] = size # If others dyn_size is None but we have a dyn_size, maybe update others dyn_size. if self.dyn_size is not None and self.same_as.dyn_size is not self.dyn_size: # Could be unset if it comes from the config, or from prev graph creation. @@ -660,18 +767,22 @@ def make_global_broadcast_batch_info(cls): @classmethod def get_common_batch_info(cls, batches): """ - :param list[BatchInfo] batches: + :param list[BatchInfo|None] batches: :rtype: BatchInfo|None """ + # Fast paths. if not batches: return None if len(batches) == 1: return batches[0] + # Make unique, and filter non-none. batches_ = [] for batch in batches: - if batch not in batches_: + if batch and batch not in batches_: batches_.append(batch) batches = batches_ + if not batches_: + return None if len(batches) == 1: return batches[0] base = batches[0].get_global_base() @@ -1178,8 +1289,10 @@ def __init__(self, name, else: dtype = "float32" self.dtype = dtype # type: str - self.batch = batch - self.beam = beam + if beam and batch: + assert batch.beam == beam + self._batch = batch + self._beam = beam if isinstance(dim_tags, (tuple, list)): # We do a couple of sanity checks, and maybe set special axes attribs. shape_ = tuple(tag.dimension for tag in dim_tags if not tag.is_batch_dim()) @@ -1260,6 +1373,7 @@ def __init__(self, name, base_tag.declare_same_as(_dim_tag) if _dim_tag.dyn_size is not None: self.set_dynamic_size(_axis, _dim_tag.dyn_size) + self._adapt_batch_consistent_dim_tags() self.sanity_check(assume_complete=False) @classmethod @@ -1326,8 +1440,11 @@ def sanity_check(self, ignore_placeholder=False, assume_complete=True): continue # further checks will assume not batch assert axis != self.batch_dim_axis, "%s: invalid %s" % (self, tag) # Note: tag.kind (feature or spatial) is independent from self.feature_dim_axis. + if tag.batch and self.batch: + assert tag.batch == self.batch if tag.dyn_size_ext: assert tag.dyn_size_ext.dtype in {"int32", "int64"} + assert tag.batch == tag.dyn_size_ext.batch tag.dyn_size_ext.sanity_check() if not ignore_placeholder and self.placeholder is not None: # Note: We could just call self.placeholder.set_shape. @@ -1524,6 +1641,11 @@ def __repr__(self): def __hash__(self): return id(self) + def _adapt_batch_consistent_dim_tags(self): + if not self.batch: + return + self._dim_tags = tuple(tag.get_for_batch(self.batch) for tag in self._dim_tags) + def copy(self, name=None): """ :param str name: if given, will overwrite this name @@ -1748,11 +1870,11 @@ def copy_add_batch_dim(self, batch_dim_axis, batch=None, dim_tag=None): dim_tags = list(self.dim_tags) if dim_tag: assert dim_tag.is_batch_dim() - assert dim_tag.dimension == (batch.dim if isinstance(batch.dim, int) else None) + assert dim_tag.dimension == batch.static_dim + assert dim_tag.batch == batch else: dim_tag = DimensionTag( - kind=DimensionTag.Types.Batch, description="batch", - dimension=batch.dim if isinstance(batch.dim, int) else None) + kind=DimensionTag.Types.Batch, description="batch", dimension=batch.static_dim, batch=batch) dim_tags.insert(batch_dim_axis, dim_tag) data_opts["dim_tags"] = dim_tags data_opts["batch"] = batch @@ -1846,7 +1968,8 @@ def copy_add_dim_by_tag(self, dim_tag, unbroadcast=False, axis=None): else: batch_info = BatchInfo.make_global_broadcast_batch_info() return self.copy_add_batch_dim( - batch_dim_axis=axis, batch=batch_info, dim_tag=dim_tag if dim_tag.dimension == 1 else None) + batch_dim_axis=axis, batch=batch_info, + dim_tag=dim_tag if (dim_tag.dimension == 1 and dim_tag.batch == batch_info) else None) data_opts = self.get_kwargs() # Note: if dim_tag is feature, but we are sparse, we just make it spatial @@ -2049,21 +2172,15 @@ def copy_extend_with_beam(self, beam): assert data.beam is None, "incompatible beam (%r vs %r)" % (data.beam, beam) if beam is None: return data - if data.batch: - data.batch = data.batch.copy_set_beam(beam) data.beam = beam + assert data.batch + data.batch = data.batch.copy_set_beam(beam) with tf.name_scope("%s_data_extend_with_beam" % get_valid_scope_name_from_str(self.name)): if data.placeholder is not None: with same_control_flow_ctx(data.placeholder): data.placeholder = tile_transposed(data.placeholder, axis=data.batch_dim_axis, multiples=beam.beam_size) - for i, v in sorted(data.size_placeholder.items()): - tag = DimensionTag.get_tag_from_size_tensor(v) - with same_control_flow_ctx(v): - sizes = tile_transposed(v, axis=0, multiples=beam.beam_size) - sizes._RETURNN_dyn_size_beam = beam - if tag is not None: - tag.set_tag_on_size_tensor(sizes, batch=data.batch) - data.size_placeholder[i] = sizes + setattr(data.placeholder, "_RETURNN_beam_expanded_base_data", self) + data._adapt_batch_consistent_dim_tags() return data def copy_squeeze_axes(self, axes): @@ -2343,23 +2460,23 @@ def shape_dense(self): return self.shape @property - def shape_sparse(self): + def batch_shape_dense(self): """ - :return: shape without feature dim axis :rtype: tuple[int|None] """ if self.sparse: - return self.shape - return self.shape[:self.feature_dim_axis] + self.shape[self.feature_dim_axis + 1:] + return self.batch_shape + (self.dim,) + return self.batch_shape @property - def batch_shape_dense(self): + def dim_tags_sparse(self): """ - :rtype: tuple[int|None] + :return: dim tags without feature dim axis + :rtype: tuple[DimensionTag] """ - if self.sparse: - return self.batch_shape + (self.dim,) - return self.batch_shape + if self.sparse or not self.have_feature_axis(): + return self.dim_tags + return self.dim_tags[:self.feature_dim_axis] + self.dim_tags[self.feature_dim_axis + 1:] @property def ndim(self): @@ -2501,6 +2618,44 @@ def placeholder(self, value): self._placeholder = value self.sanity_check(assume_complete=False) + @property + def batch(self): + """ + :rtype: BatchInfo|None + """ + return self._batch + + @batch.setter + def batch(self, batch): + """ + :param BatchInfo|None batch: + """ + if batch: + assert batch.beam == self.beam + self._batch = batch + self._adapt_batch_consistent_dim_tags() + + @property + def beam(self): + """ + :rtype: SearchBeam|None + """ + if self._beam: + return self._beam + if self._batch: + return self._batch.beam + return None + + @beam.setter + def beam(self, beam): + """ + :param SearchBeam|None beam: + """ + # No check for batch.beam, as the batch is usually set only later. + self._beam = beam + if self._batch: + self._batch = self._batch.copy_set_beam(beam=beam) + def time_dimension(self): """ :return: shape(placeholder)[time_dim_axis], int scalar @@ -2894,34 +3049,31 @@ def set_dynamic_size(self, axis, sizes): if getattr(sizes, "_RETURNN_dyn_size_beam", NotSpecified) is NotSpecified: sizes._RETURNN_dyn_size_beam = self.beam if self.beam and getattr(sizes, "_RETURNN_dyn_size_beam", None) != self.beam: - from returnn.tf.util import basic as tf_util tag = DimensionTag.get_tag_from_size_tensor(sizes) - if tag: - # Just to be sure, we tile it as it should be. - base_size = tag.get_same_base().dyn_size - if not getattr(base_size, "_RETURNN_dyn_size_beam", None): - with tf_util.same_control_flow_ctx(base_size): - sizes = tf_util.tile_transposed(base_size, axis=0, multiples=self.beam.beam_size) - sizes._RETURNN_dyn_size_beam = self.beam - tag.set_tag_on_size_tensor(sizes, batch=self.batch) + assert tag and self.batch + tag = tag.get_for_batch(self.batch) + assert tag.dyn_size is not None + sizes = tag.dyn_size sizes_tag = DimensionTag.get_tag_from_size_tensor(sizes) if sizes_tag: - assert sizes_tag.dyn_size is sizes + assert sizes_tag.is_same_size_tensor(sizes) tag = self.dim_tags[axis] assert tag.dimension is None # dynamic axis - if tag.dyn_size is sizes: + if tag.is_same_size_tensor(sizes): return # nothing to do if tag.dyn_size is None: if sizes_tag: # special rule for older code: overtake previous existing - assert sizes_tag.dyn_size is sizes + assert sizes_tag.is_same_size_tensor(sizes) self._dim_tags = self.dim_tags[:axis] + (sizes_tag,) + self.dim_tags[axis + 1:] # Also assume the existing dim tag should be expected as equal. # Likely there is anyway no reference so this does not matter. tag.declare_same_as(sizes_tag) else: # Assign now. This should also set the dim tag on sizes. - tag.dyn_size = sizes + new_tag = tag.set_tag_on_size_tensor(sizes, batch=self.batch) + if new_tag is not tag: + self._dim_tags = self.dim_tags[:axis] + (new_tag,) + self.dim_tags[axis + 1:] else: # Reset to some new size. # Use new dim tag, or previous existing attached to size. @@ -2944,18 +3096,16 @@ def get_static_axes(self): return [axis for axis, dim in enumerate(self.batch_shape) if axis != self.batch_dim_axis and dim is not None] - def mark_same_time(self, other): + def mark_same_time(self, tags): """ - If the dimension tag of others time axis matches any of our axes, we set our time axis to the selected one. + If the given dimension tag matches any of our axes, we set our time axis to the selected one. - :param Data other: + :param set[DimensionTag] tags: :return: whether we have found the same :rtype: bool """ - assert other.have_time_axis() - tag_other = other.get_dim_tag(other.time_dim_axis) - for axis, dim_tag in enumerate(self.get_batch_shape_dim_tags()): - if dim_tag == tag_other: + for axis, dim_tag in enumerate(self.dim_tags): + if dim_tag in tags: self.time_dim_axis = axis return True return False @@ -3235,19 +3385,22 @@ def get_common_data(cls, sources, warnings_out=None, ignore_feature_dim=False): :param bool ignore_feature_dim: when set, the feature dim does not have to match in the sources :return: some generic data where the sources should be compatible to (with copy_compatible_to), i.e. it contains the union of all axes from all sources (least common multiple). + This is always a template, and a new copy. :rtype: Data|None """ if not sources: return None assert sources if len(sources) == 1: - return sources[0] + return sources[0].copy_template() max_ndim = max([s.batch_ndim for s in sources]) common_batch = BatchInfo.get_common_batch_info([src.batch for src in sources if src.batch]) # Try with the (first) largest. common = [s for s in sources if s.batch_ndim == max_ndim][0] - common = common.copy() - common.batch = common_batch # no copy_ext_batch because we don't want TF ops + common = common.copy_template() + common.beam = None # this will be reset + if common_batch: + common.batch = common_batch.copy_set_beam(None) # the beam will be reset if any([s.beam for s in sources]): # Note: we don't use copy_extend_with_beam because we don't want to create any ops in the TF graph at this point. common.beam = SearchBeam.get_combined_beam(*[s.beam for s in sources]) @@ -3576,7 +3729,7 @@ def _infer_dim_tags_tuple_from_shape( # Just some sanity checks. assert isinstance(tag, DimensionTag) assert tag.dimension == dim - assert tag.dyn_size is dyn_size + assert tag.is_same_size_tensor(dyn_size) continue if axis == feature_dim_axis and dyn_size is None and axis != time_dim_axis: tag = DimensionTag(kind=DimensionTag.Types.Feature, dimension=dim, description="feature:%s" % name)