@@ -2619,7 +2619,7 @@ def get_axes(self, exclude_time=False, exclude_batch=False, exclude_feature=Fals
26192619
26202620 def get_axes_from_description (self , axes , allow_int = True ):
26212621 """
2622- :param int|list[int]|str|list[str] |None axes: one axis or multiple axis, or none.
2622+ :param int|list[int]|str|list[str|DimensionTag]|DimensionTag |None axes: one axis or multiple axis, or none.
26232623 This is counted with batch-dim, which by default is axis 0 (see enforce_batch_dim_axis).
26242624 It also accepts the special tokens "B"|"batch", "spatial", "spatial_except_time", or "F"|"feature",
26252625 and more (see the code).
@@ -2630,11 +2630,13 @@ def get_axes_from_description(self, axes, allow_int=True):
26302630 """
26312631 if axes is None or axes == "" :
26322632 return []
2633+ if isinstance (axes , DimensionTag ):
2634+ return [i for (i , tag ) in self .dim_tags if tag == axes ]
26332635 if not allow_int :
26342636 assert not isinstance (axes , int )
26352637 assert isinstance (axes , (str , int , list , tuple ))
26362638 if isinstance (axes , (list , tuple )):
2637- assert all ([a is None or isinstance (a , (str , int )) for a in axes ])
2639+ assert all ([a is None or isinstance (a , (str , int , DimensionTag )) for a in axes ])
26382640 if not allow_int :
26392641 assert all ([not isinstance (a , int ) for a in axes ])
26402642 if isinstance (axes , str ):
@@ -2731,12 +2733,13 @@ def get_axes_from_description(self, axes, allow_int=True):
27312733
27322734 def get_axis_from_description (self , axis , allow_int = True ):
27332735 """
2734- :param int|str axis:
2736+ :param int|str|DimensionTag axis:
27352737 :param bool allow_int:
27362738 :return: axis, counted with batch-dim
27372739 :rtype: int
27382740 """
27392741 axes = self .get_axes_from_description (axis , allow_int = allow_int )
2742+ assert axes , "%s: %r axis not found" % (self , axis )
27402743 assert len (axes ) == 1 , "%r: %r is not a unique axis but %r" % (self , axis , axes )
27412744 return axes [0 ]
27422745
0 commit comments