Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 126 additions & 29 deletions returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,14 @@ def __init__(self, kind=Types.Unspecified, description=None,
if dyn_size_ext:
assert batch == dyn_size_ext.batch
self.dyn_size_ext = dyn_size_ext # type: typing.Optional[Data]
if dyn_size is not None:
assert not dyn_size_ext
self.dyn_size = dyn_size
self._dyn_size_same = set() # type: typing.Set[tf.Tensor]
self._undefined = undefined
# We can have different tag variants per batch info (e.g. with beam), or per control flow ctx.
# They each have same_as = self. The same_base should have the base (global) batch info.
self._same_for_batch_ctx = {} # type: typing.Dict[typing.Tuple[BatchInfo,typing.Optional[ControlFlowContext]],DimensionTag] # nopep8
if dyn_size is not None:
assert not dyn_size_ext
self.dyn_size = dyn_size

def __repr__(self):
return "DimensionTag{%s}" % self.short_repr()
Expand Down Expand Up @@ -149,13 +149,54 @@ def _can_use_in_ctx(self, ctx):
return False
return True

def get_for_batch_ctx(self, batch, ctx):
def _validate_in_current_graph(self):
"""
:rtype: bool
"""
tensor = None
if self.batch:
batch_base = self.batch.get_global_base()
if batch_base.is_global_batch():
tensor = batch_base.get_global_batch_dim().size
if not isinstance(tensor, tf.Tensor):
if self.dyn_size_ext and self.dyn_size_ext.placeholder is not None:
tensor = self.dyn_size_ext.placeholder
if isinstance(tensor, tf.Tensor):
g = tf_compat.v1.get_default_graph()
if tensor.graph is not g: # maybe from an earlier run which reuses the dim tag
# Reset and cleanup.
self.dyn_size_ext = None
same_base = self.get_same_base()
same_base._same_for_batch_ctx.pop((self.batch, self.control_flow_ctx), None)
self.batch = None # it is invalid in the new graph
self.control_flow_ctx = None # also invalid
return False
return True

def _maybe_update(self):
if self.is_batch_dim():
return
if isinstance(self.dimension, int):
return
if self.dyn_size_ext:
return
if not self.batch:
return
# Check if we can find more in
same = self.get_for_batch_ctx(self.batch, self.control_flow_ctx, allow_none=True)
if self is same or not same or not same.dyn_size_ext:
return
self.dyn_size_ext = same.dyn_size_ext

def get_for_batch_ctx(self, batch, ctx, allow_none=False):
"""
:param BatchInfo batch:
:param ControlFlowContext|None ctx:
:rtype: DimensionTag
:param bool allow_none:
:rtype: DimensionTag|None
"""
if self.batch == batch and self.control_flow_ctx == ctx:
if self.batch == batch and self.control_flow_ctx == ctx and self.dyn_size_ext:
self._validate_in_current_graph()
return self
if self.is_batch_dim():
# We ignore the ctx for the batch dim currently.
Expand All @@ -169,6 +210,7 @@ def get_for_batch_ctx(self, batch, ctx):
if batch.is_broadcast():
return self # just leave as-is. should not matter.
same_base = self.get_same_base()
same_base._validate_in_current_graph()
# Might be uninitialized in some cases. Assume batch is global.
if not same_base.batch:
batch_base = batch.get_global_base()
Expand All @@ -182,27 +224,24 @@ def get_for_batch_ctx(self, batch, ctx):
if same_base.dyn_size_ext:
assert same_base.batch == same_base.dyn_size_ext.batch
assert same_base.control_flow_ctx == same_base.dyn_size_ext.control_flow_ctx
tag = same_base._same_for_batch_ctx.get((batch, ctx), None)
if tag:
return tag
if same_base.batch == batch and same_base._can_use_in_ctx(ctx):
return same_base
for ctx_ in ControlFlowContext.abs_ctx_stack_with_root(ctx):
tag = same_base._same_for_batch_ctx.get((batch, ctx_), None)
if tag and tag._can_use_in_ctx(ctx):
if tag and tag._can_use_in_ctx(ctx) and tag._validate_in_current_graph():
return tag
if same_base.batch == batch and same_base._can_use_in_ctx(ctx) and same_base.dyn_size_ext:
return same_base
# Ok, nothing matching found.
dyn_size_ext = None
# Maybe we have sth with the base batch without beam which we can extend.
if batch.copy_remove_beam() == batch.get_global_base() and batch.beam:
batch_base = batch.get_global_base()
base_can_use_in_ctx = None
if same_base.batch == batch_base and same_base._can_use_in_ctx(ctx):
if same_base.batch == batch_base and same_base._can_use_in_ctx(ctx) and same_base.dyn_size_ext:
base_can_use_in_ctx = same_base
else:
for ctx_ in ControlFlowContext.abs_ctx_stack_with_root(ctx):
tag = same_base._same_for_batch_ctx.get((batch_base, ctx_), None)
if tag and tag._can_use_in_ctx(ctx):
if tag and tag._can_use_in_ctx(ctx) and tag._validate_in_current_graph() and tag.dyn_size_ext:
base_can_use_in_ctx = tag
break
if base_can_use_in_ctx and base_can_use_in_ctx.dyn_size_ext:
Expand All @@ -224,6 +263,8 @@ def get_for_batch_ctx(self, batch, ctx):
name=get_valid_scope_name_from_str("%s_identity_for_beam_%s" % (dyn_size_ext.name, batch.beam.name)))
dyn_size_ext.placeholder._RETURNN_dyn_size_beam = batch.beam
dyn_size_ext.placeholder._RETURNN_beam_expanded_base_data = beam_expanded_base_data
if not dyn_size_ext and allow_none:
return None
dim_tag = DimensionTag(
kind=self.kind, description=self.description, dimension=self.dimension,
batch=batch, control_flow_ctx=dyn_size_ext.control_flow_ctx if dyn_size_ext else ctx,
Expand All @@ -235,6 +276,34 @@ def get_for_batch_ctx(self, batch, ctx):
same_base._same_for_batch_ctx[(dim_tag.batch, dim_tag.control_flow_ctx)] = dim_tag
return dim_tag

def set_dyn_size_ext_for_batch_ctx(self, batch, ctx, dyn_size_ext):
"""
:param BatchInfo batch:
:param ControlFlowContext|None ctx:
:param Data dyn_size_ext:
"""
same = self.get_for_batch_ctx(batch, ctx)
same.dyn_size_ext = dyn_size_ext
self._maybe_update()

def get_dyn_size_ext_for_batch_ctx(self, batch, ctx):
"""
:param BatchInfo|None batch:
:param ControlFlowContext|None ctx:
:rtype: Data|None
"""
if not batch and self.batch:
# Assume global batch.
batch = self.batch.get_global_base()
if not batch:
# This is usually not valid. However, this case can happen early at initialization.
assert batch == self.batch and ctx == self.control_flow_ctx
return self.dyn_size_ext
same = self.get_for_batch_ctx(batch, ctx, allow_none=True)
if not same:
return None
return same.dyn_size_ext

@property
def dyn_size(self):
"""
Expand Down Expand Up @@ -507,6 +576,8 @@ def declare_same_as(self, other):
"""
:param DimensionTag other:
"""
self._maybe_update()
self._validate_in_current_graph()
if self is other:
return
other_same_base = other.get_same_base()
Expand All @@ -517,40 +588,66 @@ def declare_same_as(self, other):
assert not self_same_as.same_as
if self_same_as is other_same_base:
return
other_same_base._merge_same_for_batch_ctx_dict(self_same_as)
self_same_as.same_as = other_same_base
self_same_as._same_as_tb = traceback.extract_stack()
if self_same_as.dyn_size_ext is None:
self_same_as.dyn_size_ext = other_same_base.dyn_size_ext
elif other_same_base.dyn_size_ext is None:
other_same_base.dyn_size_ext = self_same_as.dyn_size_ext
if self.dyn_size_ext is None and self_same_as.dyn_size_ext:
self.dyn_size_ext = self_same_as.dyn_size_ext.copy_extend_with_beam(self.batch.beam if self.batch else None)
if self_same_as.dyn_size_ext is None or not self_same_as._validate_in_current_graph():
self_same_as.dyn_size_ext = other_same_base.get_dyn_size_ext_for_batch_ctx(
self_same_as.batch, self_same_as.control_flow_ctx)
elif other_same_base.dyn_size_ext is None or not other_same_base._validate_in_current_graph():
other_same_base.dyn_size_ext = self_same_as.get_dyn_size_ext_for_batch_ctx(
other_same_base.batch, other_same_base.control_flow_ctx)
if (self.dyn_size_ext is None or not self._validate_in_current_graph()) and self_same_as.dyn_size_ext:
self.dyn_size_ext = self_same_as.get_dyn_size_ext_for_batch_ctx(self.batch, self.control_flow_ctx)
other_same_base._merge_same_for_batch_ctx_dict(self)
self.same_as = other_same_base
self._same_as_tb = traceback.extract_stack()
self._maybe_update()
if self.dyn_size is not None and other_same_base.dyn_size is not None:
if self.dyn_size is not other_same_base.dyn_size:
if self.batch == other_same_base.batch:
if self.batch == other_same_base.batch and self.control_flow_ctx == other_same_base.control_flow_ctx:
# Note: Instead of making this a warning, we could also enforce this at some point.
# The user should be able to fix `extern_data` in the config such that this is correct in the first place.
# Also, in addition to this warning, we might want to add some runtime check on the eq of the dyn sizes.
print(
"Warning: assuming dim tags are same with different size placeholders: %r vs %r" % (self, other_same_base))
"Warning: assuming dim tags are same with different size placeholders: %r vs %r" % (
self.dyn_size, other_same_base.dyn_size))
# If we have a defined source, and this is a dynamic spatial axis, and it was undefined before,
# maybe we can overtake the size_placeholder now.
if self.same_as.dyn_size is not None and self.src_data:
assert isinstance(self.src_axis, int)
# Maybe it changed in the meanwhile, so check.
if self.src_data.get_dim_tag(self.src_axis).description == self.description:
self.src_data.size_placeholder[
self.src_data.get_batch_axis_excluding_batch(self.src_axis)] = self.same_as.dyn_size
tag = self.src_data.get_dim_tag(self.src_axis)
if tag.description == self.description and (not tag.dyn_size_ext or not tag._validate_in_current_graph()):
tag.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx(tag.batch, tag.control_flow_ctx)
# If others dyn_size is None but we have a dyn_size, maybe update others dyn_size.
if self.dyn_size is not None and self.same_as.dyn_size is not self.dyn_size:
# Could be unset if it comes from the config, or from prev graph creation.
# This is important such that self.can_compare() is sane.
if self.same_as.dyn_size is None or self.same_as.dyn_size.graph is not self.dyn_size.graph:
self.same_as.dyn_size_ext = self.dyn_size_ext
if not self.dyn_size_ext and other.dyn_size_ext:
self.dyn_size_ext = other.dyn_size_ext.copy()
if self.same_as.dyn_size is None or not self.same_as._validate_in_current_graph():
self.same_as.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx(
self.same_as.batch, self.same_as.control_flow_ctx)
if (not self.dyn_size_ext or not self._validate_in_current_graph()) and other.dyn_size_ext:
self.dyn_size_ext = other.get_dyn_size_ext_for_batch_ctx(self.batch, self.control_flow_ctx)

def _merge_same_for_batch_ctx_dict(self, other):
"""
:param DimensionTag other:
"""
self._validate_in_current_graph()
for _, dim in list(self._same_for_batch_ctx.items()):
assert isinstance(dim, DimensionTag)
dim._validate_in_current_graph()
for key, dim in other._same_for_batch_ctx.items():
if not dim._validate_in_current_graph():
continue
self_dim = self._same_for_batch_ctx.get(key, None)
if self_dim and (self_dim.dyn_size_ext or not dim.dyn_size_ext):
continue # keep ours
if not dim.dyn_size_ext:
continue # undefined, do not overtake
self._same_for_batch_ctx[key] = dim
other._same_for_batch_ctx.clear() # we only want to have it once

@classmethod
def get_existing_tag_from_collection(cls, other, tags, is_equal_opts=None):
Expand Down