Skip to content

Commit 719eba8

Browse files
committed
Data.get_runtime_sanity_check_op extend
1 parent 3339d7a commit 719eba8

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

returnn/tf/util/data.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,28 +1354,38 @@ def get_runtime_sanity_check_op(self):
13541354
checks = []
13551355
with tf.name_scope("runtime_sanity_check"):
13561356
shape = tf.shape(self.placeholder)
1357+
batch_dim = shape[self.batch_dim_axis] if self.have_batch_axis() else 1
13571358
rank = tf.rank(self.placeholder)
1358-
data = [str(self), "shape", shape]
1359+
data = ["Data.get_runtime_sanity_check_op:", str(self), "shape", shape]
13591360
for i, tag in enumerate(self.dim_tags):
13601361
if tag.dyn_size is not None:
1361-
data += ["dyn_size[%i]" % i, tag.dyn_size, ".shape", tf.shape(tag.dyn_size)]
1362+
data += [
1363+
"dyn_size[%i] (%s)" % (i, tag), tag.dyn_size, ".shape", tf.shape(tag.dyn_size)]
13621364
checks += [tf.Assert(tf.equal(rank, self.batch_ndim), data + ["-> invalid rank"])]
1365+
if self.have_batch_axis():
1366+
batch_dim_via_info = self.get_batch_dim()
1367+
checks += [
1368+
tf.Assert(tf.equal(batch_dim, batch_dim_via_info), data + ["-> invalid batch dim info", batch_dim_via_info])]
13631369
for i in range(self.batch_ndim):
13641370
if self.batch_shape[i] is not None:
13651371
checks += [tf.Assert(tf.equal(shape[i], self.batch_shape[i]), data + ["-> invalid shape[%i]" % i])]
1366-
dyn_size = self.dim_tags[i].dyn_size
1367-
if dyn_size is not None:
1372+
dyn_size_ext = self.dim_tags[i].dyn_size_ext
1373+
if dyn_size_ext and dyn_size_ext.placeholder is not None:
1374+
dyn_size = dyn_size_ext.placeholder
1375+
if dyn_size_ext.have_batch_axis() and self.have_batch_axis():
1376+
checks += [tf.Assert(
1377+
tf.equal(tf.shape(dyn_size)[dyn_size_ext.batch_dim_axis], batch_dim),
1378+
data + ["-> invalid axis %i tag dyn size batch dim" % i])]
13681379
checks += [tf.Assert(
13691380
# Note: in almost all cases, we have equality here.
13701381
# However, not strictly in all cases, e.g. DecideLayer, maybe some others...
1371-
tf.less_equal(tf.reduce_max(dyn_size), shape[i]),
1382+
tf.logical_or(
1383+
tf.less_equal(tf.reduce_max(dyn_size), shape[i]),
1384+
# In other rare cases, this might be a broadcast dim
1385+
# (e.g. as initial values of att weights for a rec loop).
1386+
tf.equal(1, shape[i])),
13721387
data + ["-> invalid shape[%i] or max(dyn_size[%i])" % (i, i)])]
1373-
batch_dim = shape[self.batch_dim_axis] if self.have_batch_axis() else 1
1374-
for i, tag in enumerate(self.dim_tags):
1375-
if tag.dyn_size is not None:
1376-
checks += [tf.Assert(
1377-
tf.reduce_all(tf.equal(tf.shape(tag.dyn_size), [batch_dim])),
1378-
data + ["-> invalid shape(dyn_size[%i]) or invalid batch dim" % i, batch_dim])]
1388+
checks += [dyn_size_ext.get_runtime_sanity_check_op()]
13791389
return tf.group(*checks)
13801390

13811391
def get_placeholder_kwargs(self, with_batch=True):

0 commit comments

Comments
 (0)