Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4b2307e
Data beam/batch encapsulation, adapt dyn sizes, dim tags per batch
albertz Sep 2, 2021
6a84085
Data small static batch dim tag fixes
albertz Aug 29, 2021
4927c61
DimensionTag same_as extended logic, same size, better batch logic
albertz Aug 30, 2021
8abafe4
Data set_dynamic_size adopt potential new dim tag
albertz Aug 30, 2021
d5c6ca8
DimensionTag set_Tag_on_size_tensor reset, extended exception text
albertz Aug 31, 2021
f04a94c
Data.set_dynamic_size fix and cleanup outdated size beam tiling logic
albertz Aug 31, 2021
438253e
DimensionTag.set_tag_on_size_tensor exception less verbose graph
albertz Aug 31, 2021
fd43640
DimensionTag for batch beam tiling fix multiple beams
albertz Sep 1, 2021
bd83de7
DimensionTag get_for_batch size identity fix control flow ctx
albertz Sep 1, 2021
68cab2c
Data expand beam set base data attrib on tile
albertz Sep 2, 2021
34d44fe
TF util new_seq_len cleanup and fixes for better dim tag batch logic
albertz Aug 31, 2021
e768327
Layer base get out data, use dim tags and correct batch, beam
albertz Aug 28, 2021
0389cba
UnflattenNdLayer size dim tag, use same_as_before
albertz Aug 30, 2021
aaea739
CondLayer size dim tag, use same_as_before
albertz Aug 30, 2021
0afe8b9
DecideLayer fix out dim tag
albertz Aug 30, 2021
47296fe
DotLayer, fixes for dim tags, batch
albertz Aug 27, 2021
178320a
select search sources, skip beam expanded tensors
albertz Sep 2, 2021
fe85269
SelectSearchSourcesLayer, fix transform tag batch
albertz Aug 30, 2021
8c28370
Rec layer out data, cleanup and fix, no _post_init_output
albertz Aug 30, 2021
2ee9e41
RecLayer target for time-dim-tag only when not search
albertz Aug 30, 2021
77924e6
Rec subnet time-dim-tag test consistency to template
albertz Aug 30, 2021
3447102
Rec subnet, opt search resolve, fix seq-len without end
albertz Aug 30, 2021
33c9075
Rec subnet, output, fix seq len tile transpose for beam expand
albertz Aug 30, 2021
446bc34
Rec subnet, output, fixed seq len, tag use same_as_before after check
albertz Aug 30, 2021
20b9a6c
Rec subnet, output, fix time dim tag with dyn end
albertz Aug 31, 2021
099b44b
Rec subnet output sizes loop acc layer, fix beam tiling, cleanup
albertz Aug 31, 2021
bec61fc
Rec subnet, cleanup and fix outdated size_placeholder fixup code
albertz Aug 31, 2021
f4f7005
Rec subnet, better check for same time
albertz Aug 31, 2021
0d80173
Rec subnet, output layers moved out, fix time dim tag
albertz Aug 31, 2021
2b91629
Rec subnet, use correct resolved time dim
albertz Aug 31, 2021
1e447a9
Rec subnet resolve directly set _RETURNN_dyn_size_beam
albertz Sep 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 19 additions & 27 deletions returnn/tf/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from returnn.util.basic import NotSpecified, CollectionReadCheckCovered, BehaviorVersion
import returnn.tf.compat as tf_compat
import returnn.tf.util.basic as tf_util
from returnn.tf.util.data import Data, SearchBeam
from returnn.tf.util.data import Data
from returnn.tf.util.basic import OutputWithActivation, CustomUpdate, reuse_name_scope
from returnn.log import log

