Skip to content

Commit a74cdbe

Browse files
zoyahavtfx-copybara
authored andcommitted
Switching from checking if func_graph.FuncGraph to tf.inside_function().
PiperOrigin-RevId: 519751306
1 parent e171594 commit a74cdbe

File tree

3 files changed

+8
-27
lines changed

3 files changed

+8
-27
lines changed

tensorflow_transform/analyzer_nodes.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from tfx_bsl.types import tfx_namedtuple
4141

4242
# pylint: disable=g-direct-tensorflow-import
43-
from tensorflow.python.framework import func_graph
4443
from tensorflow.python.framework import ops
4544
# pylint: disable=g-enable-tensorflow-import
4645

@@ -303,8 +302,7 @@ def bind_future_as_tensor(
303302
tensor_info: TensorInfo,
304303
name: Optional[str] = None) -> common_types.TemporaryAnalyzerOutputType:
305304
"""Bind a future value as a tensor."""
306-
# TODO(b/165884902): Use tf.inside_function after dropping TF 2.3 support.
307-
if isinstance(ops.get_default_graph(), func_graph.FuncGraph):
305+
if tf.inside_function():
308306
# If the default graph is a `FuncGraph`, tf.function was used to trace the
309307
# preprocessing fn.
310308
return _bind_future_as_tensor_v2(future, tensor_info, name)

tensorflow_transform/annotators.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,7 @@
2323

2424
import tensorflow as tf
2525
from tensorflow_transform.graph_context import TFGraphContext
26-
27-
# pylint: disable=g-direct-tensorflow-import
28-
from tensorflow.python.framework import func_graph
29-
from tensorflow.python.framework import ops
30-
# pylint: disable=g-import-not-at-top
31-
try:
32-
# Moved in TensorFlow 2.10.
33-
from tensorflow.python.trackable import base
34-
except ImportError:
35-
from tensorflow.python.training.tracking import base
36-
# pylint: enable=g-direct-tensorflow-import, g-import-not-at-top
26+
from tensorflow.python.trackable import base # pylint: disable=g-direct-tensorflow-import
3727

3828
__all__ = ['annotate_asset', 'make_and_track_object']
3929

@@ -170,8 +160,7 @@ def make_and_track_object(trackable_factory_callable: Callable[[],
170160
creation is lifted out to the eager context using `tf.init_scope`.
171161
"""
172162
# pyformat: enable
173-
# TODO(b/165884902): Use tf.inside_function after dropping TF 1.15 support.
174-
if not isinstance(ops.get_default_graph(), func_graph.FuncGraph):
163+
if not tf.inside_function():
175164
raise ValueError('This API should only be invoked inside the user defined '
176165
'`preprocessing_fn` with TF2 behaviors enabled and '
177166
'`force_tf_compat_v1=False`. ')

tensorflow_transform/tf_utils.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
# pylint: disable=g-direct-tensorflow-import
2828
from tensorflow.python.framework import composite_tensor
29-
from tensorflow.python.framework import func_graph
3029
from tensorflow.python.framework import ops
3130
from tensorflow.python.ops import lookup_ops
3231
from tensorflow.python.util import object_identity
@@ -652,14 +651,12 @@ def make_tfrecord_vocabulary_lookup_initializer(filename_tensor,
652651
return_indicator_as_value=False,
653652
has_indicator=False):
654653
"""Makes a lookup table initializer from a compressed tfrecord file."""
655-
graph = ops.get_default_graph()
656654
with contextlib.ExitStack() as stack:
657-
# TODO(b/165884902): Use tf.inside_function after dropping TF 2.3 support.
658655
# If filename_tensor is a graph tensor (e.g. temporary analyzer output), the
659656
# following operation cannot be lifted to init scope. Hence, check it is an
660657
# eager tensor or a string constant.
661-
if isinstance(graph, func_graph.FuncGraph) and isinstance(
662-
filename_tensor, (ops.EagerTensor, str)):
658+
if (tf.inside_function() and
659+
isinstance(filename_tensor, (ops.EagerTensor, str))):
663660
# Lift the dataset creation out of graph construction to avoid
664661
# repeated initialization in TF2.
665662
stack.enter_context(tf.init_scope())
@@ -668,8 +665,7 @@ def make_tfrecord_vocabulary_lookup_initializer(filename_tensor,
668665
value_dtype,
669666
return_indicator_as_value,
670667
has_indicator)
671-
# TODO(b/165884902): Use tf.inside_function after dropping TF 2.3 support.
672-
if isinstance(graph, func_graph.FuncGraph):
668+
if tf.inside_function():
673669
annotators.track_object(dataset, name=None)
674670
return _DatasetInitializerCompat(dataset)
675671

@@ -1642,16 +1638,14 @@ def construct_and_lookup_table(construct_table_callable: Callable[
16421638
A tuple of the result from looking x up in a table and the table's size.
16431639
16441640
"""
1645-
graph = ops.get_default_graph()
16461641
# If table is lifted into an initialization scope, add a control dependency
16471642
# on the graph tensor used to track this analyzer in
16481643
# `analyzer_nodes.TENSOR_REPLACEMENTS`.
16491644
asset_filepath, control_dependency = (
16501645
_get_asset_analyzer_output_and_control_dependency(asset_filepath))
16511646
with contextlib.ExitStack() as stack:
1652-
# TODO(b/165884902): Use tf.inside_function after dropping TF 2.3 support.
1653-
if isinstance(graph, func_graph.FuncGraph) and isinstance(
1654-
asset_filepath, (ops.EagerTensor, str)):
1647+
if (tf.inside_function() and
1648+
isinstance(asset_filepath, (ops.EagerTensor, str))):
16551649
# Lift the table initialization out of graph construction to avoid
16561650
# repeated initialization in TF2.
16571651
stack.enter_context(tf.init_scope())

0 commit comments

Comments
 (0)