@@ -2032,12 +2032,12 @@ class RangeInAxisLayer(LayerBase):
20322032 recurrent = True # if axis=="T", the time-dim order matters
20332033
20342034 # noinspection PyUnusedLocal
2035- def __init__ (self , axis , dtype = "int32" , unbroadcast = False , keepdims = True , sparse = False , ** kwargs ):
2035+ def __init__ (self , axis , dtype = "int32" , unbroadcast = False , keepdims = False , sparse = False , ** kwargs ):
20362036 """
20372037 :param str axis:
20382038 :param str dtype:
2039- :param bool unbroadcast:
2040- :param bool keepdims:
2039+ :param bool unbroadcast: DEPRECATED, unsupported, and not needed
2040+ :param bool keepdims: DEPRECATED, unsupported, and not needed
20412041 :param bool sparse:
20422042 """
20432043 super (RangeInAxisLayer , self ).__init__ (** kwargs )
@@ -2048,47 +2048,28 @@ def __init__(self, axis, dtype="int32", unbroadcast=False, keepdims=True, sparse
20482048 source_shape = get_shape (source .placeholder )
20492049 out = tf .range (0 , source_shape [axis ], dtype = dtype )
20502050 if unbroadcast :
2051- assert keepdims
2051+ raise Exception ( "%s: do not use unbroadcast" )
20522052 if keepdims :
2053- out_shape = [
2054- source_shape [i ]
2055- if (i == axis or i == self .output .batch_dim_axis )
2056- else 1
2057- for i in range (self .output .batch_ndim )]
2058- out = tf .reshape (out , out_shape ) # add missing axes (keep_dims)
2059- if unbroadcast :
2060- out = out + tf .zeros (source_shape , dtype = out .dtype )
2053+ raise Exception ("%s: do not use keepdims" )
20612054 self .output .placeholder = out
20622055
20632056 @classmethod
2064- def get_out_data_from_opts (cls ,
2065- name , sources , axis , dtype = "int32" , unbroadcast = False , keepdims = True , sparse = False ,
2066- ** kwargs ):
2057+ def get_out_data_from_opts (cls , name , sources , axis , dtype = "int32" , sparse = False , ** kwargs ):
20672058 """
20682059 :param str name:
20692060 :param list[LayerBase] sources:
20702061 :param str axis:
20712062 :param str dtype:
2072- :param bool unbroadcast:
2073- :param bool keepdims:
20742063 :param bool sparse:
20752064 """
2076- from ..util .data import DimensionTag
20772065 assert len (sources ) == 1 , "%s layer %r requires single source" % (cls , name )
20782066 source = sources [0 ].output
20792067 axis = source .get_axis_from_description (axis )
2080- if keepdims :
2081- data_opts = source .get_kwargs (include_special_axes = True )
2082- dim_tags = [
2083- tag if (i == axis or tag .is_batch_dim () or unbroadcast )
2084- else DimensionTag (kind = tag .kind , description = "%s_keep%i" % (name , i ), dimension = 1 )
2085- for i , tag in enumerate (source .dim_tags )]
2086- else :
2087- data_opts = source .get_kwargs (include_special_axes = False )
2088- dim_tags = [source .dim_tags [axis ]]
2089- if not dim_tags [0 ].is_batch_dim ():
2090- data_opts .pop ("batch" , None )
2091- data_opts .pop ("beam" , None )
2068+ data_opts = source .get_kwargs (include_special_axes = False )
2069+ dim_tags = [source .dim_tags [axis ]]
2070+ if not dim_tags [0 ].is_batch_dim ():
2071+ data_opts .pop ("batch" , None )
2072+ data_opts .pop ("beam" , None )
20922073 data_opts ["name" ] = "%s_output" % name
20932074 data_opts ["dim_tags" ] = dim_tags
20942075 data_opts ["dtype" ] = dtype
0 commit comments