-
Notifications
You must be signed in to change notification settings - Fork 133
GatherLayer on batch axis #1089
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
a983d80
3dc0018
94c8476
09c357a
151f498
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1341,6 +1341,19 @@ def __init__(self, position, axis, **kwargs): | |
# (BatchAxes.., InputAxesBeforeGatherAxis, PositionAxes.., InputAxesAfterGatherAxis..) | ||
self.output.placeholder = tf.gather(params=params, indices=indices, axis=gather_axis, batch_dims=batch_dims) | ||
|
||
if input_data.dim_tags[old_gather_axis].is_batch_dim(): | ||
for dim_tag in self.output.dim_tags: | ||
if dim_tag.is_spatial_dim(): | ||
axis = self.output.get_batch_axis_excluding_batch(self.output.get_axis_by_tag_name(dim_tag.description)) | ||
new_size = tf.gather(params=self.output.size_placeholder[axis], indices=position_data.placeholder) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should not access |
||
from ..util.data import Dim | ||
Dim( | ||
kind=Dim.Types.Spatial, description="%s_gather_axis" % self.name, | ||
dyn_size=new_size, batch=self.output.batch, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should not assign |
||
src_data=self.output, src_axis=axis, auto_generated=True) | ||
self.output.size_placeholder[axis] = new_size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should not assign |
||
|
||
|
||
@classmethod | ||
def _get_common_input_position_axes(cls, input_data, position_data, old_gather_axis): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5456,6 +5456,53 @@ def test_GatherLayer_broadcast_dim(): | |
}) | ||
|
||
|
||
def test_GatherLayer_batch_dim(): | ||
from returnn.tf.util.data import batch_dim | ||
with make_scope() as session: | ||
import numpy as np | ||
net = TFNetwork(extern_data=ExternData()) | ||
b_dim, time_dim, feature_dim = 3, 4, 2 | ||
# [B, T, F] | ||
random = np.random.RandomState(42) | ||
values_seqs = random.rand(b_dim, time_dim, feature_dim).astype('float32') | ||
values_size = np.array([4, 2, 3]) | ||
values_placeholder = tf.constant(values_seqs, dtype=tf.float32) | ||
values_size_placeholder = {0: tf.constant(values_size, dtype=tf.int32)} | ||
values = InternalLayer( | ||
name="values", network=net, | ||
output=Data( | ||
name="values", | ||
batch_dim_axis=0, time_dim_axis=1, feature_dim_axis=2, | ||
shape=[None, feature_dim], | ||
placeholder=values_placeholder, | ||
size_placeholder=values_size_placeholder, | ||
dim_tags=(batch_dim, SpatialDim("time"), FeatureDim("feature", feature_dim)), | ||
)) | ||
position_np = np.array([0, 2]) | ||
position = InternalLayer( | ||
name="position", network=net, | ||
output=Data( | ||
name="position", | ||
placeholder=tf.constant(position_np, dtype=tf.int64), | ||
batch_dim_axis=0, shape=[], dtype="int64", | ||
)) | ||
Comment on lines
+5482
to
+5488
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @albertz do I need to change the creation of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. It's actually not so simple because of the special treatment of the batch dim tag. I'm not sure it's really possible currently. In practice, in your real code, how would you end up with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
So it's not actually the global batch dim. I was just confused because I got
but this is because the check does not cover this case, see comment here: #1089 (comment) |
||
values.output.sanity_check() | ||
position.output.sanity_check() | ||
|
||
# should become [B', T, F] | ||
layer = GatherLayer( | ||
name="gather", network=net, | ||
sources=[values], position=position, axis="B", | ||
output=GatherLayer.get_out_data_from_opts( | ||
name="gather", sources=[values], position=position, axis="B")) | ||
layer.output.sanity_check() | ||
out_seqs, out_size = session.run([layer.output.placeholder, layer.output.size_placeholder.as_dict()]) | ||
assert isinstance(out_seqs, numpy.ndarray) | ||
|
||
np.testing.assert_equal(values_seqs[position_np, :], out_seqs) | ||
np.testing.assert_equal(values_size[position_np], out_size[0]) | ||
|
||
|
||
def test_SliceNdLayer(): | ||
n_batch = 5 | ||
n_time = 7 | ||
|
@@ -9213,6 +9260,62 @@ def _get_mask_eval_layer(source, **_kwargs): | |
assert numpy.isfinite(loss_v) | ||
|
||
|
||
def test_supervised_multilingual_training(): | ||
n_batch = 3 | ||
out_dim = 5 | ||
|
||
net_dict = { | ||
"encoder": {"class": "copy", "from": "data"}, | ||
|
||
"loss_language_0": { | ||
"class": "subnetwork", | ||
"from": "data:classes", | ||
"is_output_layer": True, | ||
"subnetwork": { | ||
# create language id indices based on where we see non-silence in the alignment. | ||
# this should be done in the dataset in a neat way. | ||
"aux_0": {"class": "gather", "position": 0, "axis": "T", "from": "data"}, | ||
"aux_1": {"class": "compare", "kind": "greater", "value": 2, "from": "aux_0"}, | ||
"idx": { # shape (B',) with B' < B (B' is the number of utterances in the batch for this language) | ||
"class": "eval", | ||
"eval": "tf.squeeze(tf.where(source(0)), axis=-1)", | ||
"out_type": {"dtype": "int64"}, | ||
"from": "aux_1", | ||
}, | ||
|
||
# gather targets and encoder outputs | ||
"tgt": {"class": "gather", "from": "data", "axis": "B", "position": "idx"}, # B', T (sparse) | ||
"enc_raw": {"class": "gather", "from": "base:encoder", "axis": "B", "position": "idx"}, # B', T, F | ||
"enc": {"class": "reinterpret_data", "size_base": "tgt", "from": "enc_raw"}, # B', T, F | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this needed? |
||
"logits": {"class": "linear", "n_out": out_dim, "from": "enc"}, # B', T, n_out | ||
"output": { | ||
"class": "sparse_softmax_cross_entropy_with_logits", | ||
"logits": "logits", | ||
"targets": "tgt", | ||
"loss": "as_is", | ||
}, | ||
}, | ||
}, | ||
} | ||
|
||
with make_scope() as session: | ||
config = Config({ | ||
"extern_data": { | ||
"data": {"dim": 7, "batch_dim_axis": 0, "time_dim_axis": 1, "feature_dim_axis": 2}, | ||
"classes": {"dim": out_dim, "sparse": True}}, | ||
"debug_print_layer_output_shape": True, # "debug_print_layer_output": True, | ||
}) | ||
net = TFNetwork(config=config, train_flag=True) | ||
net.construct_from_dict(net_dict) | ||
loss = net.get_total_loss() | ||
|
||
tf_compat.v1.set_random_seed(1) | ||
net.initialize_params(session) | ||
loss_v = session.run(loss, feed_dict=make_feed_dict(net.extern_data, same_time=True, n_batch=n_batch)) | ||
print("loss:", loss_v) | ||
assert numpy.isfinite(loss_v) | ||
|
||
|
||
if __name__ == "__main__": | ||
try: | ||
better_exchook.install() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is:
for axis, dim_tag in enumerate(self.output.dim_tags)
get_axis_by_tag_name
anddim_tag.description
dim_tag.dyn_size_ext