Skip to content

Commit f368012

Browse files
zoyahavtfx-copybara
authored andcommitted
n/a
PiperOrigin-RevId: 522298075
1 parent e50a154 commit f368012

File tree

10 files changed

+113
-83
lines changed

10 files changed

+113
-83
lines changed

examples/simple_example_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
"""Tests for simple_example."""
1515

1616
import tensorflow as tf
17-
from tensorflow_transform.beam import tft_unit
1817
import simple_example
18+
from tensorflow_transform.beam import tft_unit
1919

2020

2121
_EXPECTED_TRANSFORMED_OUTPUT = [

examples/simple_sequence_example_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
"""Tests for simple_example."""
1515

1616
import tensorflow as tf
17-
from tensorflow_transform.beam import tft_unit
1817
import simple_sequence_example
18+
from tensorflow_transform.beam import tft_unit
1919

2020
_EXPECTED_TRANSFORMED_OUTPUT = [{
2121
'transformed_seq_int_feature$ragged_values': [

tensorflow_transform/beam/analysis_graph_builder_test.py

Lines changed: 63 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tensorflow_transform import tf2_utils
2424
from tensorflow_transform.beam import analysis_graph_builder
2525
from tensorflow_transform.beam import analyzer_cache
26-
from tensorflow_transform import test_case
26+
from tensorflow_transform.beam import tft_unit
2727
# TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple`
2828
# once the Spark issue is resolved.
2929
from tfx_bsl.types import tfx_namedtuple
@@ -396,17 +396,21 @@ def __new__(cls):
396396
]
397397

398398

399-
class AnalysisGraphBuilderTest(test_case.TransformTestCase):
399+
class AnalysisGraphBuilderTest(tft_unit.TransformTestCase):
400400

401-
@test_case.named_parameters(
402-
*test_case.cross_named_parameters(_ANALYZE_TEST_CASES, [
403-
dict(testcase_name='tf_compat_v1', use_tf_compat_v1=True),
404-
dict(testcase_name='tf2', use_tf_compat_v1=False)
405-
]))
401+
@tft_unit.named_parameters(
402+
*tft_unit.cross_named_parameters(
403+
_ANALYZE_TEST_CASES,
404+
[
405+
dict(testcase_name='tf_compat_v1', use_tf_compat_v1=True),
406+
dict(testcase_name='tf2', use_tf_compat_v1=False),
407+
],
408+
)
409+
)
406410
def test_build(self, feature_spec, preprocessing_fn, expected_dot_graph_str,
407411
expected_dot_graph_str_tf2, use_tf_compat_v1):
408412
if not use_tf_compat_v1:
409-
test_case.skip_if_not_tf2('Tensorflow 2.x required')
413+
tft_unit.skip_if_not_tf2('Tensorflow 2.x required')
410414
specs = (
411415
feature_spec if use_tf_compat_v1 else
412416
impl_helper.get_type_specs_from_feature_specs(feature_spec))
@@ -430,48 +434,54 @@ def test_build(self, feature_spec, preprocessing_fn, expected_dot_graph_str,
430434
second=(expected_dot_graph_str
431435
if use_tf_compat_v1 else expected_dot_graph_str_tf2))
432436

433-
@test_case.named_parameters(*test_case.cross_named_parameters(
434-
[
435-
dict(
436-
testcase_name='one_dataset_cached_single_phase',
437-
preprocessing_fn=_preprocessing_fn_with_one_analyzer,
438-
full_dataset_keys=['a', 'b'],
439-
cached_dataset_keys=['a'],
440-
expected_dataset_keys=['b'],
441-
),
442-
dict(
443-
testcase_name='all_datasets_cached_single_phase',
444-
preprocessing_fn=_preprocessing_fn_with_one_analyzer,
445-
full_dataset_keys=['a', 'b'],
446-
cached_dataset_keys=['a', 'b'],
447-
expected_dataset_keys=[],
448-
),
449-
dict(
450-
testcase_name='mixed_single_phase',
451-
preprocessing_fn=lambda d: dict( # pylint: disable=g-long-lambda
452-
list(_preprocessing_fn_with_chained_ptransforms(d).items()) +
453-
list(_preprocessing_fn_with_one_analyzer(d).items())),
454-
full_dataset_keys=['a', 'b'],
455-
cached_dataset_keys=['a', 'b'],
456-
expected_dataset_keys=['a', 'b'],
457-
),
458-
dict(
459-
testcase_name='multi_phase',
460-
preprocessing_fn=_preprocessing_fn_with_two_phases,
461-
full_dataset_keys=['a', 'b'],
462-
cached_dataset_keys=['a', 'b'],
463-
expected_dataset_keys=['a', 'b'],
464-
)
465-
],
466-
[
467-
dict(testcase_name='tf_compat_v1', use_tf_compat_v1=True),
468-
dict(testcase_name='tf2', use_tf_compat_v1=False)
469-
]))
437+
@tft_unit.named_parameters(
438+
*tft_unit.cross_named_parameters(
439+
[
440+
dict(
441+
testcase_name='one_dataset_cached_single_phase',
442+
preprocessing_fn=_preprocessing_fn_with_one_analyzer,
443+
full_dataset_keys=['a', 'b'],
444+
cached_dataset_keys=['a'],
445+
expected_dataset_keys=['b'],
446+
),
447+
dict(
448+
testcase_name='all_datasets_cached_single_phase',
449+
preprocessing_fn=_preprocessing_fn_with_one_analyzer,
450+
full_dataset_keys=['a', 'b'],
451+
cached_dataset_keys=['a', 'b'],
452+
expected_dataset_keys=[],
453+
),
454+
dict(
455+
testcase_name='mixed_single_phase',
456+
preprocessing_fn=lambda d: dict( # pylint: disable=g-long-lambda
457+
list(
458+
_preprocessing_fn_with_chained_ptransforms(d).items()
459+
)
460+
+ list(_preprocessing_fn_with_one_analyzer(d).items())
461+
),
462+
full_dataset_keys=['a', 'b'],
463+
cached_dataset_keys=['a', 'b'],
464+
expected_dataset_keys=['a', 'b'],
465+
),
466+
dict(
467+
testcase_name='multi_phase',
468+
preprocessing_fn=_preprocessing_fn_with_two_phases,
469+
full_dataset_keys=['a', 'b'],
470+
cached_dataset_keys=['a', 'b'],
471+
expected_dataset_keys=['a', 'b'],
472+
),
473+
],
474+
[
475+
dict(testcase_name='tf_compat_v1', use_tf_compat_v1=True),
476+
dict(testcase_name='tf2', use_tf_compat_v1=False),
477+
],
478+
)
479+
)
470480
def test_get_analysis_dataset_keys(self, preprocessing_fn, full_dataset_keys,
471481
cached_dataset_keys, expected_dataset_keys,
472482
use_tf_compat_v1):
473483
if not use_tf_compat_v1:
474-
test_case.skip_if_not_tf2('Tensorflow 2.x required')
484+
tft_unit.skip_if_not_tf2('Tensorflow 2.x required')
475485
full_dataset_keys = list(
476486
map(analyzer_cache.DatasetKey, full_dataset_keys))
477487
cached_dataset_keys = map(analyzer_cache.DatasetKey, cached_dataset_keys)
@@ -499,18 +509,16 @@ def test_get_analysis_dataset_keys(self, preprocessing_fn, full_dataset_keys,
499509
full_dataset_keys,
500510
input_cache,
501511
force_tf_compat_v1=use_tf_compat_v1))
502-
503-
dot_string = nodes.get_dot_graph([analysis_graph_builder._ANALYSIS_GRAPH
504-
]).to_string()
505-
self.WriteRenderedDotFile(dot_string)
512+
self.DebugPublishLatestsRenderedTFTGraph()
506513
self.assertCountEqual(expected_dataset_keys, dataset_keys)
507514

508-
@test_case.named_parameters(
515+
@tft_unit.named_parameters(
509516
dict(testcase_name='tf_compat_v1', use_tf_compat_v1=True),
510-
dict(testcase_name='tf2', use_tf_compat_v1=False))
517+
dict(testcase_name='tf2', use_tf_compat_v1=False),
518+
)
511519
def test_get_analysis_cache_entry_keys(self, use_tf_compat_v1):
512520
if not use_tf_compat_v1:
513-
test_case.skip_if_not_tf2('Tensorflow 2.x required')
521+
tft_unit.skip_if_not_tf2('Tensorflow 2.x required')
514522
full_dataset_keys = map(analyzer_cache.DatasetKey, ['a', 'b'])
515523
def preprocessing_fn(inputs):
516524
return {'x': tft.scale_to_0_1(inputs['x'])}
@@ -531,10 +539,7 @@ def mocked_make_cache_entry_key(_):
531539
specs,
532540
full_dataset_keys,
533541
force_tf_compat_v1=use_tf_compat_v1))
534-
535-
dot_string = nodes.get_dot_graph([analysis_graph_builder._ANALYSIS_GRAPH
536-
]).to_string()
537-
self.WriteRenderedDotFile(dot_string)
542+
self.DebugPublishLatestsRenderedTFTGraph()
538543
self.assertCountEqual(cache_entry_keys, [mocked_cache_entry_key])
539544

