Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1886,9 +1886,7 @@ def get_layer(name):
prev_layer = prev_layers[layer_name]
assert layer.output.batch_shape == prev_layer.output.batch_shape
assert layer.output.batch_dim_axis == prev_layer.output.batch_dim_axis
assert sorted(layer.output.size_placeholder.keys()) == sorted(prev_layer.output.size_placeholder.keys())
for i in range(len(layer.output.size_placeholder)):
assert layer.output.get_size_dim_tag(i) == prev_layer.output.get_size_dim_tag(i)
assert layer.output.get_dyn_size_tags() == prev_layer.output.get_dyn_size_tags()

def get_prev_template_layer(self, layer_name):
"""
Expand Down
13 changes: 10 additions & 3 deletions returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def is_dynamic(self):
:return: whether the dim is not static. usually means that it has seq lengths
:rtype: bool
"""
return self.dimension is not None
return self.dimension is None and not self.is_batch_dim()

def can_be_used_as_dim(self):
"""
Expand Down Expand Up @@ -5412,13 +5412,20 @@ def get_time_dim_tag(self):
assert self.time_dim_axis is not None
return self.get_dim_tag(self.time_dim_axis)

def get_dyn_size_tags(self):
"""
:return: all dim tags with dynamic size
:rtype: list[Dim]
"""
return [dim_tag for dim_tag in self._dim_tags if dim_tag.is_dynamic()]

def get_size_dim_tag(self, number):
"""
:param int number: index in sorted(size_placeholder.keys())
:rtype: Dim
"""
axis_wo_batch = sorted(self.size_placeholder.keys())[number]
return self.get_dim_tag(self.get_batch_axis(axis_wo_batch))
dyn_size_tags = self.get_dyn_size_tags()
return dyn_size_tags[number]

def get_batch_shape_dim_tags(self):
"""
Expand Down