Skip to content

Commit 210cd38

Browse files
committed
CumConcatLayer more wip
1 parent 305b057 commit 210cd38

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

returnn/tf/layers/rec.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8543,12 +8543,7 @@ def __init__(self, new_dim, **kwargs):
85438543
"""
85448544
super(CumConcatLayer, self).__init__(**kwargs)
85458545
assert self.network.is_inside_rec_layer()
8546-
out_axis = None
8547-
for a, tag in enumerate(self.output.dim_tags):
8548-
if tag == new_dim:
8549-
out_axis = a
8550-
break
8551-
assert out_axis is not None
8546+
out_axis = self.output.get_axis_from_description(new_dim)
85528547

85538548
if self.network.is_inside_rec_layer(inside_loop=True):
85548549
current_data = self.input_data.copy_compatible_to(self.output, unbroadcast=False)
@@ -8584,27 +8579,34 @@ def get_out_data_from_opts(cls, name, network, sources, new_dim, **kwargs):
85848579
:rtype: Data
85858580
"""
85868581
rec_layer = network.get_rec_parent_layer(inside_loop=False)
8587-
assert rec_layer, "This must be inside the loop"
8582+
assert rec_layer, "CumConcatLayer %r must be used inside a RecLayer" % name
8583+
new_dim_base = new_dim.get_same_base()
8584+
if new_dim_base.per_spatial_frame is None:
8585+
new_dim_base.per_spatial_frame = rec_layer.time_dim_tag
8586+
else:
8587+
assert new_dim_base.per_spatial_frame == rec_layer.time_dim_tag
8588+
85888589
input_data = get_concat_sources_data_template(sources, name="%s_output" % name)
85898590
if network.is_inside_rec_layer(inside_loop=True):
85908591
# Currently SelectSearchSourcesLayer assumes that all rec_vars_outputs are batch-major.
85918592
# Therefore we here copy the input as batch-major, and then add the time axis at axis 1.
85928593
# In the future, when SelectSearchSourcesLayer has support for this, we can change this to operate on axis 0,
85938594
# which should be more efficient
85948595
out = input_data.copy_as_batch_major()
8595-
out = out.copy_add_dim_by_tag(new_dim, unbroadcast=True, axis=1)
8596-
# TODO set new_dim per spatial frame ...
8596+
out = out.copy_add_dim_by_tag(new_dim_base, unbroadcast=True, axis=1)
85978597
return out
85988598

85998599
else: # outside loop
8600-
out = input_data.copy_as_batch_major()
8601-
rec_time = rec_layer.output.get_time_dim_tag()
8602-
_matches = [i for (i, tag) in enumerate(out.dim_tags) if tag == rec_time]
8603-
assert len(_matches) == 1
8604-
out = out.copy_move_axis(_matches[0], 1)
8605-
# TODO use separate new_dim outside loop ...
8606-
out = out.copy_template_replace_dim_tag(axis=1, new_dim_tag=new_dim)
8607-
return out
8600+
if not new_dim_base.per_spatial_frame_accumulated:
8601+
new_dim_accum = DimensionTag(
8602+
kind=new_dim_base.kind, description="%s:accumulated" % name)
8603+
new_dim_accum.same_as = new_dim_base
8604+
new_dim_base.per_spatial_frame_accumulated = new_dim_accum
8605+
else:
8606+
new_dim_accum = new_dim_base.per_spatial_frame_accumulated
8607+
# Assume that the input has the time dim from the rec layer.
8608+
axis = input_data.get_axis_from_description(rec_layer.time_dim_tag)
8609+
return input_data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=new_dim_accum)
86088610

86098611
# noinspection PyMethodOverriding
86108612
@classmethod

0 commit comments

Comments
 (0)