Skip to content

Commit d3c82ea

Browse files
Googlertensorflower-gardener
authored andcommitted
Fortify access to ScopedTFGraph with a context manager.
Such that the backing C object it is never destroyed in the middle of a python function that uses it. It is unclear which functions that uses _c_graph can be called on the problematic 'use-during-deletion' path, thus I changed them all to use the fortified .get() API mechanically. While at it, I added a tensor_util API to reduce number of protected access violations caused by TF_TryEvaluateConstant. It is unclear how the C++ evaluate constant relates to the _ConstantValue implemented in tensor_util.py. At least the names are different, and they are now in the same file, hopefully making a future consolidation easier. PiperOrigin-RevId: 455227569
1 parent a6d3239 commit d3c82ea

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

tensorflow_probability/python/internal/prefer_static.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,9 @@
2525
from tensorflow_probability.python.internal import tensorshape_util
2626
from tensorflow_probability.python.internal.backend import numpy as nptf
2727

28-
# Try catch required to avoid breaking Probability opensource presubmits.
29-
# TODO(amitpatankar): Remove this once tf-nightly has latest code.
30-
# pylint: disable=g-import-not-at-top
31-
try:
32-
from tensorflow.python.client import pywrap_tf_session as c_api # pylint: disable=g-direct-tensorflow-import
33-
except ImportError:
34-
from tensorflow.python import pywrap_tensorflow as c_api # pylint: disable=g-direct-tensorflow-import
35-
28+
from tensorflow.python.client import pywrap_tf_session as c_api # pylint: disable=g-direct-tensorflow-import
3629
from tensorflow.python.framework import ops # pylint: disable=g-direct-tensorflow-import
30+
from tensorflow.python.framework import tensor_util # pylint: disable=g-direct-tensorflow-import
3731
from tensorflow.python.ops import control_flow_ops # pylint: disable=g-direct-tensorflow-import
3832
from tensorflow.python.util import tf_inspect # pylint: disable=g-direct-tensorflow-import
3933

@@ -114,14 +108,16 @@ def _get_static_value(pred):
114108
if tf.is_tensor(pred):
115109
pred_value = tf.get_static_value(tf.convert_to_tensor(pred))
116110

117-
# TODO(jamieas): remove the dependency on `pywrap_tensorflow`.
118111
# Explicitly check for ops.Tensor, to avoid an AttributeError
119112
# when requesting `KerasTensor.graph`.
120-
# pylint: disable=protected-access
121113
if pred_value is None and isinstance(pred, ops.Tensor):
122-
pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph,
123-
pred._as_tf_output())
124-
# pylint: enable=protected-access
114+
if hasattr(tensor_util, 'try_evaluate_constant'):
115+
pred_value = tensor_util.try_evaluate_constant(pred)
116+
else:
117+
# TODO(feyu): remove this branch after try_evaluate_constant is in
118+
# tf-nightly.
119+
pred_value = c_api.TF_TryEvaluateConstant_wrapper(
120+
pred.graph._c_graph, pred._as_tf_output()) # pylint: disable=protected-access
125121
return pred_value
126122
return pred
127123

0 commit comments

Comments
 (0)