diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index f256fafda..636062eaa 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -1478,28 +1478,41 @@ def get_runtime_sanity_check_op(self): checks = [] with tf.name_scope("runtime_sanity_check"): shape = tf.shape(self.placeholder) + batch_dim = shape[self.batch_dim_axis] if self.have_batch_axis() else 1 rank = tf.rank(self.placeholder) - data = [str(self), "shape", shape] + data = ["Data.get_runtime_sanity_check_op:", str(self), "shape", shape] for i, tag in enumerate(self.dim_tags): if tag.dyn_size is not None: - data += ["dyn_size[%i]" % i, tag.dyn_size, ".shape", tf.shape(tag.dyn_size)] + data += [ + "dyn_size[%i] (%s)" % (i, tag), tag.dyn_size, ".shape", tf.shape(tag.dyn_size)] checks += [tf.Assert(tf.equal(rank, self.batch_ndim), data + ["-> invalid rank"])] + if self.have_batch_axis(): + batch_dim_via_info = self.get_batch_dim() + checks += [ + tf.Assert(tf.equal(batch_dim, batch_dim_via_info), data + ["-> invalid batch dim info", batch_dim_via_info])] for i in range(self.batch_ndim): if self.batch_shape[i] is not None: checks += [tf.Assert(tf.equal(shape[i], self.batch_shape[i]), data + ["-> invalid shape[%i]" % i])] - dyn_size = self.dim_tags[i].dyn_size - if dyn_size is not None: + dyn_size_ext = self.dim_tags[i].dyn_size_ext + if dyn_size_ext and dyn_size_ext.placeholder is not None: + dyn_size = dyn_size_ext.placeholder + if dyn_size_ext.have_batch_axis() and self.have_batch_axis(): + checks += [tf.Assert( + tf.equal(tf.shape(dyn_size)[dyn_size_ext.batch_dim_axis], batch_dim), + data + ["-> invalid axis %i tag dyn size batch dim" % i])] checks += [tf.Assert( # Note: in almost all cases, we have equality here. # However, not strictly in all cases, e.g. DecideLayer, maybe some others... - tf.less_equal(tf.reduce_max(dyn_size), shape[i]), + # But that should not be more than 1 less. + tf.logical_or( + tf.logical_and( + tf.less_equal(tf.reduce_max(dyn_size), shape[i]), + tf.greater_equal(tf.reduce_max(dyn_size), shape[i] - 1)), + # In other rare cases, this might be a broadcast dim + # (e.g. as initial values of att weights for a rec loop). + tf.equal(1, shape[i])), data + ["-> invalid shape[%i] or max(dyn_size[%i])" % (i, i)])] - batch_dim = shape[self.batch_dim_axis] if self.have_batch_axis() else 1 - for i, tag in enumerate(self.dim_tags): - if tag.dyn_size is not None: - checks += [tf.Assert( - tf.reduce_all(tf.equal(tf.shape(tag.dyn_size), [batch_dim])), - data + ["-> invalid shape(dyn_size[%i]) or invalid batch dim" % i, batch_dim])] + checks += [dyn_size_ext.get_runtime_sanity_check_op()] return tf.group(*checks) def get_placeholder_kwargs(self, with_batch=True):