diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index de621dd7d..9be398c1a 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -859,39 +859,50 @@ class SliceNdLayer(_ConcatInputLayer): def __init__(self, start, size, min_size=None, **kwargs): """ :param LayerBase start: (B,...) - :param int|None size: if None, it uses the max possible size, and it becomes a dynamic axis + :param int|LayerBase|None size: if None, it uses the max possible size, + and it becomes a dynamic axis. :param int|None min_size: if size is None, but we want to have a min-size """ super(SliceNdLayer, self).__init__(**kwargs) - from returnn.tf.util.basic import where_bc, expand_multiple_dims - x = self.input_data.copy_as_batch_major() - seq_lens = x.get_sequence_lengths() if x.is_time_axis_dynamic() else None # (B,) or None + from returnn.tf.util.basic import where_bc + from returnn.tf.util.data import Data + x = self.input_data.copy() + seq_lens_data = x.get_time_dim_tag().dyn_size_ext # (B,) or None self.start = start - start_data = start.output.copy_as_batch_major() # e.g. (B,) or (B,T) + self.size = size + start_data = start.output.copy() # e.g. (B,) or (B,T) + data_objs = [start_data] + data_objs += [size.output] if isinstance(size, LayerBase) else [] + data_objs += [seq_lens_data] if isinstance(seq_lens_data, Data) else [] + common_data = Data.get_common_data(data_objs) + start_data = start_data.copy_compatible_to(common_data, check_sparse=False) start_t = start_data.placeholder if size is None: if min_size is None: min_size = 0 - if seq_lens is None: + if seq_lens_data is None: assert isinstance(x.batch_shape[x.time_dim_axis], int) - size = tf.maximum(tf.reduce_max(x.batch_shape[x.time_dim_axis] - start_t), min_size) # scalar + size_t = x.batch_shape[x.time_dim_axis] - start_t else: - # make seq_lens compatible with start_t - seq_lens = expand_multiple_dims( # e.g. (B,) or (B,1) - x=seq_lens, - axes=[-1] * (len(start_t.shape) - len(seq_lens.shape))) - size = tf.maximum(tf.reduce_max(seq_lens - start_t), min_size) # scalar + seq_lens_t = seq_lens_data.copy_compatible_to(common_data, check_sparse=False).placeholder + size_t = seq_lens_t - start_t + size = tf.maximum(tf.reduce_max(size_t), min_size) # scalar + elif isinstance(size, LayerBase): + size_data = size.output.copy_compatible_to(common_data, check_sparse=False) + size_t = size_data.placeholder + min_size = 0 + size = tf.maximum(tf.reduce_max(size_t), min_size) # scalar + else: + size_t = None # for each start index in start_data, we want to gather a slice # therefore, the output's first axes are the same as the ones from start_data # and the next axis will therefore be the slice axis slice_tag = self.output.dim_tags[start_data.batch_ndim] assert slice_tag.description.startswith("sliced-time:") - if not isinstance(size, int): + if size_t is not None: # in this case, size is not known before runtime and becomes dynamic and we need to set dyn_size - if seq_lens is None: - dyn_size = tf.maximum(x.batch_shape[x.time_dim_axis] - start_t, min_size) # (B,) or (B,T) - else: - dyn_size = tf.maximum(seq_lens - start_t, min_size) # (B,) or (B,T) + assert not isinstance(size, int) + dyn_size = tf.maximum(size_t, min_size) # (B,) or (B,T) dyn_size_ext = Data( name=("%s:dyn_size" % slice_tag.description), dtype=Data.size_dtype, @@ -909,20 +920,25 @@ def __init__(self, start, size, min_size=None, **kwargs): axis=start_data.batch_ndim) # [start+0, start+1, ...] gather_positions = tf.expand_dims(start_t, -1) + tf.range(0, size) # e.g. (B, size) or (B, T, size) - if seq_lens is not None: - # broadcast from (B,) to the shape of the indices - seq_lens = expand_multiple_dims( # e.g. (B,1) or (B,1,1) - x=seq_lens, - axes=[-1] * (len(gather_positions.shape) - len(seq_lens.shape))) + if seq_lens_data is not None: + seq_lens_t = seq_lens_data.copy_compatible_to( + gather_positions_data, + check_sparse=False).placeholder pad_mask = tf.logical_or( # shape like gather_positions - tf.greater(gather_positions, seq_lens - 1), + tf.greater(gather_positions, seq_lens_t - 1), tf.less(gather_positions, 0)) - gather_positions = tf.clip_by_value(gather_positions, 0, seq_lens - 1) + gather_positions = tf.clip_by_value(gather_positions, 0, seq_lens_t - 1) else: pad_mask = tf.logical_or( # shape like gather_positions tf.greater(gather_positions, x.batch_shape[1] - 1), tf.less(gather_positions, 0)) gather_positions = tf.clip_by_value(gather_positions, 0, x.batch_shape[1] - 1) + if isinstance(self.size, LayerBase): + pad_mask = tf.logical_or(tf.greater(gather_positions, tf.expand_dims(start_t + size_t - 1, -1)), pad_mask) + pad_mask_data = gather_positions_data.copy_template( + name="%s_gather_positions" % self.name, + dtype="bool") + pad_mask_data.placeholder = pad_mask gather_positions_data.placeholder = gather_positions position = InternalLayer( network=self.network, @@ -944,14 +960,18 @@ def __init__(self, start, size, min_size=None, **kwargs): # the gradient flow would go into wrong frames # and might lead to unexpected behavior. # So to be on the safe side, we do the masking here. - pad_mask = expand_multiple_dims(pad_mask, [-1] * (len(placeholder.shape) - len(pad_mask.shape))) + pad_mask_data = pad_mask_data.copy_compatible_to(gather_layer.output, check_sparse=False, check_dtype=False) + pad_mask = pad_mask_data.placeholder self.output.placeholder = where_bc(pad_mask, tf.zeros_like(placeholder), placeholder) def get_dep_layers(self): """ :rtype: list[LayerBase] """ - return super(SliceNdLayer, self).get_dep_layers() + [self.start] + dep_layers = super(SliceNdLayer, self).get_dep_layers() + [self.start] + if isinstance(self.size, LayerBase): + dep_layers += [self.size] + return dep_layers @classmethod def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, **kwargs): @@ -959,13 +979,15 @@ def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, **kwarg :param str name: :param list[LayerBase] sources: :param LayerBase|None start: - :param int|None size: + :param int|LayerBase|None size: :rtype: Data """ from ..util.data import DimensionTag - start_data = start.output.copy_as_batch_major() - input_data = sources[0].output.copy_as_batch_major() + start_data = start.output.copy() + input_data = sources[0].output.copy() gather_positions_data = start_data.copy_template(name="%s_gather_positions" % name) + if isinstance(size, LayerBase): + size = None # size might be None here in which case we set the dyn_size in __init__ tag = DimensionTag( kind=DimensionTag.Types.Spatial, @@ -991,6 +1013,8 @@ def transform_config_dict(cls, d, network, get_layer): """ super(SliceNdLayer, cls).transform_config_dict(d, network=network, get_layer=get_layer) d["start"] = get_layer(d["start"]) + if isinstance(d["size"], str): + d["size"] = get_layer(d["size"]) class GatherLayer(_ConcatInputLayer): diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index cc5589a48..4aa72a062 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -2820,6 +2820,60 @@ def test_SliceNdLayer_multidimensional_start(): numpy.testing.assert_equal(orig_seq[t2], segments[b, t, t2]) +def test_SliceNdLayer_multidimensional_size(): + with make_scope() as session: + n_out = 5 + n_batch = 3 + max_seq_len = 10 + config = Config({ + "debug_print_layer_output_template": True, + "extern_data": { + "data": {"dim": n_out}, + "classes": {"dim": n_out, "sparse": True} + }}) + net = TFNetwork(config=config, train_flag=True) + net.construct_from_dict({ + "output": { + "class": "rec", "from": "data:data", "unit": { + "const1": {"class": "constant", "value": 1}, + "start": {"class": "reinterpret_data", "from": "prev:choice", "set_sparse": False}, + "size": {"class": "combine", "from": ["const1", "start"], "kind": "add"}, + "slices": {"class": "slice_nd", "from": "base:data:data", "start": "start", "size": "size"}, + "output": {"class": "reduce", "from": "slices", "mode": "max", "axes": "dyn:-1"}, + "prob": {"class": "softmax", "from": "data:source", "target": "classes", "loss": "ce"}, + 'choice': { + 'class': 'choice', 'target': "classes", 'beam_size': 3, 'from': "prob", "input_type": "prob", + "initial_output": 0,}}}}) + session.run(tf_compat.v1.global_variables_initializer()) + output_layer = net.layers["output"] + starts = output_layer.cell.output_layers_net.layers["start"].output.get_placeholder_as_batch_major() + sizes = output_layer.cell.output_layers_net.layers["size"].output.get_placeholder_as_batch_major() + segments = output_layer.cell.output_layers_net.layers["slices"].output.get_placeholder_as_batch_major() + feed = make_feed_dict(net.extern_data.data.values(), n_batch=n_batch, n_time=max_seq_len, same_time=True) + starts = session.run(starts, feed_dict=feed) + sizes = session.run(sizes, feed_dict=feed) + segments = session.run(segments, feed_dict=feed) + seq_lens = feed[net.extern_data.data["data"].size_placeholder[0]] + input_data = feed[net.extern_data.data["data"].placeholder] + max_size = numpy.amax(sizes) + max_size = max(max_size, 0) + assert segments.shape == (n_batch, max_seq_len, max_size, n_out) + for b in range(n_batch): + for t in range(max_seq_len): + s = starts[b, t] + size = sizes[b, t] + end = min(s + size, seq_lens[b]) + orig_seq = input_data[b, s:end] + if len(orig_seq) < max_size: + orig_seq = numpy.pad(orig_seq, [(0, max_size - len(orig_seq)), (0, 0)], "constant") + elif len(orig_seq) > max_size: + orig_seq = orig_seq[:max_size] + assert orig_seq.shape == (max_size, n_out) + orig_seq = numpy.where((numpy.arange(s, s + max_size) >= seq_lens[b])[:, None], 0.0, orig_seq) + for t2 in range(max_size): + numpy.testing.assert_equal(orig_seq[t2], segments[b, t, t2]) + + def test_SliceNdLayer_set_tag_on_size_tensor(): with make_scope(): n_out = 5