Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,19 @@ def __init__(self, position, axis, **kwargs):
# (BatchAxes.., InputAxesBeforeGatherAxis, PositionAxes.., InputAxesAfterGatherAxis..)
self.output.placeholder = tf.gather(params=params, indices=indices, axis=gather_axis, batch_dims=batch_dims)

if input_data.dim_tags[old_gather_axis].is_batch_dim():
for dim_tag in self.output.dim_tags:
if dim_tag.is_spatial_dim():
axis = self.output.get_batch_axis_excluding_batch(self.output.get_axis_by_tag_name(dim_tag.description))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is:

  • way too complicated: you can simply do for axis, dim_tag in enumerate(self.output.dim_tags)
  • wrong: do not rely on get_axis_by_tag_name and dim_tag.description
  • not necessary: just use dim_tag.dyn_size_ext

new_size = tf.gather(params=self.output.size_placeholder[axis], indices=position_data.placeholder)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should not access size_placeholder but rather dim_tag.dyn_size_ext.

from ..util.data import Dim
Dim(
kind=Dim.Types.Spatial, description="%s_gather_axis" % self.name,
dyn_size=new_size, batch=self.output.batch,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should not assign dyn_size but rather dyn_size_ext.

src_data=self.output, src_axis=axis, auto_generated=True)
self.output.size_placeholder[axis] = new_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should not assign size_placeholder but rather the dim tags.



@classmethod
def _get_common_input_position_axes(cls, input_data, position_data, old_gather_axis):
"""
Expand Down
103 changes: 103 additions & 0 deletions tests/test_TFNetworkLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5456,6 +5456,53 @@ def test_GatherLayer_broadcast_dim():
})


def test_GatherLayer_batch_dim():
from returnn.tf.util.data import batch_dim
with make_scope() as session:
import numpy as np
net = TFNetwork(extern_data=ExternData())
b_dim, time_dim, feature_dim = 3, 4, 2
# [B, T, F]
random = np.random.RandomState(42)
values_seqs = random.rand(b_dim, time_dim, feature_dim).astype('float32')
values_size = np.array([4, 2, 3])
values_placeholder = tf.constant(values_seqs, dtype=tf.float32)
values_size_placeholder = {0: tf.constant(values_size, dtype=tf.int32)}
values = InternalLayer(
name="values", network=net,
output=Data(
name="values",
batch_dim_axis=0, time_dim_axis=1, feature_dim_axis=2,
shape=[None, feature_dim],
placeholder=values_placeholder,
size_placeholder=values_size_placeholder,
dim_tags=(batch_dim, SpatialDim("time"), FeatureDim("feature", feature_dim)),
))
position_np = np.array([0, 2])
position = InternalLayer(
name="position", network=net,
output=Data(
name="position",
placeholder=tf.constant(position_np, dtype=tf.int64),
batch_dim_axis=0, shape=[], dtype="int64",
))
Comment on lines +5482 to +5488
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albertz do I need to change the creation of position in order to make it have a different batch axis dim tag here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It's actually not so simple because of the special treatment of the batch dim tag. I'm not sure it's really possible currently.

In practice, in your real code, how would you end up with position?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice, in your real code, how would you end up with position?
Do you mean what dim tag I get there?

>>> position.output.dim_tags[0].description                                                                                                                                                                                     
'batch:position'

So it's not actually the global batch dim. I was just confused because I got

>>> position.output.dim_tags[0] == values.output.dim_tags[0]                                                                                                                                                                    
True

but this is because the check does not cover this case, see comment here: #1089 (comment)

values.output.sanity_check()
position.output.sanity_check()

# should become [B', T, F]
layer = GatherLayer(
name="gather", network=net,
sources=[values], position=position, axis="B",
output=GatherLayer.get_out_data_from_opts(
name="gather", sources=[values], position=position, axis="B"))
layer.output.sanity_check()
out_seqs, out_size = session.run([layer.output.placeholder, layer.output.size_placeholder.as_dict()])
assert isinstance(out_seqs, numpy.ndarray)

np.testing.assert_equal(values_seqs[position_np, :], out_seqs)
np.testing.assert_equal(values_size[position_np], out_size[0])


def test_SliceNdLayer():
n_batch = 5
n_time = 7
Expand Down Expand Up @@ -9213,6 +9260,62 @@ def _get_mask_eval_layer(source, **_kwargs):
assert numpy.isfinite(loss_v)


def test_supervised_multilingual_training():
n_batch = 3
out_dim = 5

net_dict = {
"encoder": {"class": "copy", "from": "data"},

"loss_language_0": {
"class": "subnetwork",
"from": "data:classes",
"is_output_layer": True,
"subnetwork": {
# create language id indices based on where we see non-silence in the alignment.
# this should be done in the dataset in a neat way.
"aux_0": {"class": "gather", "position": 0, "axis": "T", "from": "data"},
"aux_1": {"class": "compare", "kind": "greater", "value": 2, "from": "aux_0"},
"idx": { # shape (B',) with B' < B (B' is the number of utterances in the batch for this language)
"class": "eval",
"eval": "tf.squeeze(tf.where(source(0)), axis=-1)",
"out_type": {"dtype": "int64"},
"from": "aux_1",
},

# gather targets and encoder outputs
"tgt": {"class": "gather", "from": "data", "axis": "B", "position": "idx"}, # B', T (sparse)
"enc_raw": {"class": "gather", "from": "base:encoder", "axis": "B", "position": "idx"}, # B', T, F
"enc": {"class": "reinterpret_data", "size_base": "tgt", "from": "enc_raw"}, # B', T, F
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

"logits": {"class": "linear", "n_out": out_dim, "from": "enc"}, # B', T, n_out
"output": {
"class": "sparse_softmax_cross_entropy_with_logits",
"logits": "logits",
"targets": "tgt",
"loss": "as_is",
},
},
},
}

with make_scope() as session:
config = Config({
"extern_data": {
"data": {"dim": 7, "batch_dim_axis": 0, "time_dim_axis": 1, "feature_dim_axis": 2},
"classes": {"dim": out_dim, "sparse": True}},
"debug_print_layer_output_shape": True, # "debug_print_layer_output": True,
})
net = TFNetwork(config=config, train_flag=True)
net.construct_from_dict(net_dict)
loss = net.get_total_loss()

tf_compat.v1.set_random_seed(1)
net.initialize_params(session)
loss_v = session.run(loss, feed_dict=make_feed_dict(net.extern_data, same_time=True, n_batch=n_batch))
print("loss:", loss_v)
assert numpy.isfinite(loss_v)


if __name__ == "__main__":
try:
better_exchook.install()
Expand Down