Skip to content

Commit 1c82cf3

Browse files
zoyahavtfx-copybara
authored andcommitted
Continue to remove TF1 related logic from TFT
PiperOrigin-RevId: 493101025
1 parent 7ca243a commit 1c82cf3

File tree

10 files changed

+13
-89
lines changed

10 files changed

+13
-89
lines changed

tensorflow_transform/analyzers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def _apply_cacheable_combiner_per_key(
200200
def _apply_cacheable_combiner_per_key_large(
201201
combiner: analyzer_nodes.Combiner, key_vocabulary_filename: str,
202202
*tensor_inputs: common_types.TensorType
203-
) -> Union[tf.Tensor, common_types.Asset]:
203+
) -> Union[tf.Tensor, tf.saved_model.Asset]:
204204
"""Similar to above but saves the combined result to a file."""
205205
input_values_node = analyzer_nodes.get_input_tensors_value_nodes(
206206
tensor_inputs)
@@ -1072,7 +1072,7 @@ def _mean_and_var_per_key(
10721072
output_dtype: Optional[tf.DType] = None,
10731073
key_vocabulary_filename: Optional[str] = None
10741074
) -> Union[Tuple[tf.Tensor, tf.Tensor, tf.Tensor], tf.Tensor,
1075-
common_types.Asset]:
1075+
tf.saved_model.Asset]:
10761076
"""`mean_and_var` by group, specified by key.
10771077
10781078
Args:

tensorflow_transform/beam/impl_test.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
import contextlib
1817
import itertools
1918
import math
2019
import os
@@ -562,9 +561,6 @@ def preprocessing_fn(inputs):
562561
expected_metadata)
563562

564563
def testPyFuncs(self):
565-
if not tft_unit.is_tf_api_version_1():
566-
raise unittest.SkipTest('Test disabled when TF 2.x behavior enabled.')
567-
568564
def my_multiply(x, y):
569565
return x*y
570566

@@ -628,14 +624,11 @@ def preprocessing_fn(inputs):
628624
})
629625
self.assertAnalyzeAndTransformResults(
630626
input_data, input_metadata, preprocessing_fn, expected_data,
631-
expected_metadata)
627+
expected_metadata, force_tf_compat_v1=True)
632628

633629
def testAssertsNoReturnPyFunc(self):
634630
# Asserts that apply_pyfunc raises an exception if the passed function does
635631
# not return anything.
636-
if not tft_unit.is_tf_api_version_1():
637-
raise unittest.SkipTest('Test disabled when TF 2.x behavior enabled.')
638-
639632
self._SkipIfOutputRecordBatches()
640633

641634
def bad_func():
@@ -684,7 +677,8 @@ def preprocessing_fn(inputs):
684677
preprocessing_fn,
685678
expected_data,
686679
expected_metadata,
687-
desired_batch_size=batch_size)
680+
desired_batch_size=batch_size,
681+
force_tf_compat_v1=True)
688682

689683
def testWithUnicode(self):
690684
def preprocessing_fn(inputs):
@@ -4714,12 +4708,6 @@ def testEmptySchema(self):
47144708
preprocessing_fn=lambda inputs: inputs) # pyformat: disable
47154709

47164710
def testLoadKerasModelInPreprocessingFn(self):
4717-
4718-
if tft_unit.is_tf_api_version_1():
4719-
raise unittest.SkipTest(
4720-
'`tft.make_and_track_object` is only supported when TF2 behavior is '
4721-
'enabled.')
4722-
47234711
def _create_model(features, target):
47244712
inputs = [
47254713
tf.keras.Input(shape=(1,), name=f, dtype=tf.float32) for f in features
@@ -4797,11 +4785,8 @@ def preprocessing_fn(inputs):
47974785
'f3': 1
47984786
}]
47994787

4800-
with contextlib.ExitStack() as stack:
4801-
if not tft_unit.is_tf_api_version_1():
4802-
stack.enter_context(
4803-
self.assertRaisesRegex(
4804-
RuntimeError, 'analyzers.*appears to be non-deterministic'))
4788+
with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises
4789+
RuntimeError, 'analyzers.*appears to be non-deterministic'):
48054790
self.assertAnalyzeAndTransformResults(input_data, input_metadata,
48064791
preprocessing_fn, expected_outputs)
48074792

tensorflow_transform/beam/tft_unit.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
cross_parameters = test_case.cross_parameters
3838
named_parameters = test_case.named_parameters
3939
cross_named_parameters = test_case.cross_named_parameters
40-
is_tf_api_version_1 = test_case.is_tf_api_version_1
4140
is_external_environment = test_case.is_external_environment
4241
skip_if_not_tf2 = test_case.skip_if_not_tf2
4342
SkipTest = test_case.SkipTest

tensorflow_transform/beam/vocabulary_tfrecord_gzip_integration_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ def setUp(self):
3939
mock_is_vocabulary_tfrecord_supported.side_effect = lambda: True
4040

4141
if (tft_unit.is_external_environment() and
42-
not tf_utils.is_vocabulary_tfrecord_supported() or
43-
tft_unit.is_tf_api_version_1()):
42+
not tf_utils.is_vocabulary_tfrecord_supported()):
4443
raise unittest.SkipTest('Test requires async DatasetInitializer')
4544
super().setUp()
4645

tensorflow_transform/common_types.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,6 @@
2121

2222
from tensorflow_metadata.proto.v0 import schema_pb2
2323

24-
# TODO(b/160294509): Stop using tracking.TrackableAsset when TF1.15 support is
25-
# dropped.
26-
if hasattr(tf.saved_model, 'Asset'):
27-
Asset = tf.saved_model.Asset # pylint: disable=invalid-name
28-
else:
29-
from tensorflow.python.training.tracking import tracking # pylint: disable=g-direct-tensorflow-import, g-import-not-at-top
30-
Asset = tracking.TrackableAsset # pylint: disable=invalid-name
31-
3224
# TODO(b/185719271): Define BucketBoundariesType at module level of mappers.py.
3325
BucketBoundariesType = Union[tf.Tensor, Iterable[Union[int, float]]]
3426

@@ -53,7 +45,7 @@
5345
tf.compat.v1.ragged.RaggedTensorValue]
5446
TensorValueType = Union[tf.Tensor, np.ndarray, SparseTensorValueType,
5547
RaggedTensorValueType]
56-
TemporaryAnalyzerOutputType = Union[tf.Tensor, Asset]
48+
TemporaryAnalyzerOutputType = Union[tf.Tensor, tf.saved_model.Asset]
5749
VocabularyFileFormatType = Literal['text', 'tfrecord_gzip']
5850

5951

tensorflow_transform/graph_tools_test.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,6 @@
3232
mock = tf.compat.v1.test.mock
3333

3434

35-
def _skip_if_external_environment_or_v1_api(reason):
36-
test_case.skip_if_external_environment(reason)
37-
if test_case.is_tf_api_version_1():
38-
raise test_case.SkipTest(reason)
39-
40-
4135
def _create_lookup_table_from_file(filename):
4236
initializer = tf.lookup.TextFileInitializer(
4337
filename,
@@ -1098,7 +1092,7 @@ class GraphToolsTestUniquePath(test_case.TransformTestCase):
10981092
}),
10991093
dict(
11001094
testcase_name='_y_function_of_x_with_tf_while',
1101-
skip_test_check_fn=_skip_if_external_environment_or_v1_api,
1095+
skip_test_check_fn=test_case.skip_if_external_environment,
11021096
create_graph_fn=_create_graph_with_tf_function_while,
11031097
feeds=['x'],
11041098
replaced_tensors_ready={'x': False},

tensorflow_transform/output_wrapper.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from tensorflow_transform.tf_metadata import schema_utils
3131

3232
# pylint: disable=g-direct-tensorflow-import
33-
from tensorflow.python import tf2
3433
from tensorflow.python.framework import ops
3534
from tensorflow.tools.docs import doc_controls
3635
# pylint: enable=g-direct-tensorflow-import
@@ -427,40 +426,8 @@ def post_transform_statistics_path(self) -> str:
427426
self._transform_output_dir, self.POST_TRANSFORM_FEATURE_STATS_PATH)
428427

429428

430-
# TODO(zoyahav): Use register_keras_serializable directly once we no longer support
431-
# TF<2.1.
432-
def _maybe_register_keras_serializable(package):
433-
if hasattr(tf.keras.utils, 'register_keras_serializable'):
434-
return tf.keras.utils.register_keras_serializable(package=package)
435-
else:
436-
return lambda cls: cls
437-
438-
439-
def _check_tensorflow_version():
440-
"""Check that we're using a compatible TF version.
441-
442-
Raises a warning if either Tensorflow version is less that 2.0 or TF 2.x is
443-
not enabled.
444-
445-
If TF 2.x is enabled, but version is < TF 2.3, raises a warning to indicate
446-
that resources may not be initialized.
447-
"""
448-
major, minor, _ = tf.version.VERSION.split('.')
449-
if not (int(major) >= 2 and tf2.enabled()):
450-
tf.compat.v1.logging.warning(
451-
'Tensorflow version (%s) found. TransformFeaturesLayer is supported '
452-
'only for TF 2.x with TF 2.x behaviors enabled and may not work as '
453-
'intended.', tf.version.VERSION)
454-
elif int(major) == 2 and int(minor) < 3:
455-
# TODO(varshaan): Log a more specific warning.
456-
tf.compat.v1.logging.warning(
457-
'Tensorflow version (%s) found. TransformFeaturesLayer may not work '
458-
'as intended if the SavedModel contains an initialization op.',
459-
tf.version.VERSION)
460-
461-
462429
# TODO(b/162055065): Possibly switch back to inherit from Layer when possible.
463-
@_maybe_register_keras_serializable(package='TensorFlowTransform')
430+
@tf.keras.utils.register_keras_serializable(package='TensorFlowTransform')
464431
class TransformFeaturesLayer(tf.keras.Model):
465432
"""A Keras layer for applying a tf.Transform output to input layers."""
466433

@@ -478,7 +445,6 @@ def __init__(self,
478445
self._loaded_saved_model_graph = None
479446
# TODO(b/160294509): Use tf.compat.v1 when we stop supporting TF 1.15.
480447
if ops.executing_eagerly_outside_functions():
481-
_check_tensorflow_version()
482448
# The model must be tracked by assigning to an attribute of the Keras
483449
# layer. Hence, we track the attributes of _saved_model_loader here as
484450
# well.

tensorflow_transform/saved/saved_transform_io_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ def _variable_creator(next_creator, **kwargs):
522522
initializers.append(resource._initializer) # pylint: disable=protected-access
523523
module.initializers = initializers
524524
module.assets = [
525-
common_types.Asset(asset_filepath) for asset_filepath in
525+
tf.saved_model.Asset(asset_filepath) for asset_filepath in
526526
concrete_fn.graph.get_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
527527
]
528528
return concrete_fn

tensorflow_transform/test_case.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,6 @@
3535
SkipTest = unittest.SkipTest
3636

3737

38-
def is_tf_api_version_1():
39-
return hasattr(tf, 'Session')
40-
41-
4238
def cross_named_parameters(*args):
4339
"""Cross a list of lists of dicts suitable for @named_parameters.
4440
@@ -239,8 +235,7 @@ def skip_if_external_environment(reason):
239235

