|
79 | 79 | from tensorflow_transform.tf_metadata import dataset_metadata |
80 | 80 | from tensorflow_transform.tf_metadata import metadata_io |
81 | 81 | from tensorflow_transform.tf_metadata import schema_utils |
| 82 | +from tfx_bsl.coders import example_coder |
82 | 83 | from tfx_bsl.telemetry import collection as telemetry |
83 | 84 | from tfx_bsl.telemetry import util as telemetry_util |
84 | 85 | from tfx_bsl.tfxio import tensor_representation_util |
@@ -1499,10 +1500,13 @@ def expand(self, dataset_and_transform_fn): |
1499 | 1500 | deferred_schema = ( |
1500 | 1501 | output_metadata.deferred_metadata |
1501 | 1502 | | 'GetDeferredSchema' >> beam.Map(lambda m: m.schema)) |
| 1503 | + output_dataset_metadata = output_metadata.dataset_metadata |
1502 | 1504 | else: |
1503 | 1505 | deferred_schema = ( |
1504 | 1506 | self.pipeline |
1505 | 1507 | | 'CreateDeferredSchema' >> beam.Create([output_metadata.schema])) |
| 1508 | + output_dataset_metadata = output_metadata |
| 1509 | + output_dataset_metadata._output_record_batches = self._output_record_batches # pylint: disable=protected-access |
1506 | 1510 |
|
1507 | 1511 | # Increment input metrics. |
1508 | 1512 | _ = ( |
@@ -1573,3 +1577,67 @@ def expand(self, dataset_and_transform_fn): |
1573 | 1577 | _clear_shared_state_after_barrier(self.pipeline, output_data) |
1574 | 1578 |
|
1575 | 1579 | return (output_data, output_metadata) |
| 1580 | + |
| 1581 | + |
| 1582 | +class EncodeTransformedDataset(beam.PTransform): |
| 1583 | + """Encodes transformed data into serialized tf.Examples. |
| 1584 | +
|
| 1585 | + Should operate on the output of `TransformDataset`, this can operate on either |
| 1586 | + record batch or instance dict data. |
| 1587 | + The expected input is a (transformed_data, transformed_metadata) tuple. |
| 1588 | +
|
| 1589 | + Example use: |
| 1590 | +
|
| 1591 | + >>> def preprocessing_fn(inputs): |
| 1592 | + ... return {'x_scaled': tft.scale_to_z_score(inputs['x'], name='x')} |
| 1593 | + >>> raw_data = [dict(x=1), dict(x=2), dict(x=3)] |
| 1594 | + >>> feature_spec = dict(x=tf.io.FixedLenFeature([], tf.int64)) |
| 1595 | + >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) |
| 1596 | + >>> output_path = os.path.join(tempfile.mkdtemp(), 'result') |
| 1597 | + >>> with beam.Pipeline() as p: |
| 1598 | + ... with tft_beam.Context(temp_dir=tempfile.mkdtemp()): |
| 1599 | + ... data_pcoll = p | beam.Create(raw_data) |
| 1600 | + ... transformed_dataset, transform_fn = ( |
| 1601 | + ... (data_pcoll, raw_data_metadata) |
| 1602 | + ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) |
| 1603 | + ... _ = ( |
| 1604 | + ... transformed_dataset |
| 1605 | + ... | tft_beam.EncodeTransformedDataset() |
| 1606 | + ... | beam.io.WriteToTFRecord(output_path, shard_name_template='')) |
| 1607 | + >>> result_feature_spec ={'x_scaled': tf.io.FixedLenFeature([], tf.float32)} |
| 1608 | + >>> list(tf.data.TFRecordDataset([output_path]) |
| 1609 | + ... .map(lambda x: tf.io.parse_example(x, result_feature_spec)) |
| 1610 | + ... .as_numpy_iterator()) |
| 1611 | + [{'x_scaled': -1.2247448}, {'x_scaled': 0.0}, {'x_scaled': 1.2247448}] |
| 1612 | + """ |
| 1613 | + |
| 1614 | + def _extract_input_pvalues(self, transformed_data_and_metadata): |
| 1615 | + # This method lets beam know that metadata is not a pvalue. |
| 1616 | + return transformed_data_and_metadata, [transformed_data_and_metadata[0]] |
| 1617 | + |
| 1618 | + def expand(self, transformed_data_and_metadata): |
| 1619 | + |
| 1620 | + transformed_data, transformed_metadata = transformed_data_and_metadata |
| 1621 | + |
| 1622 | + deferred_schema = ( |
| 1623 | + transformed_metadata.deferred_metadata |
| 1624 | + | 'GetDeferredSchema' >> beam.Map(lambda m: m.schema)) |
| 1625 | + |
| 1626 | + if transformed_metadata.dataset_metadata._output_record_batches: # pylint: disable=protected-access |
| 1627 | + transformed_data_coder_pcol = ( |
| 1628 | + deferred_schema | 'RecordBatchToExamplesEncoder' >> beam.Map( |
| 1629 | + example_coder.RecordBatchToExamplesEncoder)) |
| 1630 | + encode_ptransform = 'EncodeRecordBatches' >> beam.FlatMap( |
| 1631 | + # Dropping passthrough features. |
| 1632 | + lambda elem, coder: coder.encode(elem[0]), |
| 1633 | + coder=beam.pvalue.AsSingleton(transformed_data_coder_pcol)) |
| 1634 | + else: |
| 1635 | + transformed_data_coder_pcol = ( |
| 1636 | + deferred_schema |
| 1637 | + | 'ExampleProtoCoder' >> beam.Map( |
| 1638 | + example_proto_coder.ExampleProtoCoder)) |
| 1639 | + encode_ptransform = 'EncodeInstances' >> beam.Map( |
| 1640 | + lambda data, data_coder: data_coder.encode(data), |
| 1641 | + data_coder=beam.pvalue.AsSingleton(transformed_data_coder_pcol)) |
| 1642 | + |
| 1643 | + return transformed_data | encode_ptransform |
0 commit comments