Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,17 @@ def __init__(self, position, axis, **kwargs):
# (BatchAxes.., InputAxesBeforeGatherAxis, PositionAxes.., InputAxesAfterGatherAxis..)
self.output.placeholder = tf.gather(params=params, indices=indices, axis=gather_axis, batch_dims=batch_dims)

if input_data.dim_tags[old_gather_axis].is_batch_dim():
for axis in self.output.size_placeholder:
new_size = tf.gather(params=self.output.size_placeholder[axis], indices=position_data.placeholder)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You assume that position_data is of shape [new-batch]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, in the case I have in mind yes. But for the failing test case, this is different and we need to take this into account.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is it in that case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There it's of shape [B,T,F], however, in the input B and T are packed

>>> input_data                                                                                                                                                                                                                  
Data{'flat_output', [B&Packed{'time'},F|F'feature'(5)]}
>>> self.output                                                                                                                                                                                                                 
Data{'output_output', [B,T|'time'[B],'other-spatial'(7),F|F'feature'(5)]}
>>> position_data                                                                                                                                                                                                               
Data{'indices_flat_output', [B,T|'time'[B],F|'other-spatial'(7)], dtype='int32'}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which test case is that? The one you added? test_rand_indices?
Why is position_data of this shape? As described, it should have some new-batch dim in it, right? Or basically just the shape [new-batch]? When you gather into the batch dim. It definitely should not have the old batch dim in its shape.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that I need to assign it for output. But it should come from position_data, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it None for position_data? I don't mean in the test case, I mean in the real case which motivated this test case. In the real case, you would not have such InternalLayer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should never be done if the data has a batch dim, unless sth is wrong. In case of the test case, then the test case is buggy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this is about the test case. However, in the case that I'm interested in, still input_data.batch == position_data.batch is True. This is probably because I'm using an EvalLayer to get the batch indices from a 0/1 vector with shape (B,) and that EvalLayer does not set the output correctly. Then we would need a layer which does that correctly, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An EvalLayer should never change the shape. If it does, and you are not very careful in setting the output data, then yes, this is a bug in your config.

from ..util.data import Dim
Dim(
kind=Dim.Types.Spatial, description="%s_gather_axis" % self.name,
dyn_size=new_size, batch=self.output.batch,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You miss dyn_size_ext here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should not set dyn_size in case it is non-standard. Set dyn_size_ext instead.

src_data=self.output, src_axis=axis, auto_generated=True)
self.output.size_placeholder[axis] = new_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't use the Dim object you created?
Instead of assigning size_placeholder, I think it would be better to set the newly created dim tag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the usual way to set dim tags? I can't just reassign self.output.dim_tags. declare_same_as is used elsewhere, but not sure if it applies here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See most other layers. Usually you set dim_tags in get_out_data_from_opts. You should not assign a new dim tag in __init__. In __init__, you just might to assign the dyn_size_ext or dyn_size_ext.placeholder of a dim tag which was previously newly created in get_out_data_from_opts.



