diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 91eccc986..60031a164 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -1341,6 +1341,19 @@ def __init__(self, position, axis, **kwargs): # (BatchAxes.., InputAxesBeforeGatherAxis, PositionAxes.., InputAxesAfterGatherAxis..) self.output.placeholder = tf.gather(params=params, indices=indices, axis=gather_axis, batch_dims=batch_dims) + if input_data.dim_tags[old_gather_axis].is_batch_dim(): + for dim_tag in self.output.dim_tags: + if dim_tag.is_spatial_dim(): + axis = self.output.get_batch_axis_excluding_batch(self.output.get_axis_by_tag_name(dim_tag.description)) + new_size = tf.gather(params=self.output.size_placeholder[axis], indices=position_data.placeholder) + from ..util.data import Dim + Dim( + kind=Dim.Types.Spatial, description="%s_gather_axis" % self.name, + dyn_size=new_size, batch=self.output.batch, + src_data=self.output, src_axis=axis, auto_generated=True) + self.output.size_placeholder[axis] = new_size + + @classmethod def _get_common_input_position_axes(cls, input_data, position_data, old_gather_axis): """ diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index 952d473c5..63a874a03 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -5456,6 +5456,53 @@ def test_GatherLayer_broadcast_dim(): }) +def test_GatherLayer_batch_dim(): + from returnn.tf.util.data import batch_dim + with make_scope() as session: + import numpy as np + net = TFNetwork(extern_data=ExternData()) + b_dim, time_dim, feature_dim = 3, 4, 2 + # [B, T, F] + random = np.random.RandomState(42) + values_seqs = random.rand(b_dim, time_dim, feature_dim).astype('float32') + values_size = np.array([4, 2, 3]) + values_placeholder = tf.constant(values_seqs, dtype=tf.float32) + values_size_placeholder = {0: tf.constant(values_size, dtype=tf.int32)} + values = InternalLayer( + name="values", network=net, + output=Data( + name="values", + batch_dim_axis=0, time_dim_axis=1, feature_dim_axis=2, + shape=[None, feature_dim], + placeholder=values_placeholder, + size_placeholder=values_size_placeholder, + dim_tags=(batch_dim, SpatialDim("time"), FeatureDim("feature", feature_dim)), + )) + position_np = np.array([0, 2]) + position = InternalLayer( + name="position", network=net, + output=Data( + name="position", + placeholder=tf.constant(position_np, dtype=tf.int64), + batch_dim_axis=0, shape=[], dtype="int64", + )) + values.output.sanity_check() + position.output.sanity_check() + + # should become [B', T, F] + layer = GatherLayer( + name="gather", network=net, + sources=[values], position=position, axis="B", + output=GatherLayer.get_out_data_from_opts( + name="gather", sources=[values], position=position, axis="B")) + layer.output.sanity_check() + out_seqs, out_size = session.run([layer.output.placeholder, layer.output.size_placeholder.as_dict()]) + assert isinstance(out_seqs, numpy.ndarray) + + np.testing.assert_equal(values_seqs[position_np, :], out_seqs) + np.testing.assert_equal(values_size[position_np], out_size[0]) + + def test_SliceNdLayer(): n_batch = 5 n_time = 7 @@ -9213,6 +9260,62 @@ def _get_mask_eval_layer(source, **_kwargs): assert numpy.isfinite(loss_v) +def test_supervised_multilingual_training(): + n_batch = 3 + out_dim = 5 + + net_dict = { + "encoder": {"class": "copy", "from": "data"}, + + "loss_language_0": { + "class": "subnetwork", + "from": "data:classes", + "is_output_layer": True, + "subnetwork": { + # create language id indices based on where we see non-silence in the alignment. + # this should be done in the dataset in a neat way. + "aux_0": {"class": "gather", "position": 0, "axis": "T", "from": "data"}, + "aux_1": {"class": "compare", "kind": "greater", "value": 2, "from": "aux_0"}, + "idx": { # shape (B',) with B' < B (B' is the number of utterances in the batch for this language) + "class": "eval", + "eval": "tf.squeeze(tf.where(source(0)), axis=-1)", + "out_type": {"dtype": "int64"}, + "from": "aux_1", + }, + + # gather targets and encoder outputs + "tgt": {"class": "gather", "from": "data", "axis": "B", "position": "idx"}, # B', T (sparse) + "enc_raw": {"class": "gather", "from": "base:encoder", "axis": "B", "position": "idx"}, # B', T, F + "enc": {"class": "reinterpret_data", "size_base": "tgt", "from": "enc_raw"}, # B', T, F + "logits": {"class": "linear", "n_out": out_dim, "from": "enc"}, # B', T, n_out + "output": { + "class": "sparse_softmax_cross_entropy_with_logits", + "logits": "logits", + "targets": "tgt", + "loss": "as_is", + }, + }, + }, + } + + with make_scope() as session: + config = Config({ + "extern_data": { + "data": {"dim": 7, "batch_dim_axis": 0, "time_dim_axis": 1, "feature_dim_axis": 2}, + "classes": {"dim": out_dim, "sparse": True}}, + "debug_print_layer_output_shape": True, # "debug_print_layer_output": True, + }) + net = TFNetwork(config=config, train_flag=True) + net.construct_from_dict(net_dict) + loss = net.get_total_loss() + + tf_compat.v1.set_random_seed(1) + net.initialize_params(session) + loss_v = session.run(loss, feed_dict=make_feed_dict(net.extern_data, same_time=True, n_batch=n_batch)) + print("loss:", loss_v) + assert numpy.isfinite(loss_v) + + if __name__ == "__main__": try: better_exchook.install()