Skip to content

Commit d4c8b92

Browse files
iindyktfx-copybara
authored andcommitted
Slice transformed data batches into smaller chunks if their size exceeds 200MB.
With TFXIO inputs Transform does not control input batch size. Output batch size is given by the input batch size + preprocessing_fn logic. preprocessing_fn can increase the size of the data; this combined with aggressively batched inputs can result in large output batches that may cause troubles downstream. PiperOrigin-RevId: 551903601
1 parent d59b84d commit d4c8b92

File tree

3 files changed

+137
-15
lines changed

3 files changed

+137
-15
lines changed

RELEASE.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
* `approximate_vocabulary` now returns tokens with the same frequency in
1414
reverse lexicographical order (similarly to `tft.vocabulary`).
15+
* Transformed data batches are now sliced into smaller chunks if their size
16+
exceeds 200MB.
1517
* Depends on `pyarrow>=10,<11`.
1618
* Depends on `apache-beam>=2.47,<3`.
1719
* Depends on `numpy>=1.22.0`.

tensorflow_transform/beam/impl.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,20 @@
126126
fn_api_runner.FnApiRunner: _FIXED_PARALLELISM_TF_CONFIG,
127127
}
128128

129+
# Batches larger than this will be sliced into smaller chunks. This size
130+
# constraint must be at least as strict as the following constraints:
131+
# 1. Number of elements in each individual array of the batch must be less
132+
# than or equal to 2^31 - 1. Beam's `pa.RecordBatch` PCoder does not
133+
# support larger sizes (even though the produced containers such as
134+
# LargeListArray and LargeBinaryArray support them).
135+
# 2. Serialized size of the batch must be less than 2GB. Beam's shuffle
136+
# stage will wrap the serialized batches into a proto for materialization.
137+
# 2GB is the proto size limit.
138+
# We set a much stricter limit than the above to additionaly improve the outputs
139+
# handling by making the size distributed over larger number of (still
140+
# reasonably big) batches.
141+
_MAX_TRANSFORMED_BATCH_BYTES_SIZE = 200 << 10 << 10 # 200MB
142+
129143
# TODO(b/68154497): pylint: disable=no-value-for-parameter
130144

131145

@@ -412,6 +426,34 @@ def _warn_about_tf_compat_v1():
412426
'Features such as tf.function may not work as intended.')
413427

414428

429+
def _maybe_slice_large_record_batch(
430+
record_batch: pa.RecordBatch,
431+
) -> Iterable[pa.RecordBatch]:
432+
"""Slices large batches into smaller chunks."""
433+
if record_batch.nbytes > _MAX_TRANSFORMED_BATCH_BYTES_SIZE:
434+
if record_batch.num_rows < 2:
435+
logging.warning(
436+
'Transformed data row may be too large: %d bytes. '
437+
'Consider reshaping outputs to distribute elements over a larger '
438+
'number of rows to allow automatic slicing.',
439+
record_batch.nbytes,
440+
)
441+
yield record_batch
442+
return
443+
# Note that slicing is a zero-copy operation, so the produced batches will
444+
# still share memory with the original one up to the materialization
445+
# boundary.
446+
mid_point = record_batch.num_rows // 2
447+
yield from _maybe_slice_large_record_batch(
448+
record_batch.slice(offset=0, length=mid_point)
449+
)
450+
yield from _maybe_slice_large_record_batch(
451+
record_batch.slice(offset=mid_point)
452+
)
453+
else:
454+
yield record_batch
455+
456+
415457
def _convert_to_record_batch(
416458
batch_dict: Dict[str, Union[common_types.TensorValueType, pa.Array]],
417459
converter: tensor_to_arrow.TensorsToRecordBatchConverter,
@@ -420,8 +462,8 @@ def _convert_to_record_batch(
420462
TensorAdapterConfig, dataset_metadata.DatasetMetadata
421463
],
422464
validate_varlen_sparse_values: bool = False,
423-
) -> Tuple[pa.RecordBatch, Dict[str, pa.Array]]:
424-
"""Convert batches of ndarrays to pyarrow.RecordBatch."""
465+
) -> Iterable[Tuple[pa.RecordBatch, Dict[str, pa.Array]]]:
466+
"""Convert batch of ndarrays to pyarrow.RecordBatches."""
425467

426468
# Making a copy of batch_dict because mutating PCollection elements is not
427469
# allowed.
@@ -466,9 +508,10 @@ def _convert_to_record_batch(
466508
arrow_columns.append(data)
467509
else:
468510
unary_passthrough_features[key] = data
469-
470-
return pa.RecordBatch.from_arrays(
471-
arrow_columns, schema=arrow_schema), unary_passthrough_features
511+
for reccord_batch in _maybe_slice_large_record_batch(
512+
pa.RecordBatch.from_arrays(arrow_columns, schema=arrow_schema)
513+
):
514+
yield reccord_batch, unary_passthrough_features
472515

473516

474517
def _transformed_batch_to_instance_dicts(
@@ -1545,7 +1588,7 @@ def expand(self, dataset_and_transform_fn):
15451588
)
15461589
)
15471590