@classmethod
def _get_common_input_position_axes(cls, input_data, position_data, old_gather_axis):
"""
Expand Down
101 changes: 101 additions & 0 deletions tests/test_TFNetworkLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5456,6 +5456,51 @@ def test_GatherLayer_broadcast_dim():
})


def test_GatherLayer_batch_dim():
with make_scope() as session:
import numpy as np
net = TFNetwork(extern_data=ExternData())
batch_dim, time_dim, feature_dim = 3, 4, 2
# [B, T, F]
random = np.random.RandomState(42)
values_seqs = random.rand(batch_dim, time_dim, feature_dim).astype('float32')
values_size = np.array([4, 2, 3])
values_placeholder = tf.constant(values_seqs, dtype=tf.float32)
values_size_placeholder = {0: tf.constant(values_size, dtype=tf.int32)}
values = InternalLayer(
name="values", network=net,
output=Data(
name="values",
batch_dim_axis=0, time_dim_axis=1, feature_dim_axis=2,
shape=[None, feature_dim],
placeholder=values_placeholder,
size_placeholder=values_size_placeholder,
))
position_np = np.array([0, 2])
position = InternalLayer(
name="position", network=net,
output=Data(
name="position",
placeholder=tf.constant(position_np, dtype=tf.int64),
batch_dim_axis=0, shape=[], dtype="int64",
))
Comment on lines +5482 to +5488
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albertz do I need to change the creation of position in order to make it have a different batch axis dim tag here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It's actually not so simple because of the special treatment of the batch dim tag. I'm not sure it's really possible currently.

In practice, in your real code, how would you end up with position?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice, in your real code, how would you end up with position?
Do you mean what dim tag I get there?

>>> position.output.dim_tags[0].description                                                                                                                                                                                     
'batch:position'

So it's not actually the global batch dim. I was just confused because I got

>>> position.output.dim_tags[0] == values.output.dim_tags[0]                                                                                                                                                                    
True

but this is because the check does not cover this case, see comment here: #1089 (comment)

values.output.sanity_check()
position.output.sanity_check()

# should become [B', T, F]
layer = GatherLayer(
name="gather", network=net,
sources=[values], position=position, axis="B",
output=GatherLayer.get_out_data_from_opts(
name="gather", sources=[values], position=position, axis="B"))
layer.output.sanity_check()
out_seqs, out_size = session.run([layer.output.placeholder, layer.output.size_placeholder.as_dict()])
assert isinstance(out_seqs, numpy.ndarray)

np.testing.assert_equal(values_seqs[position_np, :], out_seqs)
np.testing.assert_equal(values_size[position_np], out_size[0])


def test_SliceNdLayer():
n_batch = 5
n_time = 7
Expand Down Expand Up @@ -9213,6 +9258,62 @@ def _get_mask_eval_layer(source, **_kwargs):
assert numpy.isfinite(loss_v)


def test_supervised_multilingual_training():
n_batch = 3
out_dim = 5

net_dict = {
"encoder": {"class": "copy", "from": "data"},

"loss_language_0": {
"class": "subnetwork",
"from": "data:classes",
"is_output_layer": True,
"subnetwork": {
# create language id indices based on where we see non-silence in the alignment.
# this should be done in the dataset in a neat way.
"aux_0": {"class": "gather", "position": 0, "axis": "T", "from": "data"},
"aux_1": {"class": "compare", "kind": "greater", "value": 2, "from": "aux_0"},
"idx": { # shape (B',) with B' < B (B' is the number of utterances in the batch for this language)
"class": "eval",
"eval": "tf.squeeze(tf.where(source(0)), axis=-1)",
"out_type": {"dtype": "int64"},
"from": "aux_1",
},

# gather targets and encoder outputs
"tgt": {"class": "gather", "from": "data", "axis": "B", "position": "idx"}, # B', T (sparse)
"enc_raw": {"class": "gather", "from": "base:encoder", "axis": "B", "position": "idx"}, # B', T, F
"enc": {"class": "reinterpret_data", "size_base": "tgt", "from": "enc_raw"}, # B', T, F
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

"logits": {"class": "linear", "n_out": out_dim, "from": "enc"}, # B', T, n_out
"output": {
"class": "sparse_softmax_cross_entropy_with_logits",
"logits": "logits",
"targets": "tgt",
"loss": "as_is",
},
},
},
}

with make_scope() as session:
config = Config({
"extern_data": {
"data": {"dim": 7, "batch_dim_axis": 0, "time_dim_axis": 1, "feature_dim_axis": 2},
"classes": {"dim": out_dim, "sparse": True}},
"debug_print_layer_output_shape": True, # "debug_print_layer_output": True,
})
net = TFNetwork(config=config, train_flag=True)
net.construct_from_dict(net_dict)
loss = net.get_total_loss()

tf_compat.v1.set_random_seed(1)
net.initialize_params(session)
loss_v = session.run(loss, feed_dict=make_feed_dict(net.extern_data, same_time=True, n_batch=n_batch))
print("loss:", loss_v)
assert numpy.isfinite(loss_v)


if __name__ == "__main__":
try:
better_exchook.install()
Expand Down