diff --git a/returnn/tf/engine.py b/returnn/tf/engine.py index e302d9d25..addb6a248 100644 --- a/returnn/tf/engine.py +++ b/returnn/tf/engine.py @@ -1293,6 +1293,8 @@ def _init_network(self, net_desc, epoch=None): use_dataset_pipeline = False if self.config.is_true("dataset_pipeline"): use_dataset_pipeline = True + from returnn.tf.util.data import batch_dim + batch_dim.batch = None # make sure it is reset extern_data = ExternData() extern_data.init_from_config(config=self.config, auto_create_placeholders=not use_dataset_pipeline) if use_dataset_pipeline: diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 0e67a23bc..1a8e610fc 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -2823,11 +2823,6 @@ def __init__(self, name, if time_dim_axis is NotSpecified: time_dim_axis = _default_time_dim_axis_dim_tags(dim_tags) dim_tags = tuple(dim_tags) - if auto_create_placeholders: - _auto_create_size_placeholders_on_dim_tags(name=name, dim_tags=dim_tags) - if batch_dim_axis_ is not None: - if dim_tags[batch_dim_axis_].batch and not self._batch: - self._batch = dim_tags[batch_dim_axis_].batch del shape_ del batch_dim_axis_ else: @@ -2846,7 +2841,7 @@ def __init__(self, name, dim_tags = _infer_dim_tags_tuple_from_shape( shape, batch_dim_axis=batch_dim_axis, time_dim_axis=time_dim_axis, feature_dim_axis=feature_dim_axis, size_placeholder=size_placeholder, name=name, - auto_create_placeholders=auto_create_placeholders, + extern_data=auto_create_placeholders, dim_tags=dim_tags, sparse=sparse) del batch_dim_axis del shape @@ -2893,8 +2888,10 @@ def __init__(self, name, base_tag = self._dim_tags[_axis] if base_tag != _dim_tag: base_tag.declare_same_as(_dim_tag) - if _dim_tag.dyn_size is not None: - self.set_dynamic_size(_axis, _dim_tag.dyn_size) + self._dim_tags = self._dim_tags[:_axis] + (_dim_tag,) + self._dim_tags[_axis + 1:] + if auto_create_placeholders: + # Do that after same_dim_tags_as handling. + _auto_create_size_placeholders_on_dim_tags(name=name, dim_tags=self._dim_tags) self._adapt_batch_consistent_dim_tags() self.sanity_check(assume_complete=False) @@ -5780,7 +5777,7 @@ def _infer_dim_tags_tuple_from_shape( size_placeholder, dim_tags, name, - auto_create_placeholders + extern_data ): """ :param tuple[int|None]|list[int|None] shape: this is without batch-dim-axis @@ -5790,7 +5787,7 @@ def _infer_dim_tags_tuple_from_shape( :param bool sparse: :param dict[int,tf.Tensor]|None size_placeholder: key is axis without batch-dim :param dict[int,Dim]|None dim_tags: some existing explicitly specified dim tags. key is axis with batch-dim - :param bool auto_create_placeholders: + :param bool extern_data: :param str name: :return: dim tags tuple :rtype: tuple[Dim] @@ -5808,8 +5805,6 @@ def _infer_dim_tags_tuple_from_shape( dim_tags = dim_tags.copy() if dim_tags else {} if batch_dim_axis is not None and batch_dim_axis not in dim_tags: dim_tags[batch_dim_axis] = Dim(kind=Dim.Types.Batch, description="batch:%s" % name) - # noinspection PyShadowingNames - batch_dim = dim_tags[batch_dim_axis] if batch_dim_axis is not None else None # Note: Consistent to Data.get_dim_tag, # prefer interpretation as spatial axis if there is a dynamic size or this is marked as time axis. if size_placeholder: @@ -5833,7 +5828,7 @@ def _infer_dim_tags_tuple_from_shape( axis_wo_b = _get_axis_wo_b(axis, batch_dim_axis=batch_dim_axis) dyn_size = size_placeholder.get(axis_wo_b) if (size_placeholder and axis_wo_b is not None) else None dim = batch_shape[axis] - if auto_create_placeholders and dim is None and dyn_size is None and axis != batch_dim_axis: + if extern_data and dim is None and dyn_size is None and axis != batch_dim_axis: if not tag: if axis == time_dim_axis: tag_name = "time" @@ -5845,7 +5840,6 @@ def _infer_dim_tags_tuple_from_shape( # This is such that Dim.is_equal behaves as before, e.g. in Data.get_common_data. kind=Dim.Types.Spatial) dim_tags[axis] = tag - _create_size_placeholder(name=name, axis_wo_b=axis_wo_b, tag=tag, batch_dim=batch_dim) dyn_size = tag.dyn_size if tag: # Just some sanity checks. diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index 3e7e515e2..61dd1fcf7 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -2223,12 +2223,13 @@ def test_SplitDimsLayer_dim_tags(): feat_dim = FeatureDim("feat", 3) config = Config({ "extern_data": {"data": {"dim_tags": [batch_dim, time_dim, feat_dim]}}}) - net = TFNetwork(config=config) - net.construct_from_dict({ - "output": { - 'class': 'split_dims', 'from': 'data', 'axis': time_dim, 'dims': [rem_dim, window_dim], - 'out_shape': {batch_dim, rem_dim, window_dim, feat_dim}} - }) + with make_scope(): + net = TFNetwork(config=config) + net.construct_from_dict({ + "output": { + 'class': 'split_dims', 'from': 'data', 'axis': time_dim, 'dims': [rem_dim, window_dim], + 'out_shape': {batch_dim, rem_dim, window_dim, feat_dim}} + }) def test_SplitDimsLayer_dim_tags_expand(): @@ -2238,12 +2239,13 @@ def test_SplitDimsLayer_dim_tags_expand(): expand_dim = SpatialDim("expand_dim", 1) config = Config({ "extern_data": {"data": {"dim_tags": [batch_dim, time_dim, feat_dim]}}}) - net = TFNetwork(config=config) - net.construct_from_dict({ - "output": { - 'class': 'split_dims', 'from': 'data', 'axis': feat_dim, 'dims': [feat_dim, expand_dim], - 'out_shape': {batch_dim, time_dim, feat_dim, expand_dim}} - }) + with make_scope(): + net = TFNetwork(config=config) + net.construct_from_dict({ + "output": { + 'class': 'split_dims', 'from': 'data', 'axis': feat_dim, 'dims': [feat_dim, expand_dim], + 'out_shape': {batch_dim, time_dim, feat_dim, expand_dim}} + }) def test_SplitDimsLayer_dim_tags_split_batch_simple(): diff --git a/tests/test_TFUtil.py b/tests/test_TFUtil.py index 9a8eac7de..bb0b9c989 100644 --- a/tests/test_TFUtil.py +++ b/tests/test_TFUtil.py @@ -1243,6 +1243,22 @@ def test_Data_verify_out_shape_optional_implicit_dim(): x.verify_out_shape({time_dim, feat_dim}, allow_missing_implicit_dims=True) +def test_Data_auto_create_placeholders_same_dim_tags_as_existing(): + # Came up via: https://github.com/rwth-i6/returnn/pull/1143 + n_out = 3 + time_tag = SpatialDim("time") + with tf.Graph().as_default() as graph, tf_compat.v1.Session(graph=graph) as session: + assert isinstance(graph, tf.Graph) + data = Data("data", dim=n_out, same_dim_tags_as={"t": time_tag}, auto_create_placeholders=True) + classes = Data("classes", dim=n_out, sparse=True, same_dim_tags_as={"t": time_tag}, auto_create_placeholders=True) + assert time_tag.dyn_size is not None # this is not so relevant and might change + seq_len = time_tag.dyn_size + assert seq_len is data.get_sequence_lengths() is classes.get_sequence_lengths() + assert seq_len.op.type == "Placeholder" + placeholder_ops = [op for op in graph.get_operations() if op.type == "Placeholder"] + assert_equal(set(placeholder_ops), {data.placeholder.op, classes.placeholder.op, time_tag.dyn_size.op}) + + def test_Dim_copy(): # https://github.com/rwth-i6/returnn/issues/860 import copy