@@ -1661,42 +1661,50 @@ class LengthLayer(LayerBase):
16611661  layer_class  =  "length" 
16621662
16631663  # noinspection PyUnusedLocal 
1664-   def  __init__ (self , add_time_axis = False , dtype = "int32" , sparse = False , ** kwargs ):
1664+   def  __init__ (self , axis = "T" ,  add_time_axis = False , dtype = "int32" , sparse = False , ** kwargs ):
16651665    """ 
1666+     :param str|DimensionTag axis: 
16661667    :param bool add_time_axis: 
16671668    :param str dtype: 
16681669    :param bool sparse: 
16691670    """ 
16701671    super (LengthLayer , self ).__init__ (** kwargs )
16711672    assert  len (self .sources ) ==  1 , "%s: expects one source"  %  self 
1672-     out  =  tf .cast (self .sources [0 ].output .get_sequence_lengths (), dtype )
1673+     source  =  self .sources [0 ].output 
1674+     axis  =  source .get_axis_from_description (axis , allow_int = False )
1675+     dim  =  source .dim_tags [axis ]
16731676    if  add_time_axis :
1674-       out  =  tf .expand_dims (out , axis = self .output .time_dim_axis )
1675-     self .output .placeholder  =  out 
1677+       self .output .placeholder  =  tf .expand_dims (dim .dyn_size , axis = self .output .time_dim_axis )
1678+     else :
1679+       self .output .placeholder  =  dim .dyn_size_ext .placeholder 
16761680
16771681  @classmethod  
1678-   def  get_out_data_from_opts (cls , name , sources , add_time_axis = False , dtype = "int32" , sparse = False , ** kwargs ):
1682+   def  get_out_data_from_opts (cls , name , sources , axis = "T" ,  add_time_axis = False , dtype = "int32" , sparse = False , ** kwargs ):
16791683    """ 
16801684    :param str name: 
16811685    :param list[LayerBase] sources: 
1686+     :param str|DimensionTag axis: 
16821687    :param bool add_time_axis: 
16831688    :param str dtype: 
16841689    :param bool sparse: 
16851690    :rtype: Data 
16861691    """ 
1692+     assert  len (sources ) ==  1 
1693+     source  =  sources [0 ].output 
1694+     axis  =  source .get_axis_from_description (axis , allow_int = False )
1695+     dim  =  source .dim_tags [axis ]
16871696    if  add_time_axis :
1688-       shape  =  (1 ,)
1689-       time_dim_axis  =  1 
1690-     else :
1691-       shape  =  ()
1692-       time_dim_axis  =  None 
1693-     return  Data (
1694-       name = "%s_length"  %  name ,
1695-       shape = shape ,
1696-       batch_dim_axis = 0 ,
1697-       time_dim_axis = time_dim_axis ,
1698-       dtype = dtype ,
1699-       sparse = sparse , dim = None  if  sparse  else  NotSpecified )
1697+       assert  dim .dyn_size_ext  and  dim .dyn_size_ext .have_batch_axis () and  dim .dyn_size_ext .batch_ndim  ==  1   # [B] 
1698+       return  Data (
1699+         name = "%s_length"  %  name ,
1700+         shape = [1 ], batch_dim_axis = 0 , time_dim_axis = 1 ,
1701+         dtype = dtype , sparse = sparse , dim = None  if  sparse  else  NotSpecified )
1702+     if  not  dim .dyn_size_ext :  # yet undefined 
1703+       return  Data (
1704+         name = "%s_length"  %  name ,
1705+         shape = (), batch_dim_axis = 0 , time_dim_axis = None ,
1706+         dtype = dtype , sparse = sparse , dim = None  if  sparse  else  NotSpecified )
1707+     return  dim .dyn_size_ext 
17001708
17011709
17021710class  SoftmaxOverSpatialLayer (_ConcatInputLayer ):
0 commit comments