diff --git a/returnn/tf/layers/base.py b/returnn/tf/layers/base.py index d955962a6..3258ae264 100644 --- a/returnn/tf/layers/base.py +++ b/returnn/tf/layers/base.py @@ -393,6 +393,13 @@ def fixup_out_data(cls, output, network): extern_data.init_batch_info() # this should create it and also set it assert output.batch output.batch = output.batch.copy_set_beam(output.beam) + if output.control_flow_ctx != network.get_control_flow_ctx(): + x = output.placeholder + output = output.copy_template_set_ctx(network.get_control_flow_ctx()) + if x is not None: + # Some layers might just copy the input. But the input might have buggy ctx. + # Just leave the placeholder as-is. Most layers should anyway reset this. + output.placeholder = x return output def get_full_ctx_name(self): @@ -1730,7 +1737,7 @@ def opt_get_layer(layer_name): # Don't return layer, could be inside loop and that wont work. output = net.layers[layer_name].output.copy_template() if not output.have_time_axis() and with_time_dim: - output = output.copy_template_adding_time_dim() + output = output.copy_template_adding_time_dim().copy_template_set_ctx(network.get_control_flow_ctx()) if not output: layer_desc_ = net.layers_desc[layer_name].copy() class_name_ = layer_desc_.pop("class") diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index fa43fa268..5c6805950 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -406,8 +406,10 @@ def get_out_data_from_opts(cls, network, unit, _time_dim_tag=None, sources=(), i out = None if isinstance(unit, _SubnetworkRecCell): # subnetwork subnet = unit - sub_out = subnet.layer_data_templates["output"].output.copy_template_adding_time_dim( - name="%s_output" % kwargs["name"], time_dim_axis=0) + sub_out = ( + subnet.layer_data_templates["output"].output + .copy_template_adding_time_dim(name="%s_output" % kwargs["name"], time_dim_axis=0) + .copy_template_set_ctx(network.get_control_flow_ctx())) if out: assert sub_out.dim == out.dim assert sub_out.shape == out.shape @@ -993,6 +995,10 @@ def __init__(self, net_dict, source_data, time_dim_tag, rec_layer_name, parent_n self.parent_net = parent_net self.net_dict = safe_deep_copy(net_dict) from returnn.tf.network import TFNetwork, ExternData, LossHolder + from returnn.tf.util.data import ControlFlowContext + control_flow_ctx = ControlFlowContext( + kind=ControlFlowContext.Types.Loop, outer_ctx=parent_net.get_control_flow_ctx()) + control_flow_ctx.loop_spatial_dim = time_dim_tag self.net = TFNetwork( name="%s/%s(rec-subnet)" % (parent_net.name, rec_layer_name), extern_data=ExternData(), @@ -1000,6 +1006,7 @@ def __init__(self, net_dict, source_data, time_dim_tag, rec_layer_name, parent_n search_flag=parent_net.search_flag, eval_flag=False, inside_rec_time_dim=time_dim_tag, + control_flow_ctx=control_flow_ctx, absolute_name_prefix="%s%s/" % (parent_net.get_absolute_name_prefix(), rec_layer_name), parent_net=parent_net) self.net.is_root_in_ctx = True @@ -1007,7 +1014,7 @@ def __init__(self, net_dict, source_data, time_dim_tag, rec_layer_name, parent_n self.source_data = source_data if source_data: self.net.extern_data.data["source"] = ( - source_data.copy_template_excluding_time_dim()) + source_data.copy_template_excluding_time_dim().copy_template_set_ctx(control_flow_ctx)) self.time_dim_tag = time_dim_tag self._time_dim_tags = {time_dim_tag} # type: typing.Set[DimensionTag] if source_data: @@ -1020,7 +1027,7 @@ def __init__(self, net_dict, source_data, time_dim_tag, rec_layer_name, parent_n # These are just templates. You can use them as possible targets for dimension information, # but not as actual sources or targets. # Note: We maybe should check data.is_same_time_dim()... - self.net.extern_data.data[key] = data.copy_template_excluding_time_dim() + self.net.extern_data.data[key] = data.copy_template_excluding_time_dim().copy_template_set_ctx(control_flow_ctx) self.layer_data_templates = {} # type: typing.Dict[str,_TemplateLayer] self.prev_layers_needed = set() # type: typing.Set[str] self.prev_layer_templates = {} # type: typing.Dict[str,_TemplateLayer] @@ -1545,7 +1552,7 @@ def get_input_moved_out(name): self.parent_rec_layer, self.parent_rec_layer.output.get_time_dim_tag(), layer, layer.output.get_time_dim_tag()) return layer - output = layer.output.copy_template_excluding_time_dim() + output = layer.output.copy_template_excluding_time_dim().copy_template_set_ctx(self.net.control_flow_ctx) with tf.name_scope("%s_moved_input" % name.replace(":", "_")): if prev: output.placeholder = tf.cond( @@ -2513,7 +2520,8 @@ def cond(i, net_vars, acc_tas, seq_len_info=None): self.parent_rec_layer, input_beam, output_beam, self.parent_rec_layer.sources, self.parent_rec_layer.target)) assert output_template.output.batch.beam == output_beam - time_dim_tag = time_dim_tag.get_for_batch(output_template.output.batch) + time_dim_tag = time_dim_tag.get_for_batch_ctx( + batch=output_template.output.batch, ctx=self.net.control_flow_ctx) assert time_dim_tag.dyn_size is not None seq_len = time_dim_tag.dyn_size else: @@ -2772,7 +2780,7 @@ def get_choice_seq(choice_base): latest_batch = ( latest_layer_choice.output.batch or self.parent_rec_layer.output.batch.copy_set_beam(latest_layer_choice.output.beam)) - tag = tag.get_for_batch(latest_batch) + tag = tag.get_for_batch_ctx(batch=latest_batch, ctx=self.net.control_flow_ctx) assert tag.dyn_size is not None assert tag.batch == latest_batch and tag.batch.beam == latest_layer_choice.output.beam seq_len = tag.dyn_size @@ -3216,7 +3224,10 @@ def get_loop_acc_layer(name): acc_ta, latest_layer_choice_name, search_choices, resolved_seq_len = self._opt_search_resolve( layer_name=name, acc_ta=acc_ta, final_net_vars=final_net_vars, seq_len=seq_len, search_choices_cache=search_choices_cache) - output = self.layer_data_templates[name].output.copy_template_adding_time_dim(time_dim_axis=0) + output = ( + self.layer_data_templates[name].output + .copy_template_adding_time_dim(time_dim_axis=0) + .copy_template_set_ctx(self.parent_net.get_control_flow_ctx())) if latest_layer_choice_name: output.beam = self.net.layers[latest_layer_choice_name].search_choices.get_beam_info() elif search_choices: @@ -3303,7 +3314,10 @@ def get_layer(name): for name, search_choices in search_choices_cache.items(): if name not in self.output_layers_net.layers: # Create dummy layer. - output = self.layer_data_templates[name].output.copy_template_adding_time_dim(time_dim_axis=0) + output = ( + self.layer_data_templates[name].output + .copy_template_adding_time_dim(time_dim_axis=0) + .copy_template_set_ctx(self.output_layers_net.get_control_flow_ctx())) output.beam = search_choices.get_beam_info() layer = InternalLayer(name=name, network=self.output_layers_net, output=output) self.output_layers_net.layers[name] = layer @@ -3350,7 +3364,8 @@ def __init__(self, network, name, construct_stack=None, cell=None): output=Data( name="dummy_initial_template_data", batch_dim_axis=0, time_dim_axis=None, - shape=()), # (B,). no time-dim + shape=(), + control_flow_ctx=network.get_control_flow_ctx()), # (B,). no time-dim name=name, network=network) self.output.size_placeholder = {} # must be initialized self.layer_class = ":uninitialized-template" @@ -5226,7 +5241,7 @@ def decide(cls, src, output=None, owner=None, name=None, length_normalization=Fa for i, size in src_data.size_placeholder.items(): tag = DimensionTag.get_tag_from_size_tensor(size) assert tag - tag = tag.get_for_batch(output.batch) + tag = tag.get_for_batch_ctx(batch=output.batch, ctx=output.control_flow_ctx) if tag.dyn_size is None: size = tf.reshape(size, [batch_dim, beam_size]) # (batch, beam) size = tf.gather_nd(size, indices=beam_idxs_ext) # (batch,) @@ -7071,8 +7086,11 @@ def _create_template(cls, name, network, sources, masked_from, unit, # We don't care about the right masked input here, but just about deriving the right output shape. if masked_from: if network.is_inside_rec_layer(inside_loop=True): - source_data = masked_from.output.copy_template_excluding_time_dim( - name="%s_%s_masked_input_frame" % (masked_from.output.name, name)) + source_data = ( + masked_from.output + .copy_template_excluding_time_dim( + name="%s_%s_masked_input_frame" % (masked_from.output.name, name)) + .copy_template_set_ctx(network.get_control_flow_ctx())) else: source_data = masked_from.output.copy_template( name="%s_%s_masked_input" % (masked_from.output.name, name)) @@ -7347,7 +7365,7 @@ def get_out_data_from_opts(cls, name, network, sources, mask, **kwargs): # thus when we unroll it to get into the loop, the RecLayer would have kept it as-is, # i.e. it should still have that time-dim-axis. # Maybe we should do some extra checks if that is like we assume, but for now, just assume that. - return out.copy_template_excluding_time_dim() + return out.copy_template_excluding_time_dim().copy_template_set_ctx(network.get_control_flow_ctx()) return out assert out.have_time_axis() out = out.copy_as_time_major() diff --git a/returnn/tf/network.py b/returnn/tf/network.py index d17109ba0..9be6ad4a9 100644 --- a/returnn/tf/network.py +++ b/returnn/tf/network.py @@ -354,6 +354,7 @@ def __init__(self, config=None, extern_data=None, rnd_seed=None, train_flag=None, eval_flag=None, search_flag=None, parent_layer=None, parent_net=None, extra_parent_net=None, extra_name_prefix=None, inside_rec_time_dim=None, over_rec_time_dim=None, over_rec_time_dim_subs=None, + control_flow_ctx=None, absolute_name_prefix=None, name=None): """ :param returnn.config.Config config: only needed to init extern_data if not specified explicitly @@ -370,6 +371,7 @@ def __init__(self, config=None, extern_data=None, rnd_seed=None, :param DimensionTag|None inside_rec_time_dim: dim tag of outer rec layer, when run inside the loop (not optimized) :param DimensionTag|None over_rec_time_dim: dim tag of outer rec layer, when optimized out of the loop :param set[DimensionTag]|None over_rec_time_dim_subs: outer rec layer, out of loop, potential shorter + :param returnn.tf.util.data.ControlFlowContext control_flow_ctx: :param str|None absolute_name_prefix: :param str name: only for debugging """ @@ -432,6 +434,7 @@ def __init__(self, config=None, extern_data=None, rnd_seed=None, self._inside_rec_time_dim = inside_rec_time_dim self._over_rec_time_dim = over_rec_time_dim self._over_rec_time_dim_subs = over_rec_time_dim_subs + self.control_flow_ctx = control_flow_ctx self.extra_parent_net = extra_parent_net self.extra_name_prefix = extra_name_prefix self.extra_deps_in_extra = False @@ -519,6 +522,17 @@ def get_root_ctx_network(self): break return net, "".join(reversed(path)) + def get_control_flow_ctx(self): + """ + :rtype: returnn.tf.util.data.ControlFlowContext|None + """ + net = self + while net: + if net.control_flow_ctx: + return net.control_flow_ctx + net = net.parent_net + return None + def is_extra_internal_template_construction(self): """ :rtype: LayerBase|None diff --git a/returnn/tf/util/basic.py b/returnn/tf/util/basic.py index a7314f737..856ec7af1 100644 --- a/returnn/tf/util/basic.py +++ b/returnn/tf/util/basic.py @@ -4585,7 +4585,7 @@ def _maybe_to_base_seq_len(v): base_out_tag.set_tag_on_size_tensor(base_out_seq_len) assert base_out_tag.batch - out_tag = base_out_tag.get_for_batch(in_tag.batch) + out_tag = base_out_tag.get_for_batch_ctx(batch=in_tag.batch, ctx=in_tag.control_flow_ctx) assert out_tag.dyn_size is not None return out_tag.dyn_size diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 3457d0883..0fd3d212b 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -41,7 +41,7 @@ class Types: def __init__(self, kind=Types.Unspecified, description=None, dimension=None, dyn_size=None, dyn_size_ext=None, - batch=None, + batch=None, control_flow_ctx=None, src_data=None, src_axis=None): """ :param str|None kind: @@ -50,6 +50,7 @@ def __init__(self, kind=Types.Unspecified, description=None, :param tf.Tensor|None dyn_size: e.g. seq_len, (batch,) :param Data|None dyn_size_ext: seq_len or extended :param BatchInfo|None batch: for batch-dim, or dynamic dims per batch + :param ControlFlowContext|None control_flow_ctx: :param Data|None src_data: :param int|None src_axis: """ @@ -63,6 +64,7 @@ def __init__(self, kind=Types.Unspecified, description=None, if not batch and dyn_size_ext: batch = dyn_size_ext.batch self.batch = batch + self.control_flow_ctx = control_flow_ctx self.src_data = src_data self.src_axis = src_axis if dyn_size_ext and not dyn_size_ext.batch and batch: @@ -74,17 +76,9 @@ def __init__(self, kind=Types.Unspecified, description=None, assert not dyn_size_ext self.dyn_size = dyn_size self._dyn_size_same = set() # type: typing.Set[tf.Tensor] - # We can have different tag variants per batch info (e.g. with beam). + # We can have different tag variants per batch info (e.g. with beam), or per control flow ctx. # They each have same_as = self. The same_base should have the base (global) batch info. - self._same_for_batch = {} # type: typing.Dict[BatchInfo,DimensionTag] - # When we have some dynamic size, this dynamic size could be inside a loop (RecLayer), - # and different per each loop frame. - # In that case, we can not access it from outside (except when we accumulate it). - # We expect that the same_base is the dim tag inside the loop, which has this set. - self.per_spatial_frame = None # type: typing.Optional[DimensionTag] - # When we accumulate the dynamic sizes, this results in this dim tag. - # It has same_as set to self. - self.per_spatial_frame_accumulated = None # type: typing.Optional[DimensionTag] + self._same_for_batch_ctx = {} # type: typing.Dict[typing.Tuple[BatchInfo,typing.Optional[ControlFlowContext]],DimensionTag] # nopep8 def __repr__(self): return "DimensionTag{%s}" % self.short_repr() @@ -121,12 +115,13 @@ def copy(self, kind=None): tag._same_as_tb = traceback.extract_stack() return tag - def get_for_batch(self, batch): + def get_for_batch_ctx(self, batch, ctx): """ :param BatchInfo batch: + :param ControlFlowContext|None ctx: :rtype: DimensionTag """ - if self.batch == batch: + if self.batch == batch and self.control_flow_ctx == ctx: return self if self.is_batch_dim(): # For now, just create another copy in case of batch dim tag. @@ -150,11 +145,19 @@ def get_for_batch(self, batch): same_base.batch = batch_base if same_base.dyn_size_ext: assert same_base.batch == same_base.dyn_size_ext.batch - if same_base.batch == batch: - return same_base - if batch in same_base._same_for_batch: - return same_base._same_for_batch[batch] + assert same_base.control_flow_ctx == same_base.dyn_size_ext.control_flow_ctx + if same_base.batch == batch and ControlFlowContext.is_parent_or_same(same_base.control_flow_ctx, ctx): + if same_base.dyn_size_ext or same_base.control_flow_ctx == ctx: + return same_base + # If any parent ctx exists, and is defined, use it. + for ctx_ in ControlFlowContext.abs_ctx_stack_with_root(ctx): + tag = same_base._same_for_batch_ctx.get((batch, ctx_), None) + if tag and (tag.dyn_size_ext or tag.control_flow_ctx == ctx): + return tag if batch.copy_remove_beam() == batch.get_global_base() and batch.beam and same_base.dyn_size_ext: + # The same_base has some dyn size without any beam nor control flow context. + # We can expand it to the current beam. + assert not same_base.control_flow_ctx dyn_size_ext = same_base.dyn_size_ext.copy_extend_with_beam(batch.beam) assert dyn_size_ext.batch == batch beam_expanded_base_data = getattr(dyn_size_ext.placeholder, "_RETURNN_beam_expanded_base_data", None) @@ -177,12 +180,13 @@ def get_for_batch(self, batch): dyn_size_ext = None dim_tag = DimensionTag( kind=self.kind, description=self.description, dimension=self.dimension, - batch=batch, dyn_size_ext=dyn_size_ext) + batch=batch, control_flow_ctx=dyn_size_ext.control_flow_ctx if dyn_size_ext else ctx, + dyn_size_ext=dyn_size_ext) dim_tag.same_as = same_base dim_tag._same_as_tb = traceback.extract_stack() if dyn_size_ext: dim_tag.set_tag_on_size_tensor(dyn_size_ext.placeholder, batch=batch) - same_base._same_for_batch[batch] = dim_tag + same_base._same_for_batch_ctx[(dim_tag.batch, dim_tag.control_flow_ctx)] = dim_tag return dim_tag @property @@ -210,7 +214,7 @@ def dyn_size(self, dyn_size): self.dyn_size_ext = Data( name=("%s:dyn_size" % self.description) if self.description else dyn_size.op.name, dtype=Data.size_dtype, placeholder=dyn_size, shape=(), batch_dim_axis=0, - batch=self.batch, beam=beam) + batch=self.batch, beam=beam, control_flow_ctx=self.control_flow_ctx) other = DimensionTag.get_tag_from_size_tensor(dyn_size) if other: self.declare_same_as(other) @@ -278,7 +282,7 @@ def set_tag_on_size_tensor(self, x, batch=None, same_as_before=False): # If we already have another dyn size set or different batch, create a new DimensionTag instance. if self.batch and batch and self.batch != batch: assert not same_as_before # it cannot be the same when it is another batch... - new_dim_tag = self.get_for_batch(batch) + new_dim_tag = self.get_for_batch_ctx(batch=batch, ctx=self.control_flow_ctx) new_dim_tag.set_tag_on_size_tensor(x, batch=batch) return new_dim_tag if self.dyn_size is not None and self.dyn_size is not x: @@ -1255,7 +1259,8 @@ def __init__(self, name, dim_tags=None, same_dim_tags_as=None, batch=None, - beam=None): + beam=None, + control_flow_ctx=None): """ :param str name: :param tuple[int|None]|list[int|None] shape: including time-dim (can be None). excluding batch-dim. @@ -1284,6 +1289,7 @@ def __init__(self, name, :param BatchInfo|None batch: :param SearchBeam|None beam: the batch-dim could be extended by a beam-size, such that it represents the merged dims [batch, beam_size]. + :param ControlFlowContext|None control_flow_ctx: """ assert isinstance(name, str) assert dtype is None or isinstance(dtype, str) @@ -1301,6 +1307,7 @@ def __init__(self, name, assert batch.beam == beam self._batch = batch self._beam = beam + self.control_flow_ctx = control_flow_ctx if isinstance(dim_tags, (tuple, list)): # We do a couple of sanity checks, and maybe set special axes attribs. shape_ = tuple(tag.dimension for tag in dim_tags if not tag.is_batch_dim()) @@ -1561,6 +1568,8 @@ def get_kwargs(self, include_special_axes=True): keys += ["batch"] if self.beam is not None: keys += ["beam"] + if self.control_flow_ctx: + keys += ["control_flow_ctx"] if not self.available_for_inference: keys += ["available_for_inference"] return {key: getattr(self, key) for key in keys} @@ -1665,9 +1674,10 @@ def __hash__(self): return id(self) def _adapt_batch_consistent_dim_tags(self): - if not self.batch: + if not self.batch: # uninitialized return - self._dim_tags = tuple(tag.get_for_batch(self.batch) for tag in self._dim_tags) + self._dim_tags = tuple( + tag.get_for_batch_ctx(batch=self.batch, ctx=self.control_flow_ctx) for tag in self._dim_tags) def copy(self, name=None): """ @@ -2382,6 +2392,16 @@ def copy_template_new_dim_tags(self, new_dim_tags, name=None, keep_special_axes= opts["name"] = name return Data(**opts) + def copy_template_set_ctx(self, ctx): + """ + :param ControlFlowContext ctx: + :return: new Data instance + :rtype: Data + """ + kwargs = self.get_kwargs() + kwargs["control_flow_ctx"] = ctx + return Data(**kwargs) + def _get_variable_dim_pattern(self): """ :return: tuple with bools specifying which dims of the shape (excluding batch-dim) are of variable length. @@ -3074,7 +3094,7 @@ def set_dynamic_size(self, axis, sizes): if self.beam and getattr(sizes, "_RETURNN_dyn_size_beam", None) != self.beam: tag = DimensionTag.get_tag_from_size_tensor(sizes) assert tag and self.batch - tag = tag.get_for_batch(self.batch) + tag = tag.get_for_batch_ctx(batch=self.batch, ctx=self.control_flow_ctx) assert tag.dyn_size is not None sizes = tag.dyn_size @@ -3961,3 +3981,145 @@ def _default_feature_dim_axis(batch_dim_axis, time_dim_axis, batch_shape, sparse if static_axes: return static_axes[-1] return axes[-1] + + +class ControlFlowContext: + """ + This is a simple wrapper around the TF ControlFlowContext, i.e. tf.while_loop or tf.cond. + + We have this wrapper to refer to a context which might not exist yet (e.g. at template construction time). + Also, we might want to store additional information, such the spatial dim tag of the loop. + """ + + class Types: + """ + Possible types of context. + """ + Loop = "loop" + Cond = "cond" + + def __init__(self, kind, outer_ctx=None): + """ + :param str kind: from ControlFlowContext.Types + :param ControlFlowContext outer_ctx: + """ + self.kind = kind + self._outer_ctx = outer_ctx + from tensorflow.python.ops.control_flow_ops import ControlFlowContext as TFControlFlowCtx + self._tf_control_flow_ctx = None # type: typing.Optional[TFControlFlowCtx] + self._loop_spatial_dim = None # type: typing.Optional[DimensionTag] + + def __repr__(self): + return "ControlFlowContext{%s}" % self.repr_inner() + + def repr_inner(self): + """ + :rtype: str + """ + return "/".join(ctx._repr_single() for ctx in self.abs_ctx_stack()) + + def _repr_single(self): + """ + :rtype: str + """ + s = self.kind + if self.is_loop() and self.loop_spatial_dim: + s += "(%s)" % self.loop_spatial_dim.short_repr() + return s + + def abs_ctx_stack(self): + """ + :rtype: list[ControlFlowContext] + :return: chain of ctx, last is self + """ + chain = [] + ctx = self + while ctx: + chain.append(ctx) + ctx = ctx.outer_ctx + chain.reverse() + return chain + + @classmethod + def abs_ctx_stack_with_root(cls, ctx): + """ + :param ControlFlowContext|None ctx: + :rtype: list[ControlFlowContext|None] + :return: chain of ctx, last is self, first is None + """ + ls = [None] # type: typing.List[typing.Optional[ControlFlowContext]] + if ctx: + ls += ctx.abs_ctx_stack() + return ls + + @classmethod + def is_parent_or_same(cls, parent, child): + """ + :param ControlFlowContext|None parent: + :param ControlFlowContext|None child: + :rtype: bool + """ + if parent == child: + return True + if not parent: + return True # parent is root + if not child: + return False # child is root but parent is not + while child: + if child == parent: + return True + child = child.outer_ctx + return False + + def is_loop(self): + """ + :rtype: bool + """ + return self.kind == self.Types.Loop + + def is_cond(self): + """ + :rtype: bool + """ + return self.kind == self.Types.Cond + + @property + def outer_ctx(self): + """ + :rtype: ControlFlowContext|None + """ + return self._outer_ctx + + @property + def tf_control_flow_ctx(self): + """ + :rtype: tensorflow.python.ops.control_flow_ops.ControlFlowContext|None + """ + return self._tf_control_flow_ctx + + @tf_control_flow_ctx.setter + def tf_control_flow_ctx(self, ctx): + """ + :param tensorflow.python.ops.control_flow_ops.ControlFlowContext ctx: + """ + if self.is_loop(): + assert ctx.IsWhileContext() + if self.is_cond(): + assert ctx.IsCondContext() + self._tf_control_flow_ctx = ctx + + @property + def loop_spatial_dim(self): + """ + :rtype: DimensionTag|None + """ + assert self.is_loop() + return self._loop_spatial_dim + + @loop_spatial_dim.setter + def loop_spatial_dim(self, dim): + """ + :param DimensionTag dim: + """ + assert self.is_loop() + self._loop_spatial_dim = dim