240236

241237
def skip_if_not_tf2(reason):
242-
major, _, _ = tf.version.VERSION.split('.')
243-
if not (int(major) >= 2 and tf2.enabled()) or is_tf_api_version_1():
238+
if not tf2.enabled():
244239
raise unittest.SkipTest(reason)
245240

246241

tensorflow_transform/tf_utils_test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2292,9 +2292,6 @@ def test_split_vocabulary_entries(self):
22922292
self.assertAllEqual(self.evaluate(keys), np.array(expected_keys))
22932293
self.assertAllEqual(self.evaluate(values), np.array(expected_values))
22942294

2295-
@unittest.skipIf(
2296-
test_case.is_tf_api_version_1(),
2297-
'TFRecord vocabulary dataset tests require TF API version>1')
22982295
def test_read_tfrecord_vocabulary_dataset(self):
22992296
vocab_file = os.path.join(self.get_temp_dir(), 'vocab.tfrecord.gz')
23002297
contents = [b'a', b'b', b'c']
@@ -2346,9 +2343,6 @@ def test_read_tfrecord_vocabulary_dataset(self):
23462343
return_indicator_as_value=True,
23472344
has_indicator=True),
23482345
])
2349-
@unittest.skipIf(
2350-
test_case.is_tf_api_version_1(),
2351-
'TFRecord vocabulary dataset tests require TF API version>1')
23522346
def test_make_tfrecord_vocabulary_dataset(self, contents, expected, key_dtype,
23532347
value_dtype,
23542348
return_indicator_as_value,

0 commit comments

Comments
 (0)