Expand Down Expand Up @@ -250,6 +250,7 @@ def _base_get_out_data_from_opts(cls, network, name, out_type=None, n_out=NotSpe
:return: Data template (placeholder not set)
:rtype: Data
"""
from ..util.data import DimensionTag
if callable(out_type):
return out_type(
network=network, name=name, n_out=n_out, target=target, size_target=size_target, sources=sources, loss=loss,
Expand All @@ -268,9 +269,8 @@ def _base_get_out_data_from_opts(cls, network, name, out_type=None, n_out=NotSpe
network=network, mark_data_key_as_used=False).dim
if n_out is not NotSpecified:
assert out_type["dim"] == n_out
sources_data = None
if sources and sources[0]:
sources_data = sources[0].output.copy_template()
sources_data_list = [src.output for src in sources if src]
sources_data = Data.get_common_data(sources_data_list, ignore_feature_dim=True) if sources_data_list else None
if sources_data and not sources_data.sparse and not out_type.get("sparse", False):
out_type.setdefault("dtype", sources_data.dtype)
# You are supposed to set self.output.{batch_dim_axis,time_dim_axis} explicitly,
Expand All @@ -291,38 +291,30 @@ def _base_get_out_data_from_opts(cls, network, name, out_type=None, n_out=NotSpe
if "shape" not in out_type and "dim_tags" not in out_type:
if sources_data:
if out_type.get("sparse", False):
out_type.setdefault("shape", sources_data.shape_sparse)
out_type["dim_tags"] = sources_data.dim_tags_sparse
else: # not sparse
feature_dim_axis = out_type.get("feature_dim_axis", NotSpecified)
if feature_dim_axis is NotSpecified:
if sources_data.feature_dim_axis is not None:
feature_dim_axis = sources_data.feature_dim_axis
else:
feature_dim_axis = -1
if sources_data.shape:
default_shape = list(sources_data.shape_dense)
if sources_data.batch_dim_axis is not None:
default_shape.insert(sources_data.batch_dim_axis, None)
default_shape[feature_dim_axis] = out_type.get("dim", None)
if out_type.get("batch_dim_axis") is not None:
default_shape.pop(out_type.get("batch_dim_axis"))
else: # source is scalar
if out_type.get("dim") or out_type.get("feature_dim_axis") is not None:
default_shape = (out_type.get("dim"),)
dim = out_type.get("dim", None)
dim_tags = list(sources_data.dim_tags_sparse)
feature_dim_tag = DimensionTag(
kind=DimensionTag.Types.Feature, description="%s:feature-dense" % name, dimension=dim)
if feature_dim_axis in (NotSpecified, None):
if sources_data.feature_dim_axis is None:
feature_dim_axis = len(dim_tags)
else:
default_shape = ()
out_type.setdefault("shape", tuple(default_shape))
feature_dim_axis = sources_data.feature_dim_axis
dim_tags.insert(feature_dim_axis, feature_dim_tag)
out_type["dim_tags"] = dim_tags
elif network.is_inside_rec_layer():
if out_type.get("sparse", False):
out_type.setdefault("shape", ())
else:
out_type.setdefault("shape", (out_type.get("dim", None),))
# Note: No special handling for feature_dim_axis here for now...
beam = None
for src in sources:
if src: # might be None if template construction
beam = SearchBeam.get_combined_beam(beam, src.output.beam)
out_type.setdefault("beam", beam)
if sources_data and sources_data.batch:
out_type.setdefault("batch", sources_data.batch)
if sources_data and sources_data.beam:
out_type.setdefault("beam", sources_data.beam)
output = Data(**out_type)
cls._post_init_output(
output=output, network=network, target=target, size_target=size_target, _target_layers=_target_layers,
Expand Down
59 changes: 15 additions & 44 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,9 @@ def transform(v):
assert isinstance(v, (tf.Tensor, tf.TensorArray))
if isinstance(v, tf.Tensor) and v.get_shape().ndims == 0:
return v # leave scalars as-is
if isinstance(v, tf.Tensor) and getattr(v, "_RETURNN_beam_expanded_base_data", None):
# This tensor was just expanded by a beam. Selecting beams are not needed.
return v
for i, base_src_choices in enumerate(reversed(search_choices_seq)):
assert isinstance(base_src_choices, SearchChoices)
assert base_src_choices.src_beams is not None, (
Expand All @@ -452,7 +455,7 @@ def transform(v):
i, get_valid_scope_name_from_str(base_src_choices.owner.name),
len(search_choices_seq), get_valid_scope_name_from_str(search_choices.owner.name)))
if tag:
tag.set_tag_on_size_tensor(v, batch=self.output.batch)
tag.set_tag_on_size_tensor(v, batch=self.output.batch.copy_set_beam(base_src_choices.get_beam_info()))
self.used_search_choices_beams = True
return v

Expand Down Expand Up @@ -3128,7 +3131,7 @@ def __init__(self, sizes, num_axes, declare_same_sizes_as=None, **kwargs):
for i, other in declare_same_sizes_as.items():
assert 0 <= i < num_axes
other_dim_tag = other.output.get_size_dim_tag(0)
other_dim_tag.set_tag_on_size_tensor(size_placeholder[i], batch=self.output.batch)
other_dim_tag.set_tag_on_size_tensor(size_placeholder[i], batch=self.output.batch, same_as_before=True)
self.output.size_placeholder = size_placeholder

def get_dep_layers(self):
Expand Down Expand Up @@ -5487,6 +5490,7 @@ def get_out_data_from_opts(cls, name, sources, red1=-1, red2=-2, var1=-2, var2=-
:param bool add_var2_if_empty:
:rtype: Data
"""
from ..util.data import DimensionTag, BatchInfo
assert len(sources) == 2, "dot-layer %r: needs exactly two sources" % (name,)
# See __init__.
a_out = sources[0].output.copy()
Expand All @@ -5508,11 +5512,9 @@ def get_out_data_from_opts(cls, name, sources, red1=-1, red2=-2, var1=-2, var2=-
assert all(b_axis in map_a_to_b_rem_axes.values() for b_axis in b_rem_axes)
b_rem_axes = [map_a_to_b_rem_axes[a_axis] for a_axis in a_rem_axes]

a_shape = a_out.batch_shape
b_shape = b_out.batch_shape
a_rem_dims = [a_shape[i] for i in a_rem_axes]
a_var_dims = [a_shape[i] for i in a_var_axes]
b_var_dims = [b_shape[i] for i in b_var_axes]
a_rem_dims = [a_out.dim_tags[i] for i in a_rem_axes]
a_var_dims = [a_out.dim_tags[i] for i in a_var_axes]
b_var_dims = [b_out.dim_tags[i] for i in b_var_axes]

def find_axis(a_axis, b_axis):
"""
Expand All @@ -5533,49 +5535,18 @@ def find_axis(a_axis, b_axis):
return axis

time_dim_axis = find_axis(a_out.time_dim_axis, b_out.time_dim_axis)
batch_dim_axis = find_axis(a_out.batch_dim_axis, b_out.batch_dim_axis)
assert batch_dim_axis != NotSpecified or (a_out.batch_dim_axis is None and b_out.batch_dim_axis is None)

if not b_var_dims and add_var2_if_empty:
b_var_dims.append(1)

def get_batch_axis_excluding_batch(axis):
"""
:param int axis:
:rtype: int
"""
if batch_dim_axis is None:
return axis
assert axis != batch_dim_axis
if axis < batch_dim_axis:
return axis
return axis - 1

# Collect dynamic size info.
size_placeholder = {}
for axis1_wo_b in sorted(a_out.size_placeholder.keys()):
axis_out_wb = cls._axis1_to_output(a_out.get_batch_axis(axis1_wo_b), a_rem_axes=a_rem_axes, a_var_axes=a_var_axes)
if axis_out_wb is None:
continue
size_placeholder[get_batch_axis_excluding_batch(axis_out_wb)] = a_out.size_placeholder[axis1_wo_b]
for axis2_wo_b in sorted(b_out.size_placeholder.keys()):
axis_out_wb = cls._axis2_to_output(
b_out.get_batch_axis(axis2_wo_b), b_rem_axes=b_rem_axes, a_var_axes=a_var_axes, b_var_axes=b_var_axes)
if axis_out_wb is None or axis_out_wb in size_placeholder:
continue
size_placeholder[get_batch_axis_excluding_batch(axis_out_wb)] = b_out.size_placeholder[axis2_wo_b]

shape = list(a_rem_dims + a_var_dims + b_var_dims)
if batch_dim_axis is not None and batch_dim_axis is not NotSpecified:
shape.pop(batch_dim_axis)
b_var_dims.append(
DimensionTag(kind=DimensionTag.Types.Spatial, description="%s:dot:dummy-var2" % name, dimension=1))

dim_tags = list(a_rem_dims + a_var_dims + b_var_dims)
return Data(
name="%s_output" % name,
shape=tuple(shape),
batch_dim_axis=batch_dim_axis,
dim_tags=dim_tags,
time_dim_axis=time_dim_axis,
dtype=a_out.dtype,
size_placeholder=size_placeholder,
batch=BatchInfo.get_common_batch_info([src.batch for src in (a_out, b_out)]),
beam=SearchBeam.get_combined_beam(a_out.beam, b_out.beam))


Expand Down Expand Up @@ -6361,7 +6332,7 @@ def __init__(self, condition, true_layer, false_layer,
old_size = self.output.size_placeholder[i]
old_tag = DimensionTag.get_tag_from_size_tensor(old_size)
assert old_tag
old_tag.set_tag_on_size_tensor(size, batch=self.output.batch)
old_tag.set_tag_on_size_tensor(size, batch=self.output.batch, same_as_before=True)
self.output.size_placeholder[i] = size

def _cond_layer_return(self, layer):
Expand Down
Loading