Skip to content

Commit 8f91037

Browse files
paulgc17tfx-copybara
authored andcommitted
Minor refactoring
PiperOrigin-RevId: 311664311
1 parent ed4aaec commit 8f91037

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

tensorflow_data_validation/statistics/stats_impl.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,6 @@
5151
from tensorflow_metadata.proto.v0 import statistics_pb2
5252

5353

54-
# The combiner accumulates record batches from the upstream and merge them when
55-
# certain conditions are met. A merged record batch would allow better
56-
# vectorized processing, # but we have to pay for copying and the RAM to contain
57-
# the merged record batch. If the total byte size of accumulated record batches
58-
# exceeds this threshold a merge will be forced to avoid consuming too much
59-
# memory.
60-
_MERGE_RECORD_BATCH_BYTE_SIZE_THRESHOLD = 20 << 20 # 20MiB
61-
62-
6354
@beam.typehints.with_input_types(pa.RecordBatch)
6455
@beam.typehints.with_output_types(statistics_pb2.DatasetFeatureStatisticsList)
6556
class GenerateStatisticsImpl(beam.PTransform):
@@ -576,7 +567,15 @@ class _CombinerStatsGeneratorsCombineFn(beam.CombineFn):
576567
# accumulators in the individual stats generators, but shouldn't be too large
577568
# as it also acts as cap on the maximum memory usage of the computation.
578569
# TODO(b/73789023): Ideally we should automatically infer the batch size.
579-
_DEFAULT_DESIRED_MERGE_ACCUMULATOR_BATCH_SIZE = 100
570+
_DESIRED_MERGE_ACCUMULATOR_BATCH_SIZE = 100
571+
572+
# The combiner accumulates record batches from the upstream and merges them
573+
# when certain conditions are met. A merged record batch would allow better
574+
# vectorized processing, but we have to pay for copying and the RAM to
575+
# contain the merged record batch. If the total byte size of accumulated
576+
# record batches exceeds this threshold a merge will be forced to avoid
577+
# consuming too much memory.
578+
_MERGE_RECORD_BATCH_BYTE_SIZE_THRESHOLD = 20 << 20 # 20MiB
580579

581580
def __init__(
582581
self,
@@ -633,7 +632,8 @@ def _should_do_batch(self, accumulator: _CombinerStatsGeneratorsCombineFnAcc,
633632
if curr_batch_size >= self._desired_batch_size:
634633
return True
635634

636-
if accumulator.curr_byte_size >= _MERGE_RECORD_BATCH_BYTE_SIZE_THRESHOLD:
635+
if (accumulator.curr_byte_size >=
636+
self._MERGE_RECORD_BATCH_BYTE_SIZE_THRESHOLD):
637637
return True
638638

639639
return False
@@ -690,7 +690,7 @@ def merge_accumulators(
690690
# Repeatedly take the next N from `accumulators` (an iterator).
691691
# If there are less than N remaining, all is taken.
692692
batched_accumulators = list(itertools.islice(
693-
accumulators, self._DEFAULT_DESIRED_MERGE_ACCUMULATOR_BATCH_SIZE))
693+
accumulators, self._DESIRED_MERGE_ACCUMULATOR_BATCH_SIZE))
694694
if not batched_accumulators:
695695
break
696696

0 commit comments

Comments
 (0)