From e633cb734a10ce9351426a7b2dfcc62ad665141a Mon Sep 17 00:00:00 2001 From: Robin Schmitt Date: Fri, 17 Sep 2021 09:14:03 +0200 Subject: [PATCH 1/5] SliceNdLayer support a layer as the size argument --- returnn/tf/layers/basic.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index de621dd7d..f7981ad6f 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -859,7 +859,8 @@ class SliceNdLayer(_ConcatInputLayer): def __init__(self, start, size, min_size=None, **kwargs): """ :param LayerBase start: (B,...) - :param int|None size: if None, it uses the max possible size, and it becomes a dynamic axis + :param int|LayerBase|None size: if None, it uses the max possible size, + and it becomes a dynamic axis. If LayerBase, it needs to have the same shape as start :param int|None min_size: if size is None, but we want to have a min-size """ super(SliceNdLayer, self).__init__(**kwargs) @@ -867,6 +868,7 @@ def __init__(self, start, size, min_size=None, **kwargs): x = self.input_data.copy_as_batch_major() seq_lens = x.get_sequence_lengths() if x.is_time_axis_dynamic() else None # (B,) or None self.start = start + self.size = size start_data = start.output.copy_as_batch_major() # e.g. (B,) or (B,T) start_t = start_data.placeholder if size is None: @@ -874,13 +876,19 @@ def __init__(self, start, size, min_size=None, **kwargs): min_size = 0 if seq_lens is None: assert isinstance(x.batch_shape[x.time_dim_axis], int) - size = tf.maximum(tf.reduce_max(x.batch_shape[x.time_dim_axis] - start_t), min_size) # scalar + size_t = x.batch_shape[x.time_dim_axis] - start_t else: # make seq_lens compatible with start_t seq_lens = expand_multiple_dims( # e.g. (B,) or (B,1) x=seq_lens, axes=[-1] * (len(start_t.shape) - len(seq_lens.shape))) - size = tf.maximum(tf.reduce_max(seq_lens - start_t), min_size) # scalar + size_t = seq_lens - start_t + size = tf.maximum(tf.reduce_max(size_t), min_size) # scalar + elif isinstance(size, LayerBase): + assert size.output.batch_ndim == start_data.batch_ndim + min_size = 0 + size_t = size.output.get_placeholder_as_batch_major() + size = tf.maximum(tf.reduce_max(size_t), min_size) # scalar # for each start index in start_data, we want to gather a slice # therefore, the output's first axes are the same as the ones from start_data # and the next axis will therefore be the slice axis @@ -888,10 +896,7 @@ def __init__(self, start, size, min_size=None, **kwargs): assert slice_tag.description.startswith("sliced-time:") if not isinstance(size, int): # in this case, size is not known before runtime and becomes dynamic and we need to set dyn_size - if seq_lens is None: - dyn_size = tf.maximum(x.batch_shape[x.time_dim_axis] - start_t, min_size) # (B,) or (B,T) - else: - dyn_size = tf.maximum(seq_lens - start_t, min_size) # (B,) or (B,T) + dyn_size = tf.maximum(size_t, min_size) # (B,) or (B,T) dyn_size_ext = Data( name=("%s:dyn_size" % slice_tag.description), dtype=Data.size_dtype, @@ -923,6 +928,8 @@ def __init__(self, start, size, min_size=None, **kwargs): tf.greater(gather_positions, x.batch_shape[1] - 1), tf.less(gather_positions, 0)) gather_positions = tf.clip_by_value(gather_positions, 0, x.batch_shape[1] - 1) + if isinstance(self.size, LayerBase): + pad_mask = tf.logical_or(tf.greater(gather_positions, tf.expand_dims(start_t + size_t - 1, -1)), pad_mask) gather_positions_data.placeholder = gather_positions position = InternalLayer( network=self.network, @@ -951,7 +958,10 @@ def get_dep_layers(self): """ :rtype: list[LayerBase] """ - return super(SliceNdLayer, self).get_dep_layers() + [self.start] + dep_layers = super(SliceNdLayer, self).get_dep_layers() + [self.start] + if isinstance(self.size, LayerBase): + dep_layers += [self.size] + return dep_layers @classmethod def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, **kwargs): @@ -966,6 +976,8 @@ def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, **kwarg start_data = start.output.copy_as_batch_major() input_data = sources[0].output.copy_as_batch_major() gather_positions_data = start_data.copy_template(name="%s_gather_positions" % name) + if isinstance(size, LayerBase): + size = None # size might be None here in which case we set the dyn_size in __init__ tag = DimensionTag( kind=DimensionTag.Types.Spatial, @@ -991,6 +1003,8 @@ def transform_config_dict(cls, d, network, get_layer): """ super(SliceNdLayer, cls).transform_config_dict(d, network=network, get_layer=get_layer) d["start"] = get_layer(d["start"]) + if isinstance(d["size"], str): + d["size"] = get_layer(d["size"]) class GatherLayer(_ConcatInputLayer): From d08d33561ec09465a4ac77041f8feba2ed12d5bc Mon Sep 17 00:00:00 2001 From: Robin Schmitt Date: Mon, 20 Sep 2021 08:42:49 +0200 Subject: [PATCH 2/5] test_SliceNdLayer_multidimensional_size --- tests/test_TFNetworkLayer.py | 54 ++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index cc5589a48..4aa72a062 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -2820,6 +2820,60 @@ def test_SliceNdLayer_multidimensional_start(): numpy.testing.assert_equal(orig_seq[t2], segments[b, t, t2]) +def test_SliceNdLayer_multidimensional_size(): + with make_scope() as session: + n_out = 5 + n_batch = 3 + max_seq_len = 10 + config = Config({ + "debug_print_layer_output_template": True, + "extern_data": { + "data": {"dim": n_out}, + "classes": {"dim": n_out, "sparse": True} + }}) + net = TFNetwork(config=config, train_flag=True) + net.construct_from_dict({ + "output": { + "class": "rec", "from": "data:data", "unit": { + "const1": {"class": "constant", "value": 1}, + "start": {"class": "reinterpret_data", "from": "prev:choice", "set_sparse": False}, + "size": {"class": "combine", "from": ["const1", "start"], "kind": "add"}, + "slices": {"class": "slice_nd", "from": "base:data:data", "start": "start", "size": "size"}, + "output": {"class": "reduce", "from": "slices", "mode": "max", "axes": "dyn:-1"}, + "prob": {"class": "softmax", "from": "data:source", "target": "classes", "loss": "ce"}, + 'choice': { + 'class': 'choice', 'target': "classes", 'beam_size': 3, 'from': "prob", "input_type": "prob", + "initial_output": 0,}}}}) + session.run(tf_compat.v1.global_variables_initializer()) + output_layer = net.layers["output"] + starts = output_layer.cell.output_layers_net.layers["start"].output.get_placeholder_as_batch_major() + sizes = output_layer.cell.output_layers_net.layers["size"].output.get_placeholder_as_batch_major() + segments = output_layer.cell.output_layers_net.layers["slices"].output.get_placeholder_as_batch_major() + feed = make_feed_dict(net.extern_data.data.values(), n_batch=n_batch, n_time=max_seq_len, same_time=True) + starts = session.run(starts, feed_dict=feed) + sizes = session.run(sizes, feed_dict=feed) + segments = session.run(segments, feed_dict=feed) + seq_lens = feed[net.extern_data.data["data"].size_placeholder[0]] + input_data = feed[net.extern_data.data["data"].placeholder] + max_size = numpy.amax(sizes) + max_size = max(max_size, 0) + assert segments.shape == (n_batch, max_seq_len, max_size, n_out) + for b in range(n_batch): + for t in range(max_seq_len): + s = starts[b, t] + size = sizes[b, t] + end = min(s + size, seq_lens[b]) + orig_seq = input_data[b, s:end] + if len(orig_seq) < max_size: + orig_seq = numpy.pad(orig_seq, [(0, max_size - len(orig_seq)), (0, 0)], "constant") + elif len(orig_seq) > max_size: + orig_seq = orig_seq[:max_size] + assert orig_seq.shape == (max_size, n_out) + orig_seq = numpy.where((numpy.arange(s, s + max_size) >= seq_lens[b])[:, None], 0.0, orig_seq) + for t2 in range(max_size): + numpy.testing.assert_equal(orig_seq[t2], segments[b, t, t2]) + + def test_SliceNdLayer_set_tag_on_size_tensor(): with make_scope(): n_out = 5 From d58191edcbf9bd314a8fcaacae36066a26e5f5b0 Mon Sep 17 00:00:00 2001 From: Robin Schmitt Date: Mon, 20 Sep 2021 10:44:12 +0200 Subject: [PATCH 3/5] SliceNdLayer make seq_lens more generic and use the compatible_to logic more --- returnn/tf/layers/basic.py | 55 +++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index f7981ad6f..482d16e17 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -860,35 +860,40 @@ def __init__(self, start, size, min_size=None, **kwargs): """ :param LayerBase start: (B,...) :param int|LayerBase|None size: if None, it uses the max possible size, - and it becomes a dynamic axis. If LayerBase, it needs to have the same shape as start + and it becomes a dynamic axis. :param int|None min_size: if size is None, but we want to have a min-size """ super(SliceNdLayer, self).__init__(**kwargs) from returnn.tf.util.basic import where_bc, expand_multiple_dims - x = self.input_data.copy_as_batch_major() - seq_lens = x.get_sequence_lengths() if x.is_time_axis_dynamic() else None # (B,) or None + from returnn.tf.util.data import Data + x = self.input_data.copy() + seq_lens_data = x.get_time_dim_tag().dyn_size_ext # (B,) or None self.start = start self.size = size - start_data = start.output.copy_as_batch_major() # e.g. (B,) or (B,T) - start_t = start_data.placeholder + start_data = start.output.copy() # e.g. (B,) or (B,T) + start_t = start_data.get_placeholder_as_batch_major() if size is None: if min_size is None: min_size = 0 - if seq_lens is None: + if seq_lens_data is None: assert isinstance(x.batch_shape[x.time_dim_axis], int) size_t = x.batch_shape[x.time_dim_axis] - start_t else: - # make seq_lens compatible with start_t - seq_lens = expand_multiple_dims( # e.g. (B,) or (B,1) - x=seq_lens, - axes=[-1] * (len(start_t.shape) - len(seq_lens.shape))) - size_t = seq_lens - start_t + seq_lens_t = seq_lens_data.copy_compatible_to(start_data, check_sparse=False).get_placeholder_as_batch_major() + size_t = seq_lens_t - start_t size = tf.maximum(tf.reduce_max(size_t), min_size) # scalar elif isinstance(size, LayerBase): - assert size.output.batch_ndim == start_data.batch_ndim + size_data = size.output.copy() + common_data = Data.get_common_data([start_data, size_data]) + size_data = size_data.copy_compatible_to(common_data) + size_t = size_data.get_placeholder_as_batch_major() + start_data = start_data.copy_compatible_to(common_data) + start_t = start_data.get_placeholder_as_batch_major() min_size = 0 - size_t = size.output.get_placeholder_as_batch_major() size = tf.maximum(tf.reduce_max(size_t), min_size) # scalar + else: + # in this case, size_t is never used but needs to be set to avoid a PyCharm warning + size_t = None # for each start index in start_data, we want to gather a slice # therefore, the output's first axes are the same as the ones from start_data # and the next axis will therefore be the slice axis @@ -901,28 +906,28 @@ def __init__(self, start, size, min_size=None, **kwargs): name=("%s:dyn_size" % slice_tag.description), dtype=Data.size_dtype, placeholder=dyn_size, - dim_tags=start_data.dim_tags, + dim_tags=start_data.copy_as_batch_major().dim_tags, # as batch major because the placeholder is too batch=slice_tag.batch, beam=slice_tag.batch.beam if slice_tag.batch else self.output.beam, control_flow_ctx=slice_tag.control_flow_ctx) slice_tag.dyn_size_ext = dyn_size_ext slice_tag.set_tag_on_size_tensor(dyn_size) - gather_positions_data = start_data.copy_template(name="%s_gather_positions" % self.name) + # get as batch major, because the placeholder later will also be batch major + gather_positions_data = start_data.copy_template(name="%s_gather_positions" % self.name).copy_as_batch_major() gather_positions_data = gather_positions_data.copy_add_dim_by_tag( slice_tag, unbroadcast=True, axis=start_data.batch_ndim) # [start+0, start+1, ...] gather_positions = tf.expand_dims(start_t, -1) + tf.range(0, size) # e.g. (B, size) or (B, T, size) - if seq_lens is not None: - # broadcast from (B,) to the shape of the indices - seq_lens = expand_multiple_dims( # e.g. (B,1) or (B,1,1) - x=seq_lens, - axes=[-1] * (len(gather_positions.shape) - len(seq_lens.shape))) + if seq_lens_data is not None: + seq_lens_t = seq_lens_data.copy_compatible_to( + gather_positions_data, + check_sparse=False).placeholder pad_mask = tf.logical_or( # shape like gather_positions - tf.greater(gather_positions, seq_lens - 1), + tf.greater(gather_positions, seq_lens_t - 1), tf.less(gather_positions, 0)) - gather_positions = tf.clip_by_value(gather_positions, 0, seq_lens - 1) + gather_positions = tf.clip_by_value(gather_positions, 0, seq_lens_t - 1) else: pad_mask = tf.logical_or( # shape like gather_positions tf.greater(gather_positions, x.batch_shape[1] - 1), @@ -969,12 +974,12 @@ def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, **kwarg :param str name: :param list[LayerBase] sources: :param LayerBase|None start: - :param int|None size: + :param int|LayerBase|None size: :rtype: Data """ from ..util.data import DimensionTag - start_data = start.output.copy_as_batch_major() - input_data = sources[0].output.copy_as_batch_major() + start_data = start.output.copy() + input_data = sources[0].output.copy() gather_positions_data = start_data.copy_template(name="%s_gather_positions" % name) if isinstance(size, LayerBase): size = None From 2afe7567398b88784b855e5586a3bae3cb3779ad Mon Sep 17 00:00:00 2001 From: Robin Schmitt Date: Mon, 20 Sep 2021 13:23:12 +0200 Subject: [PATCH 4/5] SliceNdLayer remove as_batch_major calls --- returnn/tf/layers/basic.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 482d16e17..8db4b62c0 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -871,7 +871,7 @@ def __init__(self, start, size, min_size=None, **kwargs): self.start = start self.size = size start_data = start.output.copy() # e.g. (B,) or (B,T) - start_t = start_data.get_placeholder_as_batch_major() + start_t = start_data.placeholder if size is None: if min_size is None: min_size = 0 @@ -879,20 +879,19 @@ def __init__(self, start, size, min_size=None, **kwargs): assert isinstance(x.batch_shape[x.time_dim_axis], int) size_t = x.batch_shape[x.time_dim_axis] - start_t else: - seq_lens_t = seq_lens_data.copy_compatible_to(start_data, check_sparse=False).get_placeholder_as_batch_major() + seq_lens_t = seq_lens_data.copy_compatible_to(start_data, check_sparse=False).placeholder size_t = seq_lens_t - start_t size = tf.maximum(tf.reduce_max(size_t), min_size) # scalar elif isinstance(size, LayerBase): size_data = size.output.copy() common_data = Data.get_common_data([start_data, size_data]) size_data = size_data.copy_compatible_to(common_data) - size_t = size_data.get_placeholder_as_batch_major() + size_t = size_data.placeholder start_data = start_data.copy_compatible_to(common_data) - start_t = start_data.get_placeholder_as_batch_major() + start_t = start_data.placeholder min_size = 0 size = tf.maximum(tf.reduce_max(size_t), min_size) # scalar else: - # in this case, size_t is never used but needs to be set to avoid a PyCharm warning size_t = None # for each start index in start_data, we want to gather a slice # therefore, the output's first axes are the same as the ones from start_data @@ -906,14 +905,13 @@ def __init__(self, start, size, min_size=None, **kwargs): name=("%s:dyn_size" % slice_tag.description), dtype=Data.size_dtype, placeholder=dyn_size, - dim_tags=start_data.copy_as_batch_major().dim_tags, # as batch major because the placeholder is too + dim_tags=start_data.dim_tags, batch=slice_tag.batch, beam=slice_tag.batch.beam if slice_tag.batch else self.output.beam, control_flow_ctx=slice_tag.control_flow_ctx) slice_tag.dyn_size_ext = dyn_size_ext slice_tag.set_tag_on_size_tensor(dyn_size) - # get as batch major, because the placeholder later will also be batch major - gather_positions_data = start_data.copy_template(name="%s_gather_positions" % self.name).copy_as_batch_major() + gather_positions_data = start_data.copy_template(name="%s_gather_positions" % self.name) gather_positions_data = gather_positions_data.copy_add_dim_by_tag( slice_tag, unbroadcast=True, @@ -935,6 +933,10 @@ def __init__(self, start, size, min_size=None, **kwargs): gather_positions = tf.clip_by_value(gather_positions, 0, x.batch_shape[1] - 1) if isinstance(self.size, LayerBase): pad_mask = tf.logical_or(tf.greater(gather_positions, tf.expand_dims(start_t + size_t - 1, -1)), pad_mask) + pad_mask_data = gather_positions_data.copy_template( + name="%s_gather_positions" % self.name, + dtype="bool") + pad_mask_data.placeholder = pad_mask gather_positions_data.placeholder = gather_positions position = InternalLayer( network=self.network, @@ -956,7 +958,8 @@ def __init__(self, start, size, min_size=None, **kwargs): # the gradient flow would go into wrong frames # and might lead to unexpected behavior. # So to be on the safe side, we do the masking here. - pad_mask = expand_multiple_dims(pad_mask, [-1] * (len(placeholder.shape) - len(pad_mask.shape))) + pad_mask_data = pad_mask_data.copy_compatible_to(gather_layer.output, check_sparse=False, check_dtype=False) + pad_mask = pad_mask_data.placeholder self.output.placeholder = where_bc(pad_mask, tf.zeros_like(placeholder), placeholder) def get_dep_layers(self): From d582ed76c00d95f44f5bb1509eed7b28e6da5a78 Mon Sep 17 00:00:00 2001 From: Robin Schmitt Date: Mon, 20 Sep 2021 14:05:01 +0200 Subject: [PATCH 5/5] SliceNdLayer get_common_data on all 3 sources --- returnn/tf/layers/basic.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 8db4b62c0..9be398c1a 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -864,13 +864,18 @@ def __init__(self, start, size, min_size=None, **kwargs): :param int|None min_size: if size is None, but we want to have a min-size """ super(SliceNdLayer, self).__init__(**kwargs) - from returnn.tf.util.basic import where_bc, expand_multiple_dims + from returnn.tf.util.basic import where_bc from returnn.tf.util.data import Data x = self.input_data.copy() seq_lens_data = x.get_time_dim_tag().dyn_size_ext # (B,) or None self.start = start self.size = size start_data = start.output.copy() # e.g. (B,) or (B,T) + data_objs = [start_data] + data_objs += [size.output] if isinstance(size, LayerBase) else [] + data_objs += [seq_lens_data] if isinstance(seq_lens_data, Data) else [] + common_data = Data.get_common_data(data_objs) + start_data = start_data.copy_compatible_to(common_data, check_sparse=False) start_t = start_data.placeholder if size is None: if min_size is None: @@ -879,16 +884,12 @@ def __init__(self, start, size, min_size=None, **kwargs): assert isinstance(x.batch_shape[x.time_dim_axis], int) size_t = x.batch_shape[x.time_dim_axis] - start_t else: - seq_lens_t = seq_lens_data.copy_compatible_to(start_data, check_sparse=False).placeholder + seq_lens_t = seq_lens_data.copy_compatible_to(common_data, check_sparse=False).placeholder size_t = seq_lens_t - start_t size = tf.maximum(tf.reduce_max(size_t), min_size) # scalar elif isinstance(size, LayerBase): - size_data = size.output.copy() - common_data = Data.get_common_data([start_data, size_data]) - size_data = size_data.copy_compatible_to(common_data) + size_data = size.output.copy_compatible_to(common_data, check_sparse=False) size_t = size_data.placeholder - start_data = start_data.copy_compatible_to(common_data) - start_t = start_data.placeholder min_size = 0 size = tf.maximum(tf.reduce_max(size_t), min_size) # scalar else: @@ -898,8 +899,9 @@ def __init__(self, start, size, min_size=None, **kwargs): # and the next axis will therefore be the slice axis slice_tag = self.output.dim_tags[start_data.batch_ndim] assert slice_tag.description.startswith("sliced-time:") - if not isinstance(size, int): + if size_t is not None: # in this case, size is not known before runtime and becomes dynamic and we need to set dyn_size + assert not isinstance(size, int) dyn_size = tf.maximum(size_t, min_size) # (B,) or (B,T) dyn_size_ext = Data( name=("%s:dyn_size" % slice_tag.description),