|
25 | 25 | from tensorflow_probability.python.internal import tensorshape_util
|
26 | 26 | from tensorflow_probability.python.internal.backend import numpy as nptf
|
27 | 27 |
|
28 |
| -# Try catch required to avoid breaking Probability opensource presubmits. |
29 |
| -# TODO(amitpatankar): Remove this once tf-nightly has latest code. |
30 |
| -# pylint: disable=g-import-not-at-top |
31 |
| -try: |
32 |
| - from tensorflow.python.client import pywrap_tf_session as c_api # pylint: disable=g-direct-tensorflow-import |
33 |
| -except ImportError: |
34 |
| - from tensorflow.python import pywrap_tensorflow as c_api # pylint: disable=g-direct-tensorflow-import |
35 |
| - |
| 28 | +from tensorflow.python.client import pywrap_tf_session as c_api # pylint: disable=g-direct-tensorflow-import |
36 | 29 | from tensorflow.python.framework import ops # pylint: disable=g-direct-tensorflow-import
|
| 30 | +from tensorflow.python.framework import tensor_util # pylint: disable=g-direct-tensorflow-import |
37 | 31 | from tensorflow.python.ops import control_flow_ops # pylint: disable=g-direct-tensorflow-import
|
38 | 32 | from tensorflow.python.util import tf_inspect # pylint: disable=g-direct-tensorflow-import
|
39 | 33 |
|
@@ -114,14 +108,16 @@ def _get_static_value(pred):
|
114 | 108 | if tf.is_tensor(pred):
|
115 | 109 | pred_value = tf.get_static_value(tf.convert_to_tensor(pred))
|
116 | 110 |
|
117 |
| - # TODO(jamieas): remove the dependency on `pywrap_tensorflow`. |
118 | 111 | # Explicitly check for ops.Tensor, to avoid an AttributeError
|
119 | 112 | # when requesting `KerasTensor.graph`.
|
120 |
| - # pylint: disable=protected-access |
121 | 113 | if pred_value is None and isinstance(pred, ops.Tensor):
|
122 |
| - pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph, |
123 |
| - pred._as_tf_output()) |
124 |
| - # pylint: enable=protected-access |
| 114 | + if hasattr(tensor_util, 'try_evaluate_constant'): |
| 115 | + pred_value = tensor_util.try_evaluate_constant(pred) |
| 116 | + else: |
| 117 | + # TODO(feyu): remove this branch after try_evaluate_constant is in |
| 118 | + # tf-nightly. |
| 119 | + pred_value = c_api.TF_TryEvaluateConstant_wrapper( |
| 120 | + pred.graph._c_graph, pred._as_tf_output()) # pylint: disable=protected-access |
125 | 121 | return pred_value
|
126 | 122 | return pred
|
127 | 123 |
|
|
0 commit comments