|
51 | 51 | from tensorflow_metadata.proto.v0 import statistics_pb2 |
52 | 52 |
|
53 | 53 |
|
54 | | -# The combiner accumulates record batches from the upstream and merge them when |
55 | | -# certain conditions are met. A merged record batch would allow better |
56 | | -# vectorized processing, # but we have to pay for copying and the RAM to contain |
57 | | -# the merged record batch. If the total byte size of accumulated record batches |
58 | | -# exceeds this threshold a merge will be forced to avoid consuming too much |
59 | | -# memory. |
60 | | -_MERGE_RECORD_BATCH_BYTE_SIZE_THRESHOLD = 20 << 20 # 20MiB |
61 | | - |
62 | | - |
63 | 54 | @beam.typehints.with_input_types(pa.RecordBatch) |
64 | 55 | @beam.typehints.with_output_types(statistics_pb2.DatasetFeatureStatisticsList) |
65 | 56 | class GenerateStatisticsImpl(beam.PTransform): |
@@ -576,7 +567,15 @@ class _CombinerStatsGeneratorsCombineFn(beam.CombineFn): |
576 | 567 | # accumulators in the individual stats generators, but shouldn't be too large |
577 | 568 | # as it also acts as cap on the maximum memory usage of the computation. |
578 | 569 | # TODO(b/73789023): Ideally we should automatically infer the batch size. |
579 | | - _DEFAULT_DESIRED_MERGE_ACCUMULATOR_BATCH_SIZE = 100 |
| 570 | + _DESIRED_MERGE_ACCUMULATOR_BATCH_SIZE = 100 |
| 571 | + |
| 572 | + # The combiner accumulates record batches from the upstream and merges them |
| 573 | + # when certain conditions are met. A merged record batch would allow better |
| 574 | + # vectorized processing, but we have to pay for copying and the RAM to |
| 575 | + # contain the merged record batch. If the total byte size of accumulated |
| 576 | + # record batches exceeds this threshold a merge will be forced to avoid |
| 577 | + # consuming too much memory. |
| 578 | + _MERGE_RECORD_BATCH_BYTE_SIZE_THRESHOLD = 20 << 20 # 20MiB |
580 | 579 |
|
581 | 580 | def __init__( |
582 | 581 | self, |
@@ -633,7 +632,8 @@ def _should_do_batch(self, accumulator: _CombinerStatsGeneratorsCombineFnAcc, |
633 | 632 | if curr_batch_size >= self._desired_batch_size: |
634 | 633 | return True |
635 | 634 |
|
636 | | - if accumulator.curr_byte_size >= _MERGE_RECORD_BATCH_BYTE_SIZE_THRESHOLD: |
| 635 | + if (accumulator.curr_byte_size >= |
| 636 | + self._MERGE_RECORD_BATCH_BYTE_SIZE_THRESHOLD): |
637 | 637 | return True |
638 | 638 |
|
639 | 639 | return False |
@@ -690,7 +690,7 @@ def merge_accumulators( |
690 | 690 | # Repeatedly take the next N from `accumulators` (an iterator). |
691 | 691 | # If there are less than N remaining, all is taken. |
692 | 692 | batched_accumulators = list(itertools.islice( |
693 | | - accumulators, self._DEFAULT_DESIRED_MERGE_ACCUMULATOR_BATCH_SIZE)) |
| 693 | + accumulators, self._DESIRED_MERGE_ACCUMULATOR_BATCH_SIZE)) |
694 | 694 | if not batched_accumulators: |
695 | 695 | break |
696 | 696 |
|
|
0 commit comments