diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 3d5b62109..10e7d494c 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -82,14 +82,14 @@ def __init__(self, kind=Types.Unspecified, description=None, if dyn_size_ext: assert batch == dyn_size_ext.batch self.dyn_size_ext = dyn_size_ext # type: typing.Optional[Data] - if dyn_size is not None: - assert not dyn_size_ext - self.dyn_size = dyn_size self._dyn_size_same = set() # type: typing.Set[tf.Tensor] self._undefined = undefined # We can have different tag variants per batch info (e.g. with beam), or per control flow ctx. # They each have same_as = self. The same_base should have the base (global) batch info. self._same_for_batch_ctx = {} # type: typing.Dict[typing.Tuple[BatchInfo,typing.Optional[ControlFlowContext]],DimensionTag] # nopep8 + if dyn_size is not None: + assert not dyn_size_ext + self.dyn_size = dyn_size def __repr__(self): return "DimensionTag{%s}" % self.short_repr() @@ -149,13 +149,54 @@ def _can_use_in_ctx(self, ctx): return False return True - def get_for_batch_ctx(self, batch, ctx): + def _validate_in_current_graph(self): + """ + :rtype: bool + """ + tensor = None + if self.batch: + batch_base = self.batch.get_global_base() + if batch_base.is_global_batch(): + tensor = batch_base.get_global_batch_dim().size + if not isinstance(tensor, tf.Tensor): + if self.dyn_size_ext and self.dyn_size_ext.placeholder is not None: + tensor = self.dyn_size_ext.placeholder + if isinstance(tensor, tf.Tensor): + g = tf_compat.v1.get_default_graph() + if tensor.graph is not g: # maybe from an earlier run which reuses the dim tag + # Reset and cleanup. + self.dyn_size_ext = None + same_base = self.get_same_base() + same_base._same_for_batch_ctx.pop((self.batch, self.control_flow_ctx), None) + self.batch = None # it is invalid in the new graph + self.control_flow_ctx = None # also invalid + return False + return True + + def _maybe_update(self): + if self.is_batch_dim(): + return + if isinstance(self.dimension, int): + return + if self.dyn_size_ext: + return + if not self.batch: + return + # Check if we can find more in + same = self.get_for_batch_ctx(self.batch, self.control_flow_ctx, allow_none=True) + if self is same or not same or not same.dyn_size_ext: + return + self.dyn_size_ext = same.dyn_size_ext + + def get_for_batch_ctx(self, batch, ctx, allow_none=False): """ :param BatchInfo batch: :param ControlFlowContext|None ctx: - :rtype: DimensionTag + :param bool allow_none: + :rtype: DimensionTag|None """ - if self.batch == batch and self.control_flow_ctx == ctx: + if self.batch == batch and self.control_flow_ctx == ctx and self.dyn_size_ext: + self._validate_in_current_graph() return self if self.is_batch_dim(): # We ignore the ctx for the batch dim currently. @@ -169,6 +210,7 @@ def get_for_batch_ctx(self, batch, ctx): if batch.is_broadcast(): return self # just leave as-is. should not matter. same_base = self.get_same_base() + same_base._validate_in_current_graph() # Might be uninitialized in some cases. Assume batch is global. if not same_base.batch: batch_base = batch.get_global_base() @@ -182,27 +224,24 @@ def get_for_batch_ctx(self, batch, ctx): if same_base.dyn_size_ext: assert same_base.batch == same_base.dyn_size_ext.batch assert same_base.control_flow_ctx == same_base.dyn_size_ext.control_flow_ctx - tag = same_base._same_for_batch_ctx.get((batch, ctx), None) - if tag: - return tag - if same_base.batch == batch and same_base._can_use_in_ctx(ctx): - return same_base for ctx_ in ControlFlowContext.abs_ctx_stack_with_root(ctx): tag = same_base._same_for_batch_ctx.get((batch, ctx_), None) - if tag and tag._can_use_in_ctx(ctx): + if tag and tag._can_use_in_ctx(ctx) and tag._validate_in_current_graph(): return tag + if same_base.batch == batch and same_base._can_use_in_ctx(ctx) and same_base.dyn_size_ext: + return same_base # Ok, nothing matching found. dyn_size_ext = None # Maybe we have sth with the base batch without beam which we can extend. if batch.copy_remove_beam() == batch.get_global_base() and batch.beam: batch_base = batch.get_global_base() base_can_use_in_ctx = None - if same_base.batch == batch_base and same_base._can_use_in_ctx(ctx): + if same_base.batch == batch_base and same_base._can_use_in_ctx(ctx) and same_base.dyn_size_ext: base_can_use_in_ctx = same_base else: for ctx_ in ControlFlowContext.abs_ctx_stack_with_root(ctx): tag = same_base._same_for_batch_ctx.get((batch_base, ctx_), None) - if tag and tag._can_use_in_ctx(ctx): + if tag and tag._can_use_in_ctx(ctx) and tag._validate_in_current_graph() and tag.dyn_size_ext: base_can_use_in_ctx = tag break if base_can_use_in_ctx and base_can_use_in_ctx.dyn_size_ext: @@ -224,6 +263,8 @@ def get_for_batch_ctx(self, batch, ctx): name=get_valid_scope_name_from_str("%s_identity_for_beam_%s" % (dyn_size_ext.name, batch.beam.name))) dyn_size_ext.placeholder._RETURNN_dyn_size_beam = batch.beam dyn_size_ext.placeholder._RETURNN_beam_expanded_base_data = beam_expanded_base_data + if not dyn_size_ext and allow_none: + return None dim_tag = DimensionTag( kind=self.kind, description=self.description, dimension=self.dimension, batch=batch, control_flow_ctx=dyn_size_ext.control_flow_ctx if dyn_size_ext else ctx, @@ -235,6 +276,34 @@ def get_for_batch_ctx(self, batch, ctx): same_base._same_for_batch_ctx[(dim_tag.batch, dim_tag.control_flow_ctx)] = dim_tag return dim_tag + def set_dyn_size_ext_for_batch_ctx(self, batch, ctx, dyn_size_ext): + """ + :param BatchInfo batch: + :param ControlFlowContext|None ctx: + :param Data dyn_size_ext: + """ + same = self.get_for_batch_ctx(batch, ctx) + same.dyn_size_ext = dyn_size_ext + self._maybe_update() + + def get_dyn_size_ext_for_batch_ctx(self, batch, ctx): + """ + :param BatchInfo|None batch: + :param ControlFlowContext|None ctx: + :rtype: Data|None + """ + if not batch and self.batch: + # Assume global batch. + batch = self.batch.get_global_base() + if not batch: + # This is usually not valid. However, this case can happen early at initialization. + assert batch == self.batch and ctx == self.control_flow_ctx + return self.dyn_size_ext + same = self.get_for_batch_ctx(batch, ctx, allow_none=True) + if not same: + return None + return same.dyn_size_ext + @property def dyn_size(self): """ @@ -507,6 +576,8 @@ def declare_same_as(self, other): """ :param DimensionTag other: """ + self._maybe_update() + self._validate_in_current_graph() if self is other: return other_same_base = other.get_same_base() @@ -517,40 +588,66 @@ def declare_same_as(self, other): assert not self_same_as.same_as if self_same_as is other_same_base: return + other_same_base._merge_same_for_batch_ctx_dict(self_same_as) self_same_as.same_as = other_same_base self_same_as._same_as_tb = traceback.extract_stack() - if self_same_as.dyn_size_ext is None: - self_same_as.dyn_size_ext = other_same_base.dyn_size_ext - elif other_same_base.dyn_size_ext is None: - other_same_base.dyn_size_ext = self_same_as.dyn_size_ext - if self.dyn_size_ext is None and self_same_as.dyn_size_ext: - self.dyn_size_ext = self_same_as.dyn_size_ext.copy_extend_with_beam(self.batch.beam if self.batch else None) + if self_same_as.dyn_size_ext is None or not self_same_as._validate_in_current_graph(): + self_same_as.dyn_size_ext = other_same_base.get_dyn_size_ext_for_batch_ctx( + self_same_as.batch, self_same_as.control_flow_ctx) + elif other_same_base.dyn_size_ext is None or not other_same_base._validate_in_current_graph(): + other_same_base.dyn_size_ext = self_same_as.get_dyn_size_ext_for_batch_ctx( + other_same_base.batch, other_same_base.control_flow_ctx) + if (self.dyn_size_ext is None or not self._validate_in_current_graph()) and self_same_as.dyn_size_ext: + self.dyn_size_ext = self_same_as.get_dyn_size_ext_for_batch_ctx(self.batch, self.control_flow_ctx) + other_same_base._merge_same_for_batch_ctx_dict(self) self.same_as = other_same_base self._same_as_tb = traceback.extract_stack() + self._maybe_update() if self.dyn_size is not None and other_same_base.dyn_size is not None: if self.dyn_size is not other_same_base.dyn_size: - if self.batch == other_same_base.batch: + if self.batch == other_same_base.batch and self.control_flow_ctx == other_same_base.control_flow_ctx: # Note: Instead of making this a warning, we could also enforce this at some point. # The user should be able to fix `extern_data` in the config such that this is correct in the first place. # Also, in addition to this warning, we might want to add some runtime check on the eq of the dyn sizes. print( - "Warning: assuming dim tags are same with different size placeholders: %r vs %r" % (self, other_same_base)) + "Warning: assuming dim tags are same with different size placeholders: %r vs %r" % ( + self.dyn_size, other_same_base.dyn_size)) # If we have a defined source, and this is a dynamic spatial axis, and it was undefined before, # maybe we can overtake the size_placeholder now. if self.same_as.dyn_size is not None and self.src_data: assert isinstance(self.src_axis, int) # Maybe it changed in the meanwhile, so check. - if self.src_data.get_dim_tag(self.src_axis).description == self.description: - self.src_data.size_placeholder[ - self.src_data.get_batch_axis_excluding_batch(self.src_axis)] = self.same_as.dyn_size + tag = self.src_data.get_dim_tag(self.src_axis) + if tag.description == self.description and (not tag.dyn_size_ext or not tag._validate_in_current_graph()): + tag.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx(tag.batch, tag.control_flow_ctx) # If others dyn_size is None but we have a dyn_size, maybe update others dyn_size. if self.dyn_size is not None and self.same_as.dyn_size is not self.dyn_size: # Could be unset if it comes from the config, or from prev graph creation. # This is important such that self.can_compare() is sane. - if self.same_as.dyn_size is None or self.same_as.dyn_size.graph is not self.dyn_size.graph: - self.same_as.dyn_size_ext = self.dyn_size_ext - if not self.dyn_size_ext and other.dyn_size_ext: - self.dyn_size_ext = other.dyn_size_ext.copy() + if self.same_as.dyn_size is None or not self.same_as._validate_in_current_graph(): + self.same_as.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx( + self.same_as.batch, self.same_as.control_flow_ctx) + if (not self.dyn_size_ext or not self._validate_in_current_graph()) and other.dyn_size_ext: + self.dyn_size_ext = other.get_dyn_size_ext_for_batch_ctx(self.batch, self.control_flow_ctx) + + def _merge_same_for_batch_ctx_dict(self, other): + """ + :param DimensionTag other: + """ + self._validate_in_current_graph() + for _, dim in list(self._same_for_batch_ctx.items()): + assert isinstance(dim, DimensionTag) + dim._validate_in_current_graph() + for key, dim in other._same_for_batch_ctx.items(): + if not dim._validate_in_current_graph(): + continue + self_dim = self._same_for_batch_ctx.get(key, None) + if self_dim and (self_dim.dyn_size_ext or not dim.dyn_size_ext): + continue # keep ours + if not dim.dyn_size_ext: + continue # undefined, do not overtake + self._same_for_batch_ctx[key] = dim + other._same_for_batch_ctx.clear() # we only want to have it once @classmethod def get_existing_tag_from_collection(cls, other, tags, is_equal_opts=None):