Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,28 +1478,41 @@ def get_runtime_sanity_check_op(self):
checks = []
with tf.name_scope("runtime_sanity_check"):
shape = tf.shape(self.placeholder)
batch_dim = shape[self.batch_dim_axis] if self.have_batch_axis() else 1
rank = tf.rank(self.placeholder)
data = [str(self), "shape", shape]
data = ["Data.get_runtime_sanity_check_op:", str(self), "shape", shape]
for i, tag in enumerate(self.dim_tags):
if tag.dyn_size is not None:
data += ["dyn_size[%i]" % i, tag.dyn_size, ".shape", tf.shape(tag.dyn_size)]
data += [
"dyn_size[%i] (%s)" % (i, tag), tag.dyn_size, ".shape", tf.shape(tag.dyn_size)]
checks += [tf.Assert(tf.equal(rank, self.batch_ndim), data + ["-> invalid rank"])]
if self.have_batch_axis():
batch_dim_via_info = self.get_batch_dim()
checks += [
tf.Assert(tf.equal(batch_dim, batch_dim_via_info), data + ["-> invalid batch dim info", batch_dim_via_info])]
for i in range(self.batch_ndim):
if self.batch_shape[i] is not None:
checks += [tf.Assert(tf.equal(shape[i], self.batch_shape[i]), data + ["-> invalid shape[%i]" % i])]
dyn_size = self.dim_tags[i].dyn_size
if dyn_size is not None:
dyn_size_ext = self.dim_tags[i].dyn_size_ext
if dyn_size_ext and dyn_size_ext.placeholder is not None:
dyn_size = dyn_size_ext.placeholder
if dyn_size_ext.have_batch_axis() and self.have_batch_axis():
checks += [tf.Assert(
tf.equal(tf.shape(dyn_size)[dyn_size_ext.batch_dim_axis], batch_dim),
data + ["-> invalid axis %i tag dyn size batch dim" % i])]
checks += [tf.Assert(
# Note: in almost all cases, we have equality here.
# However, not strictly in all cases, e.g. DecideLayer, maybe some others...
tf.less_equal(tf.reduce_max(dyn_size), shape[i]),
# But that should not be more than 1 less.
tf.logical_or(
tf.logical_and(
tf.less_equal(tf.reduce_max(dyn_size), shape[i]),
tf.greater_equal(tf.reduce_max(dyn_size), shape[i] - 1)),
# In other rare cases, this might be a broadcast dim
# (e.g. as initial values of att weights for a rec loop).
tf.equal(1, shape[i])),
data + ["-> invalid shape[%i] or max(dyn_size[%i])" % (i, i)])]
batch_dim = shape[self.batch_dim_axis] if self.have_batch_axis() else 1
for i, tag in enumerate(self.dim_tags):
if tag.dyn_size is not None:
checks += [tf.Assert(
tf.reduce_all(tf.equal(tf.shape(tag.dyn_size), [batch_dim])),
data + ["-> invalid shape(dyn_size[%i]) or invalid batch dim" % i, batch_dim])]
checks += [dyn_size_ext.get_runtime_sanity_check_op()]
return tf.group(*checks)

def get_placeholder_kwargs(self, with_batch=True):
Expand Down