@@ -210,7 +210,8 @@ def __init__(self, inputs=None, outputs=None, name=None):
210210 # check type of inputs and outputs
211211 check_order = ['inputs' , 'outputs' ]
212212 for co , check_argu in enumerate ([inputs , outputs ]):
213- if isinstance (check_argu , tf_ops ._TensorLike ) or tf_ops .is_dense_tensor_like (check_argu ):
213+ if isinstance (check_argu ,
214+ (tf .Tensor , tf .SparseTensor , tf .Variable )) or tf_ops .is_dense_tensor_like (check_argu ):
214215 pass
215216 elif isinstance (check_argu , list ):
216217 if len (check_argu ) == 0 :
@@ -219,8 +220,9 @@ def __init__(self, inputs=None, outputs=None, name=None):
219220 "It should be either Tensor or a list of Tensor."
220221 )
221222 for idx in range (len (check_argu )):
222- if not isinstance (check_argu [idx ], tf_ops ._TensorLike ) or not tf_ops .is_dense_tensor_like (
223- check_argu [idx ]):
223+ if not isinstance (check_argu [idx ],
224+ (tf .Tensor , tf .SparseTensor , tf .Variable )) or not tf_ops .is_dense_tensor_like (
225+ check_argu [idx ]):
224226 raise TypeError (
225227 "The argument `%s` should be either Tensor or a list of Tensor " % (check_order [co ]) +
226228 "but the %s[%d] is detected as %s" % (check_order [co ], idx , type (check_argu [idx ]))
0 commit comments