Skip to content

Commit 0268447

Browse files
authored
RangeInAxisLayer, keepdims unsupported, not needed (#641)
Fix #639 Fix #638
1 parent 7b30233 commit 0268447

File tree

1 file changed

+11
-30
lines changed

1 file changed

+11
-30
lines changed

returnn/tf/layers/basic.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2032,12 +2032,12 @@ class RangeInAxisLayer(LayerBase):
20322032
recurrent = True # if axis=="T", the time-dim order matters
20332033

20342034
# noinspection PyUnusedLocal
2035-
def __init__(self, axis, dtype="int32", unbroadcast=False, keepdims=True, sparse=False, **kwargs):
2035+
def __init__(self, axis, dtype="int32", unbroadcast=False, keepdims=False, sparse=False, **kwargs):
20362036
"""
20372037
:param str axis:
20382038
:param str dtype:
2039-
:param bool unbroadcast:
2040-
:param bool keepdims:
2039+
:param bool unbroadcast: DEPRECATED, unsupported, and not needed
2040+
:param bool keepdims: DEPRECATED, unsupported, and not needed
20412041
:param bool sparse:
20422042
"""
20432043
super(RangeInAxisLayer, self).__init__(**kwargs)
@@ -2048,47 +2048,28 @@ def __init__(self, axis, dtype="int32", unbroadcast=False, keepdims=True, sparse
20482048
source_shape = get_shape(source.placeholder)
20492049
out = tf.range(0, source_shape[axis], dtype=dtype)
20502050
if unbroadcast:
2051-
assert keepdims
2051+
raise Exception("%s: do not use unbroadcast")
20522052
if keepdims:
2053-
out_shape = [
2054-
source_shape[i]
2055-
if (i == axis or i == self.output.batch_dim_axis)
2056-
else 1
2057-
for i in range(self.output.batch_ndim)]
2058-
out = tf.reshape(out, out_shape) # add missing axes (keep_dims)
2059-
if unbroadcast:
2060-
out = out + tf.zeros(source_shape, dtype=out.dtype)
2053+
raise Exception("%s: do not use keepdims")
20612054
self.output.placeholder = out
20622055

20632056
@classmethod
2064-
def get_out_data_from_opts(cls,
2065-
name, sources, axis, dtype="int32", unbroadcast=False, keepdims=True, sparse=False,
2066-
**kwargs):
2057+
def get_out_data_from_opts(cls, name, sources, axis, dtype="int32", sparse=False, **kwargs):
20672058
"""
20682059
:param str name:
20692060
:param list[LayerBase] sources:
20702061
:param str axis:
20712062
:param str dtype:
2072-
:param bool unbroadcast:
2073-
:param bool keepdims:
20742063
:param bool sparse:
20752064
"""
2076-
from ..util.data import DimensionTag
20772065
assert len(sources) == 1, "%s layer %r requires single source" % (cls, name)
20782066
source = sources[0].output
20792067
axis = source.get_axis_from_description(axis)
2080-
if keepdims:
2081-
data_opts = source.get_kwargs(include_special_axes=True)
2082-
dim_tags = [
2083-
tag if (i == axis or tag.is_batch_dim() or unbroadcast)
2084-
else DimensionTag(kind=tag.kind, description="%s_keep%i" % (name, i), dimension=1)
2085-
for i, tag in enumerate(source.dim_tags)]
2086-
else:
2087-
data_opts = source.get_kwargs(include_special_axes=False)
2088-
dim_tags = [source.dim_tags[axis]]
2089-
if not dim_tags[0].is_batch_dim():
2090-
data_opts.pop("batch", None)
2091-
data_opts.pop("beam", None)
2068+
data_opts = source.get_kwargs(include_special_axes=False)
2069+
dim_tags = [source.dim_tags[axis]]
2070+
if not dim_tags[0].is_batch_dim():
2071+
data_opts.pop("batch", None)
2072+
data_opts.pop("beam", None)
20922073
data_opts["name"] = "%s_output" % name
20932074
data_opts["dim_tags"] = dim_tags
20942075
data_opts["dtype"] = dtype

0 commit comments

Comments
 (0)