Skip to content

Commit bab1a8f

Browse files
aayooushtfx-copybara
authored andcommitted
Adds an example for DatasetTFXIO usage with TFT.
PiperOrigin-RevId: 590666871
1 parent 617c5dd commit bab1a8f

File tree

2 files changed

+130
-0
lines changed

2 files changed

+130
-0
lines changed

examples/dataset_tfxio_example.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2023 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Simple Example of DatasetTFXIO usage."""
15+
16+
import pprint
17+
import tempfile
18+
19+
from absl import app
20+
import apache_beam as beam
21+
import tensorflow as tf
22+
import tensorflow_transform as tft
23+
import tensorflow_transform.beam.impl as tft_beam
24+
from tfx_bsl.tfxio import dataset_tfxio
25+
26+
27+
def _print_record_batch(data):
28+
pprint.pprint(data.to_pydict())
29+
30+
31+
def _preprocessing_fn(inputs):
32+
return {
33+
'x_centered': tf.cast(inputs['feature0'], tf.float32) - tft.mean(
34+
inputs['feature0']
35+
),
36+
'x_scaled': tft.scale_by_min_max(inputs['feature0']),
37+
}
38+
39+
40+
def _make_tfxio() -> dataset_tfxio.DatasetTFXIO:
41+
"""Make DatasetTFXIO."""
42+
num_elements = 9
43+
batch_size = 2
44+
dataset = tf.data.Dataset.range(num_elements).batch(batch_size)
45+
46+
return dataset_tfxio.DatasetTFXIO(dataset=dataset)
47+
48+
49+
def main(args):
50+
del args
51+
52+
input_tfxio = _make_tfxio()
53+
54+
# User-Defined Processing Pipeline
55+
with beam.Pipeline() as pipeline:
56+
with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
57+
raw_dataset = (
58+
pipeline | 'ReadRecordBatch' >> input_tfxio.BeamSource(batch_size=5),
59+
input_tfxio.TensorAdapterConfig(),
60+
)
61+
(transformed_data, _), _ = (
62+
raw_dataset
63+
| 'AnalyzeAndTransform'
64+
>> tft_beam.AnalyzeAndTransformDataset(
65+
_preprocessing_fn, output_record_batches=True
66+
)
67+
)
68+
transformed_data = transformed_data | 'ExtractRecordBatch' >> beam.Keys()
69+
_ = transformed_data | 'PrintTransformedData' >> beam.Map(
70+
_print_record_batch
71+
)
72+
73+
74+
if __name__ == '__main__':
75+
app.run(main)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2023 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for dataset_tfxio."""
15+
16+
import tensorflow as tf
17+
import dataset_tfxio_example
18+
from tensorflow_transform.beam import tft_unit
19+
20+
21+
_EXPECTED_TRANSFORMED_OUTPUT = [
22+
{'x_scaled': 0.0, 'x_centered': -4.0},
23+
{'x_scaled': 0.125, 'x_centered': -3.0},
24+
{'x_scaled': 0.25, 'x_centered': -2.0},
25+
{'x_scaled': 0.375, 'x_centered': -1.0},
26+
{'x_scaled': 0.5, 'x_centered': 0.0},
27+
{'x_scaled': 0.625, 'x_centered': 1.0},
28+
{'x_scaled': 0.75, 'x_centered': 2.0},
29+
{'x_scaled': 0.875, 'x_centered': 3.0},
30+
{'x_scaled': 1.0, 'x_centered': 4.0},
31+
]
32+
33+
34+
class SimpleMainTest(tf.test.TestCase):
35+
36+
def testMainDoesNotCrash(self):
37+
tft_unit.skip_if_not_tf2('Tensorflow 2.x required.')
38+
dataset_tfxio_example.main('')
39+
40+
41+
class SimpleProcessingTest(tft_unit.TransformTestCase):
42+
43+
# Asserts equal for each element. (Does not check batchwise.)
44+
def test_preprocessing_fn(self):
45+
tfxio = dataset_tfxio_example._make_tfxio()
46+
self.assertAnalyzeAndTransformResults(
47+
tfxio.BeamSource(),
48+
tfxio.TensorAdapterConfig(),
49+
dataset_tfxio_example._preprocessing_fn,
50+
_EXPECTED_TRANSFORMED_OUTPUT,
51+
)
52+
53+
54+
if __name__ == '__main__':
55+
tf.test.main()

0 commit comments

Comments
 (0)