@@ -8543,12 +8543,7 @@ def __init__(self, new_dim, **kwargs):
85438543 """
85448544 super (CumConcatLayer , self ).__init__ (** kwargs )
85458545 assert self .network .is_inside_rec_layer ()
8546- out_axis = None
8547- for a , tag in enumerate (self .output .dim_tags ):
8548- if tag == new_dim :
8549- out_axis = a
8550- break
8551- assert out_axis is not None
8546+ out_axis = self .output .get_axis_from_description (new_dim )
85528547
85538548 if self .network .is_inside_rec_layer (inside_loop = True ):
85548549 current_data = self .input_data .copy_compatible_to (self .output , unbroadcast = False )
@@ -8584,27 +8579,34 @@ def get_out_data_from_opts(cls, name, network, sources, new_dim, **kwargs):
85848579 :rtype: Data
85858580 """
85868581 rec_layer = network .get_rec_parent_layer (inside_loop = False )
8587- assert rec_layer , "This must be inside the loop"
8582+ assert rec_layer , "CumConcatLayer %r must be used inside a RecLayer" % name
8583+ new_dim_base = new_dim .get_same_base ()
8584+ if new_dim_base .per_spatial_frame is None :
8585+ new_dim_base .per_spatial_frame = rec_layer .time_dim_tag
8586+ else :
8587+ assert new_dim_base .per_spatial_frame == rec_layer .time_dim_tag
8588+
85888589 input_data = get_concat_sources_data_template (sources , name = "%s_output" % name )
85898590 if network .is_inside_rec_layer (inside_loop = True ):
85908591 # Currently SelectSearchSourcesLayer assumes that all rec_vars_outputs are batch-major.
85918592 # Therefore we here copy the input as batch-major, and then add the time axis at axis 1.
85928593 # In the future, when SelectSearchSourcesLayer has support for this, we can change this to operate on axis 0,
85938594 # which should be more efficient
85948595 out = input_data .copy_as_batch_major ()
8595- out = out .copy_add_dim_by_tag (new_dim , unbroadcast = True , axis = 1 )
8596- # TODO set new_dim per spatial frame ...
8596+ out = out .copy_add_dim_by_tag (new_dim_base , unbroadcast = True , axis = 1 )
85978597 return out
85988598
85998599 else : # outside loop
8600- out = input_data .copy_as_batch_major ()
8601- rec_time = rec_layer .output .get_time_dim_tag ()
8602- _matches = [i for (i , tag ) in enumerate (out .dim_tags ) if tag == rec_time ]
8603- assert len (_matches ) == 1
8604- out = out .copy_move_axis (_matches [0 ], 1 )
8605- # TODO use separate new_dim outside loop ...
8606- out = out .copy_template_replace_dim_tag (axis = 1 , new_dim_tag = new_dim )
8607- return out
8600+ if not new_dim_base .per_spatial_frame_accumulated :
8601+ new_dim_accum = DimensionTag (
8602+ kind = new_dim_base .kind , description = "%s:accumulated" % name )
8603+ new_dim_accum .same_as = new_dim_base
8604+ new_dim_base .per_spatial_frame_accumulated = new_dim_accum
8605+ else :
8606+ new_dim_accum = new_dim_base .per_spatial_frame_accumulated
8607+ # Assume that the input has the time dim from the rec layer.
8608+ axis = input_data .get_axis_from_description (rec_layer .time_dim_tag )
8609+ return input_data .copy_template_replace_dim_tag (axis = axis , new_dim_tag = new_dim_accum )
86088610
86098611 # noinspection PyMethodOverriding
86108612 @classmethod
0 commit comments