From 301f323a0c9fd10da7583162299f895d68a0ccf6 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 19 Oct 2022 17:14:25 +0200 Subject: [PATCH 1/5] Data same_dim_tags_as fix auto_create_placeholders Don't first create a new size placeholder and then later call declare_same_as. Esp this is required when declare_same_as becomes stricter (#1143). --- returnn/tf/util/data.py | 17 ++++++++--------- tests/test_TFUtil.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 0e67a23bc..aee649624 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -2823,8 +2823,6 @@ def __init__(self, name, if time_dim_axis is NotSpecified: time_dim_axis = _default_time_dim_axis_dim_tags(dim_tags) dim_tags = tuple(dim_tags) - if auto_create_placeholders: - _auto_create_size_placeholders_on_dim_tags(name=name, dim_tags=dim_tags) if batch_dim_axis_ is not None: if dim_tags[batch_dim_axis_].batch and not self._batch: self._batch = dim_tags[batch_dim_axis_].batch @@ -2846,7 +2844,7 @@ def __init__(self, name, dim_tags = _infer_dim_tags_tuple_from_shape( shape, batch_dim_axis=batch_dim_axis, time_dim_axis=time_dim_axis, feature_dim_axis=feature_dim_axis, size_placeholder=size_placeholder, name=name, - auto_create_placeholders=auto_create_placeholders, + extern_data=auto_create_placeholders, dim_tags=dim_tags, sparse=sparse) del batch_dim_axis del shape @@ -2893,8 +2891,10 @@ def __init__(self, name, base_tag = self._dim_tags[_axis] if base_tag != _dim_tag: base_tag.declare_same_as(_dim_tag) - if _dim_tag.dyn_size is not None: - self.set_dynamic_size(_axis, _dim_tag.dyn_size) + self._dim_tags = self._dim_tags[:_axis] + (_dim_tag,) + self._dim_tags[_axis + 1:] + if auto_create_placeholders: + # Do that after same_dim_tags_as handling. + _auto_create_size_placeholders_on_dim_tags(name=name, dim_tags=self._dim_tags) self._adapt_batch_consistent_dim_tags() self.sanity_check(assume_complete=False) @@ -5780,7 +5780,7 @@ def _infer_dim_tags_tuple_from_shape( size_placeholder, dim_tags, name, - auto_create_placeholders + extern_data ): """ :param tuple[int|None]|list[int|None] shape: this is without batch-dim-axis @@ -5790,7 +5790,7 @@ def _infer_dim_tags_tuple_from_shape( :param bool sparse: :param dict[int,tf.Tensor]|None size_placeholder: key is axis without batch-dim :param dict[int,Dim]|None dim_tags: some existing explicitly specified dim tags. key is axis with batch-dim - :param bool auto_create_placeholders: + :param bool extern_data: :param str name: :return: dim tags tuple :rtype: tuple[Dim] @@ -5833,7 +5833,7 @@ def _infer_dim_tags_tuple_from_shape( axis_wo_b = _get_axis_wo_b(axis, batch_dim_axis=batch_dim_axis) dyn_size = size_placeholder.get(axis_wo_b) if (size_placeholder and axis_wo_b is not None) else None dim = batch_shape[axis] - if auto_create_placeholders and dim is None and dyn_size is None and axis != batch_dim_axis: + if extern_data and dim is None and dyn_size is None and axis != batch_dim_axis: if not tag: if axis == time_dim_axis: tag_name = "time" @@ -5845,7 +5845,6 @@ def _infer_dim_tags_tuple_from_shape( # This is such that Dim.is_equal behaves as before, e.g. in Data.get_common_data. kind=Dim.Types.Spatial) dim_tags[axis] = tag - _create_size_placeholder(name=name, axis_wo_b=axis_wo_b, tag=tag, batch_dim=batch_dim) dyn_size = tag.dyn_size if tag: # Just some sanity checks. diff --git a/tests/test_TFUtil.py b/tests/test_TFUtil.py index 9a8eac7de..bb0b9c989 100644 --- a/tests/test_TFUtil.py +++ b/tests/test_TFUtil.py @@ -1243,6 +1243,22 @@ def test_Data_verify_out_shape_optional_implicit_dim(): x.verify_out_shape({time_dim, feat_dim}, allow_missing_implicit_dims=True) +def test_Data_auto_create_placeholders_same_dim_tags_as_existing(): + # Came up via: https://github.com/rwth-i6/returnn/pull/1143 + n_out = 3 + time_tag = SpatialDim("time") + with tf.Graph().as_default() as graph, tf_compat.v1.Session(graph=graph) as session: + assert isinstance(graph, tf.Graph) + data = Data("data", dim=n_out, same_dim_tags_as={"t": time_tag}, auto_create_placeholders=True) + classes = Data("classes", dim=n_out, sparse=True, same_dim_tags_as={"t": time_tag}, auto_create_placeholders=True) + assert time_tag.dyn_size is not None # this is not so relevant and might change + seq_len = time_tag.dyn_size + assert seq_len is data.get_sequence_lengths() is classes.get_sequence_lengths() + assert seq_len.op.type == "Placeholder" + placeholder_ops = [op for op in graph.get_operations() if op.type == "Placeholder"] + assert_equal(set(placeholder_ops), {data.placeholder.op, classes.placeholder.op, time_tag.dyn_size.op}) + + def test_Dim_copy(): # https://github.com/rwth-i6/returnn/issues/860 import copy From 78fd39ccdea632d0802e357f21263437bc597f71 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 19 Oct 2022 17:37:07 +0200 Subject: [PATCH 2/5] Data fix wrong batch info The dim tag could have an old invalid batch info. E.g. the global batch_dim when it comes from an old run. If we really need this, we should validate the dim tag first. But probably it's better to remove it and clean it up. --- returnn/tf/util/data.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index aee649624..91d25f2e1 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -2823,9 +2823,6 @@ def __init__(self, name, if time_dim_axis is NotSpecified: time_dim_axis = _default_time_dim_axis_dim_tags(dim_tags) dim_tags = tuple(dim_tags) - if batch_dim_axis_ is not None: - if dim_tags[batch_dim_axis_].batch and not self._batch: - self._batch = dim_tags[batch_dim_axis_].batch del shape_ del batch_dim_axis_ else: From 7ac561ff78fdecb957c3a7e8e48bc640b812456f Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 19 Oct 2022 17:38:21 +0200 Subject: [PATCH 3/5] small cleanup fix --- returnn/tf/util/data.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 91d25f2e1..1a8e610fc 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -5805,8 +5805,6 @@ def _infer_dim_tags_tuple_from_shape( dim_tags = dim_tags.copy() if dim_tags else {} if batch_dim_axis is not None and batch_dim_axis not in dim_tags: dim_tags[batch_dim_axis] = Dim(kind=Dim.Types.Batch, description="batch:%s" % name) - # noinspection PyShadowingNames - batch_dim = dim_tags[batch_dim_axis] if batch_dim_axis is not None else None # Note: Consistent to Data.get_dim_tag, # prefer interpretation as spatial axis if there is a dynamic size or this is marked as time axis. if size_placeholder: From d56a953505c69122efa6605a0025ed00f5b0c4fc Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 19 Oct 2022 17:40:28 +0200 Subject: [PATCH 4/5] tests small fixes --- tests/test_TFNetworkLayer.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index 3e7e515e2..61dd1fcf7 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -2223,12 +2223,13 @@ def test_SplitDimsLayer_dim_tags(): feat_dim = FeatureDim("feat", 3) config = Config({ "extern_data": {"data": {"dim_tags": [batch_dim, time_dim, feat_dim]}}}) - net = TFNetwork(config=config) - net.construct_from_dict({ - "output": { - 'class': 'split_dims', 'from': 'data', 'axis': time_dim, 'dims': [rem_dim, window_dim], - 'out_shape': {batch_dim, rem_dim, window_dim, feat_dim}} - }) + with make_scope(): + net = TFNetwork(config=config) + net.construct_from_dict({ + "output": { + 'class': 'split_dims', 'from': 'data', 'axis': time_dim, 'dims': [rem_dim, window_dim], + 'out_shape': {batch_dim, rem_dim, window_dim, feat_dim}} + }) def test_SplitDimsLayer_dim_tags_expand(): @@ -2238,12 +2239,13 @@ def test_SplitDimsLayer_dim_tags_expand(): expand_dim = SpatialDim("expand_dim", 1) config = Config({ "extern_data": {"data": {"dim_tags": [batch_dim, time_dim, feat_dim]}}}) - net = TFNetwork(config=config) - net.construct_from_dict({ - "output": { - 'class': 'split_dims', 'from': 'data', 'axis': feat_dim, 'dims': [feat_dim, expand_dim], - 'out_shape': {batch_dim, time_dim, feat_dim, expand_dim}} - }) + with make_scope(): + net = TFNetwork(config=config) + net.construct_from_dict({ + "output": { + 'class': 'split_dims', 'from': 'data', 'axis': feat_dim, 'dims': [feat_dim, expand_dim], + 'out_shape': {batch_dim, time_dim, feat_dim, expand_dim}} + }) def test_SplitDimsLayer_dim_tags_split_batch_simple(): From 0f6c740909274208320ac6dcbde1d6eadc8093a2 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 19 Oct 2022 20:31:30 +0200 Subject: [PATCH 5/5] Engine reset global batch dim --- returnn/tf/engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/returnn/tf/engine.py b/returnn/tf/engine.py index e302d9d25..addb6a248 100644 --- a/returnn/tf/engine.py +++ b/returnn/tf/engine.py @@ -1293,6 +1293,8 @@ def _init_network(self, net_desc, epoch=None): use_dataset_pipeline = False if self.config.is_true("dataset_pipeline"): use_dataset_pipeline = True + from returnn.tf.util.data import batch_dim + batch_dim.batch = None # make sure it is reset extern_data = ExternData() extern_data.init_from_config(config=self.config, auto_create_placeholders=not use_dataset_pipeline) if use_dataset_pipeline: