Skip to content

Commit 9648b4c

Browse files
committed
Data get axis, support dim tags
1 parent 66a36b0 commit 9648b4c

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

returnn/tf/util/data.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2619,7 +2619,7 @@ def get_axes(self, exclude_time=False, exclude_batch=False, exclude_feature=Fals
26192619

26202620
def get_axes_from_description(self, axes, allow_int=True):
26212621
"""
2622-
:param int|list[int]|str|list[str]|None axes: one axis or multiple axis, or none.
2622+
:param int|list[int]|str|list[str|DimensionTag]|DimensionTag|None axes: one axis or multiple axis, or none.
26232623
This is counted with batch-dim, which by default is axis 0 (see enforce_batch_dim_axis).
26242624
It also accepts the special tokens "B"|"batch", "spatial", "spatial_except_time", or "F"|"feature",
26252625
and more (see the code).
@@ -2630,11 +2630,13 @@ def get_axes_from_description(self, axes, allow_int=True):
26302630
"""
26312631
if axes is None or axes == "":
26322632
return []
2633+
if isinstance(axes, DimensionTag):
2634+
return [i for (i, tag) in self.dim_tags if tag == axes]
26332635
if not allow_int:
26342636
assert not isinstance(axes, int)
26352637
assert isinstance(axes, (str, int, list, tuple))
26362638
if isinstance(axes, (list, tuple)):
2637-
assert all([a is None or isinstance(a, (str, int)) for a in axes])
2639+
assert all([a is None or isinstance(a, (str, int, DimensionTag)) for a in axes])
26382640
if not allow_int:
26392641
assert all([not isinstance(a, int) for a in axes])
26402642
if isinstance(axes, str):
@@ -2731,12 +2733,13 @@ def get_axes_from_description(self, axes, allow_int=True):
27312733

27322734
def get_axis_from_description(self, axis, allow_int=True):
27332735
"""
2734-
:param int|str axis:
2736+
:param int|str|DimensionTag axis:
27352737
:param bool allow_int:
27362738
:return: axis, counted with batch-dim
27372739
:rtype: int
27382740
"""
27392741
axes = self.get_axes_from_description(axis, allow_int=allow_int)
2742+
assert axes, "%s: %r axis not found" % (self, axis)
27402743
assert len(axes) == 1, "%r: %r is not a unique axis but %r" % (self, axis, axes)
27412744
return axes[0]
27422745

0 commit comments

Comments
 (0)