Skip to content

Commit 2cdf7f5

Browse files
zoyahavtfx-copybara
authored andcommitted
Adds tft_beam.EncodeTransformedDatasetAsSerializedExamples which can be used to easily encode transformed data in either record batch or instance dict format.
PiperOrigin-RevId: 482514132
1 parent 3f0d9a9 commit 2cdf7f5

File tree

8 files changed

+90
-81
lines changed

8 files changed

+90
-81
lines changed

RELEASE.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
* TensorRepresentations in schema used for
1818
`schema_utils.schema_as_feature_spec` can now share name with their source
1919
features.
20+
* Introduced `tft_beam.EncodeTransformedDataset` which can be used to easily
21+
encode transformed data in preparation for materialization.
2022

2123
## Bug Fixes and Other Changes
2224

examples/census_example_common.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import tensorflow.compat.v2 as tf
2323
import tensorflow_transform as tft
2424
import tensorflow_transform.beam as tft_beam
25-
from tfx_bsl.public.tfxio import RecordBatchToExamplesEncoder
2625
from tfx_bsl.public import tfxio
2726

2827
CATEGORICAL_FEATURE_KEYS = [
@@ -210,16 +209,11 @@ def preprocessing_fn(inputs):
210209
raw_dataset | tft_beam.AnalyzeAndTransformDataset(
211210
preprocessing_fn, output_record_batches=True))
212211

213-
# Transformed metadata is not necessary for encoding.
214-
transformed_data, _ = transformed_dataset
215-
216212
# Extract transformed RecordBatches, encode and write them to the given
217213
# directory.
218-
coder = RecordBatchToExamplesEncoder()
219214
_ = (
220-
transformed_data
221-
| 'EncodeTrainData' >>
222-
beam.FlatMapTuple(lambda batch, _: coder.encode(batch))
215+
transformed_dataset
216+
| 'EncodeTrainData' >> tft_beam.EncodeTransformedDataset()
223217
| 'WriteTrainData' >> beam.io.WriteToTFRecord(
224218
os.path.join(working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE)))
225219

@@ -243,15 +237,11 @@ def preprocessing_fn(inputs):
243237
(raw_test_dataset, transform_fn)
244238
| tft_beam.TransformDataset(output_record_batches=True))
245239

246-
# Transformed metadata is not necessary for encoding.
247-
transformed_test_data, _ = transformed_test_dataset
248-
249240
# Extract transformed RecordBatches, encode and write them to the given
250241
# directory.
251242
_ = (
252-
transformed_test_data
253-
| 'EncodeTestData' >>
254-
beam.FlatMapTuple(lambda batch, _: coder.encode(batch))
243+
transformed_test_dataset
244+
| 'EncodeTestData' >> tft_beam.EncodeTransformedDataset()
255245
| 'WriteTestData' >> beam.io.WriteToTFRecord(
256246
os.path.join(working_dir, TRANSFORMED_TEST_DATA_FILEBASE)))
257247

examples/sentiment_example.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,32 +189,28 @@ def preprocessing_fn(inputs):
189189
LABEL_KEY: inputs[LABEL_KEY]
190190
}
191191

