@@ -1354,28 +1354,38 @@ def get_runtime_sanity_check_op(self):
13541354 checks = []
13551355 with tf .name_scope ("runtime_sanity_check" ):
13561356 shape = tf .shape (self .placeholder )
1357+ batch_dim = shape [self .batch_dim_axis ] if self .have_batch_axis () else 1
13571358 rank = tf .rank (self .placeholder )
1358- data = [str (self ), "shape" , shape ]
1359+ data = ["Data.get_runtime_sanity_check_op:" , str (self ), "shape" , shape ]
13591360 for i , tag in enumerate (self .dim_tags ):
13601361 if tag .dyn_size is not None :
1361- data += ["dyn_size[%i]" % i , tag .dyn_size , ".shape" , tf .shape (tag .dyn_size )]
1362+ data += [
1363+ "dyn_size[%i] (%s)" % (i , tag ), tag .dyn_size , ".shape" , tf .shape (tag .dyn_size )]
13621364 checks += [tf .Assert (tf .equal (rank , self .batch_ndim ), data + ["-> invalid rank" ])]
1365+ if self .have_batch_axis ():
1366+ batch_dim_via_info = self .get_batch_dim ()
1367+ checks += [
1368+ tf .Assert (tf .equal (batch_dim , batch_dim_via_info ), data + ["-> invalid batch dim info" , batch_dim_via_info ])]
13631369 for i in range (self .batch_ndim ):
13641370 if self .batch_shape [i ] is not None :
13651371 checks += [tf .Assert (tf .equal (shape [i ], self .batch_shape [i ]), data + ["-> invalid shape[%i]" % i ])]
1366- dyn_size = self .dim_tags [i ].dyn_size
1367- if dyn_size is not None :
1372+ dyn_size_ext = self .dim_tags [i ].dyn_size_ext
1373+ if dyn_size_ext and dyn_size_ext .placeholder is not None :
1374+ dyn_size = dyn_size_ext .placeholder
1375+ if dyn_size_ext .have_batch_axis () and self .have_batch_axis ():
1376+ checks += [tf .Assert (
1377+ tf .equal (tf .shape (dyn_size )[dyn_size_ext .batch_dim_axis ], batch_dim ),
1378+ data + ["-> invalid axis %i tag dyn size batch dim" % i ])]
13681379 checks += [tf .Assert (
13691380 # Note: in almost all cases, we have equality here.
13701381 # However, not strictly in all cases, e.g. DecideLayer, maybe some others...
1371- tf .less_equal (tf .reduce_max (dyn_size ), shape [i ]),
1382+ tf .logical_or (
1383+ tf .less_equal (tf .reduce_max (dyn_size ), shape [i ]),
1384+ # In other rare cases, this might be a broadcast dim
1385+ # (e.g. as initial values of att weights for a rec loop).
1386+ tf .equal (1 , shape [i ])),
13721387 data + ["-> invalid shape[%i] or max(dyn_size[%i])" % (i , i )])]
1373- batch_dim = shape [self .batch_dim_axis ] if self .have_batch_axis () else 1
1374- for i , tag in enumerate (self .dim_tags ):
1375- if tag .dyn_size is not None :
1376- checks += [tf .Assert (
1377- tf .reduce_all (tf .equal (tf .shape (tag .dyn_size ), [batch_dim ])),
1378- data + ["-> invalid shape(dyn_size[%i]) or invalid batch dim" % i , batch_dim ])]
1388+ checks += [dyn_size_ext .get_runtime_sanity_check_op ()]
13791389 return tf .group (* checks )
13801390
13811391 def get_placeholder_kwargs (self , with_batch = True ):
0 commit comments