540545
def test_duplicate_label_error(self):
@@ -575,4 +580,4 @@ class _Analyzer(
575580

576581

577582
if __name__ == '__main__':
578-
test_case.main()
583+
tft_unit.main()

tensorflow_transform/beam/beam_nodes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def _get_tensor_type_name(self, tensor):
103103
return 'Tensor'
104104
elif isinstance(tensor, tf.SparseTensor):
105105
return 'SparseTensor'
106+
elif isinstance(tensor, tf.RaggedTensor):
107+
return 'RaggedTensor'
106108
raise ValueError('Got a {}, expected a Tensor or SparseTensor'.format(
107109
type(tensor)))
108110

tensorflow_transform/beam/impl_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
from tensorflow_transform import pretrained_models
3131
from tensorflow_transform import schema_inference
3232
import tensorflow_transform.beam as tft_beam
33-
from tensorflow_transform.beam import tft_unit
3433
from tensorflow_transform.beam.tft_beam_io import transform_fn_io
34+
from tensorflow_transform.beam import tft_unit
3535
from tfx_bsl.tfxio import tensor_adapter
3636

3737
from google.protobuf import text_format

tensorflow_transform/beam/tft_beam_io/beam_metadata_io_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import apache_beam as beam
2020
import tensorflow as tf
2121
from tensorflow_transform import output_wrapper
22-
from tensorflow_transform.beam import tft_unit
2322
from tensorflow_transform.beam.tft_beam_io import beam_metadata_io
23+
from tensorflow_transform.beam import tft_unit
2424
from tensorflow_transform.beam.tft_beam_io import test_metadata
2525
import tensorflow_transform.test_case as tft_test_case
2626
from tensorflow_transform.tf_metadata import metadata_io

tensorflow_transform/beam/tft_beam_io/transform_fn_io_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@
1818
import apache_beam as beam
1919
from apache_beam.testing import util as beam_test_util
2020
import tensorflow as tf
21-
2221
import tensorflow_transform as tft
23-
from tensorflow_transform.beam import tft_unit
2422
from tensorflow_transform.beam.tft_beam_io import beam_metadata_io
2523
from tensorflow_transform.beam.tft_beam_io import transform_fn_io
24+
from tensorflow_transform.beam import tft_unit
2625
from tensorflow_transform.beam.tft_beam_io import test_metadata
2726
from tensorflow_transform.tf_metadata import metadata_io
2827

tensorflow_transform/beam/tft_unit.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
import os
1717
import tempfile
1818
from typing import Dict, Iterable, List, Optional, Tuple
19+
from absl import logging
1920

2021
import apache_beam as beam
2122
import pyarrow as pa
2223
import tensorflow as tf
2324
import tensorflow_transform as tft
24-
from tensorflow_transform.beam import impl as beam_impl
25+
import tensorflow_transform.beam as tft_beam
2526
from tensorflow_transform.beam.tft_beam_io import transform_fn_io
2627
from tensorflow_transform import test_case
2728
from tensorflow_transform.beam import test_helpers
@@ -333,38 +334,47 @@ def assertAnalyzeAndTransformResults(self,
333334
temp_dir = temp_dir or tempfile.mkdtemp(
334335
prefix=self._testMethodName, dir=self.get_temp_dir())
335336
with beam_pipeline or self._makeTestPipeline() as pipeline:
336-
with beam_impl.Context(
337+
with tft_beam.Context(
337338
temp_dir=temp_dir,
338339
desired_batch_size=desired_batch_size,
339-
force_tf_compat_v1=force_tf_compat_v1):
340+
force_tf_compat_v1=force_tf_compat_v1,
341+
):
340342
source_ptransform = (
341343
input_data if isinstance(input_data, beam.PTransform) else
342344
beam.Create(input_data, reshuffle=False))
343345
input_data = pipeline | 'CreateInput' >> source_ptransform
344346
if test_data is None:
345347
(transformed_data, transformed_metadata), transform_fn = (
346-
(input_data, input_metadata)
347-
| beam_impl.AnalyzeAndTransformDataset(
348-
preprocessing_fn,
349-
output_record_batches=output_record_batches))
348+
input_data,
349+
input_metadata,
350+
) | tft_beam.AnalyzeAndTransformDataset(
351+
preprocessing_fn, output_record_batches=output_record_batches
352+
)
350353
else:
351-
transform_fn = ((input_data, input_metadata)
352-
| beam_impl.AnalyzeDataset(preprocessing_fn))
354+
transform_fn = (input_data, input_metadata) | tft_beam.AnalyzeDataset(
355+
preprocessing_fn
356+
)
353357
test_data = pipeline | 'CreateTest' >> beam.Create(test_data)
354358
transformed_data, transformed_metadata = (
355-
((test_data, input_metadata), transform_fn)
356-
| beam_impl.TransformDataset(
357-
output_record_batches=output_record_batches))
359+
(test_data, input_metadata),
360+
transform_fn,
361+
) | tft_beam.TransformDataset(
362+
output_record_batches=output_record_batches
363+
)
358364