192-
# Transformed metadata is not necessary for encoding.
193192
# The TFXIO output format is chosen for improved performance.
194-
(transformed_train_data, _), transform_fn = (
193+
transformed_train_data, transform_fn = (
195194
(train_data, tfxio_train_data.TensorAdapterConfig())
196195
| 'AnalyzeAndTransform' >> tft_beam.AnalyzeAndTransformDataset(
197196
preprocessing_fn, output_record_batches=True))
198197

199-
transformed_test_data, _ = (
198+
transformed_test_data = (
200199
((test_data, tfxio_test_data.TensorAdapterConfig()), transform_fn)
201200
|
202201
'Transform' >> tft_beam.TransformDataset(output_record_batches=True))
203202

204203
# Extract transformed RecordBatches, encode and write them to the given
205204
# directory.
206-
coder = tfxio.RecordBatchToExamplesEncoder()
207205
_ = (
208206
transformed_train_data
209-
| 'EncodeTrainData' >>
210-
beam.FlatMapTuple(lambda batch, _: coder.encode(batch))
207+
| 'EncodeTrainData' >> tft_beam.EncodeTransformedDataset()
211208
| 'WriteTrainData' >> beam.io.WriteToTFRecord(
212209
os.path.join(working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE)))
213210

214211
_ = (
215212
transformed_test_data
216-
| 'EncodeTestData' >>
217-
beam.FlatMapTuple(lambda batch, _: coder.encode(batch))
213+
| 'EncodeTestData' >> tft_beam.EncodeTransformedDataset()
218214
| 'WriteTestData' >> beam.io.WriteToTFRecord(
219215
os.path.join(working_dir, TRANSFORMED_TEST_DATA_FILEBASE)))
220216

tensorflow_transform/beam/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tensorflow_transform.beam.impl import AnalyzeAndTransformDataset
2525
from tensorflow_transform.beam.impl import AnalyzeDataset
2626
from tensorflow_transform.beam.impl import AnalyzeDatasetWithCache
27+
from tensorflow_transform.beam.impl import EncodeTransformedDataset
2728
from tensorflow_transform.beam.impl import TransformDataset
2829
from tensorflow_transform.beam.tft_beam_io import *
2930

tensorflow_transform/beam/impl.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
from tensorflow_transform.tf_metadata import dataset_metadata
8080
from tensorflow_transform.tf_metadata import metadata_io
8181
from tensorflow_transform.tf_metadata import schema_utils
82+
from tfx_bsl.coders import example_coder
8283
from tfx_bsl.telemetry import collection as telemetry
8384
from tfx_bsl.telemetry import util as telemetry_util
8485
from tfx_bsl.tfxio import tensor_representation_util
@@ -1499,10 +1500,13 @@ def expand(self, dataset_and_transform_fn):
14991500
deferred_schema = (
15001501
output_metadata.deferred_metadata
15011502
| 'GetDeferredSchema' >> beam.Map(lambda m: m.schema))
1503+
output_dataset_metadata = output_metadata.dataset_metadata
15021504
else:
15031505
deferred_schema = (
15041506
self.pipeline
15051507
| 'CreateDeferredSchema' >> beam.Create([output_metadata.schema]))
1508+
output_dataset_metadata = output_metadata
1509+
output_dataset_metadata._output_record_batches = self._output_record_batches # pylint: disable=protected-access
15061510

15071511
# Increment input metrics.
15081512
_ = (
@@ -1573,3 +1577,67 @@ def expand(self, dataset_and_transform_fn):
15731577
_clear_shared_state_after_barrier(self.pipeline, output_data)
15741578

15751579
return (output_data, output_metadata)
1580+
1581+
1582+
class EncodeTransformedDataset(beam.PTransform):
1583+
"""Encodes transformed data into serialized tf.Examples.
1584+
1585+
Should operate on the output of `TransformDataset`, this can operate on either
1586+
record batch or instance dict data.
1587+
The expected input is a (transformed_data, transformed_metadata) tuple.
1588+
1589+
Example use:
1590+
1591+
>>> def preprocessing_fn(inputs):
1592+
... return {'x_scaled': tft.scale_to_z_score(inputs['x'], name='x')}
1593+
>>> raw_data = [dict(x=1), dict(x=2), dict(x=3)]
1594+
>>> feature_spec = dict(x=tf.io.FixedLenFeature([], tf.int64))
1595+
>>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec)
1596+
>>> output_path = os.path.join(tempfile.mkdtemp(), 'result')
1597+
>>> with beam.Pipeline() as p:
1598+
... with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
1599+
... data_pcoll = p | beam.Create(raw_data)
1600+
... transformed_dataset, transform_fn = (
1601+
... (data_pcoll, raw_data_metadata)
1602+
... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
1603+
... _ = (
1604+
... transformed_dataset
1605+
... | tft_beam.EncodeTransformedDataset()
1606+
... | beam.io.WriteToTFRecord(output_path, shard_name_template=''))
1607+
>>> result_feature_spec ={'x_scaled': tf.io.FixedLenFeature([], tf.float32)}
1608+
>>> list(tf.data.TFRecordDataset([output_path])
1609+
... .map(lambda x: tf.io.parse_example(x, result_feature_spec))
1610+
... .as_numpy_iterator())
1611+
[{'x_scaled': -1.2247448}, {'x_scaled': 0.0}, {'x_scaled': 1.2247448}]
1612+
"""
1613+
1614+
def _extract_input_pvalues(self, transformed_data_and_metadata):
1615+
# This method lets beam know that metadata is not a pvalue.
1616+
return transformed_data_and_metadata, [transformed_data_and_metadata[0]]
1617+
1618+
def expand(self, transformed_data_and_metadata):
1619+
1620+
transformed_data, transformed_metadata = transformed_data_and_metadata
1621+
1622+
deferred_schema = (
1623+
transformed_metadata.deferred_metadata
1624+
| 'GetDeferredSchema' >> beam.Map(lambda m: m.schema))
1625+
1626+
if transformed_metadata.dataset_metadata._output_record_batches: # pylint: disable=protected-access
1627+
transformed_data_coder_pcol = (
1628+
deferred_schema | 'RecordBatchToExamplesEncoder' >> beam.Map(
1629+
example_coder.RecordBatchToExamplesEncoder))
1630+
encode_ptransform = 'EncodeRecordBatches' >> beam.FlatMap(
1631+
# Dropping passthrough features.
1632+
lambda elem, coder: coder.encode(elem[0]),
1633+
coder=beam.pvalue.AsSingleton(transformed_data_coder_pcol))
1634+
else:
1635+
transformed_data_coder_pcol = (
1636+
deferred_schema
1637+
| 'ExampleProtoCoder' >> beam.Map(
1638+
example_proto_coder.ExampleProtoCoder))
1639+
encode_ptransform = 'EncodeInstances' >> beam.Map(
1640+
lambda data, data_coder: data_coder.encode(data),
1641+
data_coder=beam.pvalue.AsSingleton(transformed_data_coder_pcol))
1642+
1643+
return transformed_data | encode_ptransform

tensorflow_transform/beam/impl_test.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import tensorflow_transform.beam as tft_beam
3535
from tensorflow_transform.beam import tft_unit
3636
from tensorflow_transform.beam.tft_beam_io import transform_fn_io
37-
from tfx_bsl.coders import example_coder
3837
from tfx_bsl.tfxio import tensor_adapter
3938

4039
from google.protobuf import text_format
@@ -4418,27 +4417,11 @@ def test3dSparseWithTFXIO(self):
44184417
| tft_beam.AnalyzeAndTransformDataset(
44194418
lambda inputs: inputs,
44204419
output_record_batches=self._OutputRecordBatches()))
4421-
if self._OutputRecordBatches():
44224420

4423-
def record_batch_to_examples(data_batch):
4424-
# Ignore unary pass-through features.
4425-
record_batch, _ = data_batch
4426-
return example_coder.RecordBatchToExamples(record_batch)
4427-
4428-
transformed_and_serialized = (
4429-
transformed_data |
4430-
'EncodeTransformedData' >> beam.FlatMap(record_batch_to_examples))
4431-
else:
4432-
transformed_data_coder = tft.coders.ExampleProtoCoder(
4433-
transformed_metadata.schema)
4434-
transformed_and_serialized = (
4435-
transformed_data | 'EncodeTransformedData' >> beam.Map(
4436-
transformed_data_coder.encode))
4437-
4438-
_ = (
4439-
transformed_and_serialized
4440-
| 'Write' >> beam.io.WriteToTFRecord(
4441-
materialize_path, shard_name_template=''))
4421+
_ = ((transformed_data, transformed_metadata)
4422+
| 'Encode' >> tft_beam.EncodeTransformedDataset()
4423+
| 'Write' >> beam.io.WriteToTFRecord(
4424+
materialize_path, shard_name_template=''))
44424425
_ = (
44434426
transform_fn
44444427
| 'WriteTransformFn' >>

tensorflow_transform/beam/tft_unit.py

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import tensorflow as tf
2323
import tensorflow_transform as tft
2424
from tensorflow_transform.beam import impl as beam_impl
25-
from tensorflow_transform.beam.tft_beam_io import beam_metadata_io
2625
from tensorflow_transform.beam.tft_beam_io import transform_fn_io
2726
from tensorflow_transform import test_case
2827
from tensorflow_transform.beam import test_helpers
@@ -230,7 +229,7 @@ def preprocessing_fn(inputs):
230229
try:
231230
output_tensor.set_shape(output_shape)
232231
except ValueError as e:
233-
raise ValueError('Error for key {}: {}'.format(key, str(e)))
232+
raise ValueError(f'Error for key {key}') from e
234233
# Add a batch dimension
235234
output_tensor = tf.expand_dims(output_tensor, 0)
236235
# Broadcast along the batch dimension
@@ -362,41 +361,10 @@ def assertAnalyzeAndTransformResults(self,
362361

363362
transformed_data_path = os.path.join(temp_dir, 'transformed_data')
364363
if expected_data is not None:
365-
if isinstance(transformed_metadata,
366-
beam_metadata_io.BeamDatasetMetadata):
367-
deferred_schema = (
368-
transformed_metadata.deferred_metadata
369-
| 'GetDeferredSchema' >> beam.Map(lambda m: m.schema))
370-
else:
371-
deferred_schema = (
372-
self.pipeline | 'CreateDeferredSchema' >> beam.Create(
373-
[transformed_metadata.schema]))
374-
375-
if output_record_batches:
376-
# Since we are using a deferred schema, obtain a pcollection
377-
# containing the data coder that will be created from it.
378-
transformed_data_coder_pcol = (
379-
deferred_schema | 'RecordBatchToExamplesEncoder' >> beam.Map(
380-
example_coder.RecordBatchToExamplesEncoder))
381-
382-
encode_ptransform = 'EncodeRecordBatches' >> beam.FlatMap(
383-
_encode_transformed_data_batch,
384-
coder=beam.pvalue.AsSingleton(transformed_data_coder_pcol))
385-
else:
386-
# Since we are using a deferred schema, obtain a pcollection
387-
# containing the data coder that will be created from it.
388-
transformed_data_coder_pcol = (
389-
deferred_schema
390-
| 'ExampleProtoCoder' >> beam.Map(tft.coders.ExampleProtoCoder))
391-
encode_ptransform = 'EncodeExamples' >> beam.Map(
392-
lambda data, data_coder: data_coder.encode(data),
393-
data_coder=beam.pvalue.AsSingleton(transformed_data_coder_pcol))
394-
395-
_ = (
396-
transformed_data
397-
| encode_ptransform
398-
| beam.io.tfrecordio.WriteToTFRecord(
399-
transformed_data_path, shard_name_template=''))
364+
_ = ((transformed_data, transformed_metadata)
365+
| 'Encode' >> beam_impl.EncodeTransformedDataset()
366+
| 'Write' >> beam.io.tfrecordio.WriteToTFRecord(
367+
transformed_data_path, shard_name_template=''))
400368

401369
# TODO(ebreck) Log transformed_data somewhere.
402370
tf_transform_output = tft.TFTransformOutput(temp_dir)

tensorflow_transform/tf_metadata/dataset_metadata.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class DatasetMetadata:
3737

3838
def __init__(self, schema: schema_pb2.Schema):
3939
self._schema = schema
40+
self._output_record_batches = True
4041

4142
@classmethod
4243
def from_feature_spec(
@@ -53,7 +54,7 @@ def schema(self) -> schema_pb2.Schema:
5354

5455
def __eq__(self, other):
5556
if isinstance(other, self.__class__):
56-
return self.__dict__ == other.__dict__
57+
return self.schema == other.schema
5758
return NotImplemented
5859

5960
def __ne__(self, other):

0 commit comments

Comments
 (0)