diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 0cee99415..b857c266f 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -842,9 +842,13 @@ def get_out_data_from_opts( class SliceNdLayer(_ConcatInputLayer): """ - This takes out a slice-range from some axis, + This takes out a slice-range from the time axis, e.g. ``x[start:start + size]``. - This layers allows a different start slice point for each batch, + If the input is of shape (B,T,F) and start is of shape (B,), + then the output will be of shape (B,size,F). + If the input is of shape (B,T,F) and start is of shape (B,T), + then the output will be of shape (B,T,size,F). + This layer allows a different start slice point for each batch, in contrast to :class:`SliceLayer`, and the start is variable. See also :class:`GatherNdLayer`. :class:`PrefixInTimeLayer` can recover the original shape (by zero-padding). @@ -854,44 +858,93 @@ class SliceNdLayer(_ConcatInputLayer): def __init__(self, start, size, min_size=None, **kwargs): """ - :param LayerBase start: + :param LayerBase start: (B,...) :param int|None size: if None, it uses the max possible size, and it becomes a dynamic axis - :param int|None min_size: if size is None, but we want to have a min-size, set this + :param int|None min_size: if size is None, but we want to have a min-size """ super(SliceNdLayer, self).__init__(**kwargs) - from returnn.tf.util.basic import slice_nd, where_bc, expand_multiple_dims, DimensionTag + from returnn.tf.util.basic import where_bc, expand_multiple_dims x = self.input_data.copy_as_batch_major() - assert x.time_dim_axis == 1, "currently only time-axis==1 supported" - seq_lens = x.get_sequence_lengths() if x.is_time_axis_dynamic() else None + seq_lens = x.get_sequence_lengths() if x.is_time_axis_dynamic() else None # (B,) or None self.start = start - assert start.output.have_batch_axis() and start.output.batch_shape == (None,) - start = start.output.get_placeholder_as_batch_major() + start_data = start.output.copy_as_batch_major() # e.g. (B,) or (B,T) + start_t = start_data.placeholder if size is None: + if min_size is None: + min_size = 0 if seq_lens is None: - size = tf.maximum(tf.reduce_max(x.batch_shape[1] - start), 0) + assert isinstance(x.batch_shape[x.time_dim_axis], int) + size = tf.maximum(tf.reduce_max(x.batch_shape[x.time_dim_axis] - start_t), min_size) # scalar else: - size = tf.maximum(tf.reduce_max(seq_lens - start), 0) - if min_size is not None: - size = tf.maximum(size, min_size) - self.size = size - start = tf.expand_dims(start, axis=1) # (B, T) - slices = slice_nd(x.placeholder, start=tf.cast(start, tf.int32), size=size) # (B,size, ...) + # make seq_lens compatible with start_t + seq_lens = expand_multiple_dims( # e.g. (B,) or (B,1) + x=seq_lens, + axes=[-1] * (len(start_t.shape) - len(seq_lens.shape))) + size = tf.maximum(tf.reduce_max(seq_lens - start_t), min_size) # scalar + # for each start index in start_data, we want to gather a slice + # therefore, the output's first axes are the same as the ones from start_data + # and the next axis will therefore be the slice axis + slice_tag = self.output.dim_tags[start_data.batch_ndim] + assert slice_tag.description.startswith("sliced-time:") + if not isinstance(size, int): + # in this case, size is not known before runtime and becomes dynamic and we need to set dyn_size + if seq_lens is None: + dyn_size = tf.maximum(x.batch_shape[x.time_dim_axis] - start_t, min_size) # (B,) or (B,T) + else: + dyn_size = tf.maximum(seq_lens - start_t, min_size) # (B,) or (B,T) + dyn_size_ext = Data( + name=("%s:dyn_size" % slice_tag.description), + dtype=Data.size_dtype, + placeholder=dyn_size, + dim_tags=start_data.dim_tags, + batch=slice_tag.batch, + beam=slice_tag.batch.beam if slice_tag.batch else self.output.beam, + control_flow_ctx=slice_tag.control_flow_ctx) + slice_tag.dyn_size_ext = dyn_size_ext + gather_positions_data = start_data.copy_template(name="%s_gather_positions" % self.name) + gather_positions_data = gather_positions_data.copy_add_dim_by_tag( + slice_tag, + unbroadcast=True, + axis=start_data.batch_ndim) + # [start+0, start+1, ...] + gather_positions = tf.expand_dims(start_t, -1) + tf.range(0, size) # e.g. (B, size) or (B, T, size) if seq_lens is not None: - mask = tf.greater_equal(tf.range(size)[None, :] + start, seq_lens[:, None]) # (B,T) - mask = expand_multiple_dims(mask, list(range(2, x.batch_ndim))) - slices = where_bc(mask, tf.zeros_like(slices), slices) - size_placeholder = x.size_placeholder.copy() - if isinstance(size, tf.Tensor): - size_placeholder[0] = tf.maximum(seq_lens - tf.reshape(start, tf.shape(seq_lens)), 0) - tag = DimensionTag( - description="sliced-time:%s" % self.get_absolute_name(), - kind=DimensionTag.Types.Spatial, batch=self.output.batch) - tag.set_tag_on_size_tensor(size_placeholder[0]) + # broadcast from (B,) to the shape of the indices + seq_lens = expand_multiple_dims( # e.g. (B,1) or (B,1,1) + x=seq_lens, + axes=[-1] * (len(gather_positions.shape) - len(seq_lens.shape))) + pad_mask = tf.logical_or( # shape like gather_positions + tf.greater(gather_positions, seq_lens - 1), + tf.less(gather_positions, 0)) + gather_positions = tf.clip_by_value(gather_positions, 0, seq_lens - 1) else: - assert isinstance(size, int) - size_placeholder.pop(0, None) # static time axis - self.output.size_placeholder = size_placeholder - self.output.placeholder = slices + pad_mask = tf.logical_or( # shape like gather_positions + tf.greater(gather_positions, x.batch_shape[1] - 1), + tf.less(gather_positions, 0)) + gather_positions = tf.clip_by_value(gather_positions, 0, x.batch_shape[1] - 1) + gather_positions_data.placeholder = gather_positions + position = InternalLayer( + network=self.network, + name="%s_internal" % gather_positions_data.name, + output=gather_positions_data) + gather_layer = GatherLayer( + name="%s_gather" % self.name, + network=self.network, + output=self.output, + sources=self.sources, + position=position, + axis=x.get_time_dim_tag()) + placeholder = gather_layer.output.placeholder + # In principle, the padded frames are being ignored + # (unless get_padding_info_dict_ref et al are used). + # However, you can still end up with gradients for them + # in unexpected ways. + # Due to our gather implementation, + # the gradient flow would go into wrong frames + # and might lead to unexpected behavior. + # So to be on the safe side, we do the masking here. + pad_mask = expand_multiple_dims(pad_mask, [-1] * (len(placeholder.shape) - len(pad_mask.shape))) + self.output.placeholder = where_bc(pad_mask, tf.zeros_like(placeholder), placeholder) def get_dep_layers(self): """ @@ -909,11 +962,24 @@ def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, **kwarg :rtype: Data """ from ..util.data import DimensionTag - input_data = get_concat_sources_data_template(sources).copy_as_batch_spatial_major() - if start: - input_data.beam = SearchBeam.get_combined_beam(input_data.beam, start.output.beam) - new_dim_tag = DimensionTag(kind=DimensionTag.Types.Spatial, description="%s:slice_nd" % name, dimension=size) - return input_data.copy_template_replace_dim_tag(axis=1, new_dim_tag=new_dim_tag, name="%s_output" % name) + start_data = start.output.copy_as_batch_major() + input_data = sources[0].output.copy_as_batch_major() + gather_positions_data = start_data.copy_template(name="%s_gather_positions" % name) + # size might be None here in which case we set the dyn_size in __init__ + tag = DimensionTag( + kind=DimensionTag.Types.Spatial, + description="sliced-time:%s" % name, + dimension=size) + gather_positions_data = gather_positions_data.copy_add_dim_by_tag(tag, unbroadcast=True, axis=start_data.batch_ndim) + position = InternalLayer( + network=sources[0].network, + name="%s_internal" % gather_positions_data.name, + output=gather_positions_data) + return GatherLayer.get_out_data_from_opts( + name="%s_gather" % name, + sources=sources, + position=position, + axis=input_data.get_time_dim_tag()) @classmethod def transform_config_dict(cls, d, network, get_layer): diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index b9ab7164d..5c752b982 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -2662,6 +2662,54 @@ def test_SliceNdLayer_dyn_size(): numpy.testing.assert_equal(orig_seq[t], out[b, t]) +def test_SliceNdLayer_multidimensional_start(): + with make_scope() as session: + n_out = 5 + n_batch = 3 + max_seq_len = 10 + config = Config({ + "debug_print_layer_output_template": True, + "extern_data": { + "data": {"dim": n_out}, + "classes": {"dim": n_out, "sparse": True} + }}) + net = TFNetwork(config=config, train_flag=True) + net.construct_from_dict({ + "output": { + "class": "rec", "from": "data:data", "unit": { + "start": {"class": "copy", "from": "prev:choice"}, + "slices": {"class": "slice_nd", "from": "base:data:data", "start": "start", "size": None}, + "output": {"class": "reduce", "from": "slices", "mode": "max", "axes": "dyn:-1"}, + "prob": {"class": "softmax", "from": "data:source", "target": "classes", "loss": "ce"}, + 'choice': { + 'class': 'choice', 'target': "classes", 'beam_size': 3, 'from': "prob", "input_type": "prob", + "initial_output": 0,}}}}) + session.run(tf_compat.v1.global_variables_initializer()) + output_layer = net.layers["output"] + starts = output_layer.cell.output_layers_net.layers["start"].output.get_placeholder_as_batch_major() + segments = output_layer.cell.output_layers_net.layers["slices"].output.get_placeholder_as_batch_major() + feed = make_feed_dict(net.extern_data.data.values(), n_batch=n_batch, n_time=max_seq_len, same_time=True) + starts = session.run(starts, feed_dict=feed) + segments = session.run(segments, feed_dict=feed) + seq_lens = feed[net.extern_data.data["data"].size_placeholder[0]] + input_data = feed[net.extern_data.data["data"].placeholder] + max_size = numpy.amax(seq_lens[:, None] - starts) + max_size = max(max_size, 0) + assert segments.shape == (n_batch, max_seq_len, max_size, n_out) + for b in range(n_batch): + for t in range(max_seq_len): + s = starts[b, t] + orig_seq = input_data[b, s:] + if len(orig_seq) < max_size: + orig_seq = numpy.pad(orig_seq, [(0, max_size - len(orig_seq)), (0, 0)], "constant") + elif len(orig_seq) > max_size: + orig_seq = orig_seq[:max_size] + assert orig_seq.shape == (max_size, n_out) + orig_seq = numpy.where((numpy.arange(s, s + max_size) >= seq_lens[b])[:, None], 0.0, orig_seq) + for t2 in range(max_size): + numpy.testing.assert_equal(orig_seq[t2], segments[b, t, t2]) + + def test_WindowLayer_output_placeholder(): with make_scope() as session: net = TFNetwork(extern_data=ExternData())