359365
# Write transform_fn so we can test its assets
360366
_ = transform_fn | transform_fn_io.WriteTransformFn(temp_dir)
361367

362368
transformed_data_path = os.path.join(temp_dir, 'transformed_data')
363369
if expected_data is not None:
364-
_ = ((transformed_data, transformed_metadata)
365-
| 'Encode' >> beam_impl.EncodeTransformedDataset()
366-
| 'Write' >> beam.io.tfrecordio.WriteToTFRecord(
367-
transformed_data_path, shard_name_template=''))
370+
_ = (
371+
(transformed_data, transformed_metadata)
372+
| 'Encode' >> tft_beam.EncodeTransformedDataset()
373+
| 'Write'
374+
>> beam.io.tfrecordio.WriteToTFRecord(
375+
transformed_data_path, shard_name_template=''
376+
)
377+
)
368378

369379
# TODO(ebreck) Log transformed_data somewhere.
370380
tf_transform_output = tft.TFTransformOutput(temp_dir)
@@ -406,3 +416,18 @@ def assertAnalyzeAndTransformResults(self,
406416
for filename, file_contents in expected_vocab_file_contents.items():
407417
full_filename = tf_transform_output.vocabulary_file_by_name(filename)
408418
self.AssertVocabularyContents(full_filename, file_contents)
419+
420+
def DebugPublishLatestsRenderedTFTGraph(
421+
self, output_file: Optional[str] = None
422+
):
423+
"""Outputs a rendered graph which may be used for debugging.
424+
425+
Requires adding the binary resource to the test target:
426+
data = ["//third_party/graphviz:dot_binary"]
427+
428+
Args:
429+
output_file: Path to output the rendered graph file.
430+
"""
431+
logging.info(
432+
'DebugPublishLatestsRenderedTFTGraph is not currently supported.'
433+
)

tensorflow_transform/beam/vocabulary_integration_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@
1818
import os
1919

2020
import apache_beam as beam
21-
2221
import tensorflow as tf
2322
import tensorflow_transform as tft
2423
from tensorflow_transform.beam import analyzer_impls
2524
from tensorflow_transform.beam import impl as beam_impl
26-
from tensorflow_transform.beam import tft_unit
2725
from tensorflow_transform.beam.tft_beam_io import transform_fn_io
26+
from tensorflow_transform.beam import tft_unit
2827

2928
from tensorflow_metadata.proto.v0 import schema_pb2
3029

tensorflow_transform/beam/vocabulary_tfrecord_gzip_integration_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
# limitations under the License.
1515
"""Tests for tfrecord_gzip tft.vocabulary and tft.compute_and_apply_vocabulary."""
1616

17-
from tensorflow_transform.beam import tft_unit
1817
from tensorflow_transform.beam import vocabulary_integration_test
18+
from tensorflow_transform.beam import tft_unit
1919

2020

2121
class TFRecordVocabularyIntegrationTest(

0 commit comments

Comments
 (0)