1548-
output_data = output_batches | 'ConvertToRecordBatch' >> beam.Map(
1591+
output_data = output_batches | 'ConvertToRecordBatch' >> beam.FlatMap(
15491592
_convert_to_record_batch,
15501593
converter=beam.pvalue.AsSingleton(converter_pcol),
15511594
passthrough_keys=Context.get_passthrough_keys(),

tensorflow_transform/beam/impl_output_record_batches_test.py

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
from tensorflow_transform.beam import impl_test
2424
from tensorflow_transform.beam import tft_unit
2525
from tensorflow_transform.tf_metadata import schema_utils
26-
from tfx_bsl.tfxio.tensor_adapter import TensorAdapterConfig
26+
from tfx_bsl.tfxio import tensor_adapter
27+
28+
_LARGE_BATCH_SIZE = 1 << 10
2729

2830

2931
class BeamImplOutputRecordBatchesTest(impl_test.BeamImplTest):
@@ -91,9 +93,11 @@ def testConvertToRecordBatchPassthroughData(self):
9193
(passthrough_key4, batch_dict[passthrough_key4].type)
9294
])
9395
# Note that we only need `input_metadata.arrow_schema`.
94-
input_metadata = TensorAdapterConfig(arrow_schema, {})
95-
record_batch, unary_features = impl._convert_to_record_batch(
96-
batch_dict, converter, passthrough_keys, input_metadata)
96+
input_metadata = tensor_adapter.TensorAdapterConfig(arrow_schema, {})
97+
converted = list(impl._convert_to_record_batch(
98+
batch_dict, converter, passthrough_keys, input_metadata))
99+
self.assertLen(converted, 1)
100+
record_batch, unary_features = converted[0]
97101
expected_record_batch = {
98102
'a': [[100], [1], [10]],
99103
passthrough_key1: [[1], None, [0]]
@@ -115,11 +119,84 @@ def testConvertToRecordBatchPassthroughData(self):
115119
pa.large_list(pa.int64()))
116120
input_metadata.arrow_schema = input_metadata.arrow_schema.append(
117121
pa.field(passthrough_key5, batch_dict[passthrough_key5].type))
118-
with self.assertRaisesRegexp(
119-
ValueError, 'Cannot pass-through data when '
120-
'input and output batch sizes are different'):
121-
_ = impl._convert_to_record_batch(batch_dict, converter, passthrough_keys,
122-
input_metadata)
122+
with self.assertRaisesRegex(
123+
ValueError,
124+
'Cannot pass-through data when '
125+
'input and output batch sizes are different',
126+
):
127+
_ = list(
128+
impl._convert_to_record_batch(
129+
batch_dict, converter, passthrough_keys, input_metadata
130+
)
131+
)
132+
133+
@tft_unit.named_parameters(
134+
dict(
135+
testcase_name='NoPassthroughData',
136+
passthrough_data={},
137+
expected_unary_features={},
138+
),
139+
dict(
140+
testcase_name='WithPassthroughData',
141+
passthrough_data={
142+
'__passthrough_with_batch_length__': pa.array(
143+
[[1]] * _LARGE_BATCH_SIZE, pa.large_list(pa.int64())
144+
),
145+
'__passthrough_with_one_value__': pa.array(
146+
[None], pa.large_list(pa.float32())
147+
),
148+
},
149+
expected_unary_features={
150+
'__passthrough_with_one_value__': pa.array(
151+
[None], pa.large_list(pa.float32())
152+
),
153+
},
154+
),
155+
)
156+
def testConvertToLargeRecordBatch(
157+
self, passthrough_data, expected_unary_features
158+
):
159+
"""Tests slicing of large transformed batches during conversion."""
160+
# Any Beam test pipeline handling elements this large crashes the program
161+
# with OOM (even with 28GB memory available), so we test the conversion
162+
# pretty narrowly.
163+
164+
# 2^31 elements in total.
165+
num_values = 1 << 21
166+
batch_dict = {
167+
'a': np.zeros([_LARGE_BATCH_SIZE, num_values], np.float32),
168+
**passthrough_data,
169+
}
170+
schema = schema_utils.schema_from_feature_spec(
171+
{'a': tf.io.FixedLenFeature([num_values], tf.float32)}
172+
)
173+
converter = impl_helper.make_tensor_to_arrow_converter(schema)
174+
arrow_schema = pa.schema(
175+
[
176+
('a', pa.large_list(pa.float32())),
177+
]
178+
+ [(key, value.type) for key, value in passthrough_data.items()]
179+
)
180+
input_metadata = tensor_adapter.TensorAdapterConfig(arrow_schema, {})
181+
actual_num_rows = 0
182+
actual_num_batches = 0
183+
# Features are either going to be in the `record_batch` or in
184+
# `unary_features`.
185+
record_batch_features = set(batch_dict.keys()) - set(
186+
expected_unary_features.keys()
187+
)
188+
for record_batch, unary_features in impl._convert_to_record_batch(
189+
batch_dict, converter, set(passthrough_data.keys()), input_metadata
190+
):
191+
self.assertEqual(set(record_batch.schema.names), record_batch_features)
192+
self.assertEqual(unary_features, expected_unary_features)
193+
self.assertLessEqual(
194+
record_batch.nbytes, impl._MAX_TRANSFORMED_BATCH_BYTES_SIZE
195+
)
196+
actual_num_rows += record_batch.num_rows
197+
actual_num_batches += 1
198+
self.assertEqual(actual_num_rows, _LARGE_BATCH_SIZE)
199+
self.assertGreater(actual_num_batches, 1)
123200

124201

125202
if __name__ == '__main__':

0 commit comments

Comments
 (0)