5050from tensorflow_metadata .proto .v0 import statistics_pb2
5151
5252
53+ # The combiner accumulates tables from the upstream and merge them when certain
54+ # conditions are met. A merged table would allow better vectorized processing,
55+ # but we have to pay for copying and the RAM to contain the merged table.
56+ # If the total byte size of accumulated tables exceeds this threshold a merge
57+ # will be forced to avoid consuming too much memory.
58+ _MERGE_TABLE_BYTE_SIZE_THRESHOLD = 20 << 20 # 20MiB
59+
60+
5361@beam .typehints .with_input_types (pa .Table )
5462@beam .typehints .with_output_types (statistics_pb2 .DatasetFeatureStatisticsList )
5563class GenerateStatisticsImpl (beam .PTransform ):
@@ -505,7 +513,8 @@ def extract_output(self, accumulator: List[float]
505513class _CombinerStatsGeneratorsCombineFnAcc (object ):
506514 """accumulator for _CombinerStatsGeneratorsCombineFn."""
507515
508- __slots__ = ['partial_accumulators' , 'input_tables' , 'curr_batch_size' ]
516+ __slots__ = ['partial_accumulators' , 'input_tables' , 'curr_batch_size' ,
517+ 'curr_byte_size' ]
509518
510519 def __init__ (self , partial_accumulators : List [Any ]):
511520 # Partial accumulator states of the underlying CombinerStatsGenerators.
@@ -514,6 +523,8 @@ def __init__(self, partial_accumulators: List[Any]):
514523 self .input_tables = []
515524 # Current batch size.
516525 self .curr_batch_size = 0
526+ # Current total byte size of all the pa.Tables accumulated.
527+ self .curr_byte_size = 0
517528
518529
519530@beam .typehints .with_input_types (pa .Table )
@@ -544,7 +555,7 @@ class _CombinerStatsGeneratorsCombineFn(beam.CombineFn):
544555 """
545556
546557 __slots__ = ['_generators' , '_desired_batch_size' , '_combine_batch_size' ,
547- '_num_compacts' , '_num_instances' ]
558+ '_combine_byte_size' , ' _num_compacts' , '_num_instances' ]
548559
549560 # This needs to be large enough to allow for efficient merging of
550561 # accumulators in the individual stats generators, but shouldn't be too large
@@ -569,6 +580,8 @@ def __init__(
569580 # Metrics
570581 self ._combine_batch_size = beam .metrics .Metrics .distribution (
571582 constants .METRICS_NAMESPACE , 'combine_batch_size' )
583+ self ._combine_byte_size = beam .metrics .Metrics .distribution (
584+ constants .METRICS_NAMESPACE , 'combine_byte_size' )
572585 self ._num_compacts = beam .metrics .Metrics .counter (
573586 constants .METRICS_NAMESPACE , 'num_compacts' )
574587 self ._num_instances = beam .metrics .Metrics .counter (
@@ -596,6 +609,20 @@ def create_accumulator(self
596609 return _CombinerStatsGeneratorsCombineFnAcc (
597610 [g .create_accumulator () for g in self ._generators ])
598611
612+ def _should_do_batch (self , accumulator : _CombinerStatsGeneratorsCombineFnAcc ,
613+ force : bool ) -> bool :
614+ curr_batch_size = accumulator .curr_batch_size
615+ if force and curr_batch_size > 0 :
616+ return True
617+
618+ if curr_batch_size >= self ._desired_batch_size :
619+ return True
620+
621+ if accumulator .curr_byte_size >= _MERGE_TABLE_BYTE_SIZE_THRESHOLD :
622+ return True
623+
624+ return False
625+
599626 def _maybe_do_batch (
600627 self ,
601628 accumulator : _CombinerStatsGeneratorsCombineFnAcc ,
@@ -610,9 +637,9 @@ def _maybe_do_batch(
610637 force: Force computation of stats even if accumulator has less examples
611638 than the batch size.
612639 """
613- batch_size = accumulator . curr_batch_size
614- if ( force and batch_size > 0 ) or batch_size >= self ._desired_batch_size :
615- self ._combine_batch_size .update (batch_size )
640+ if self . _should_do_batch ( accumulator , force ):
641+ self ._combine_batch_size . update ( accumulator . curr_batch_size )
642+ self ._combine_byte_size .update (accumulator . curr_byte_size )
616643 if len (accumulator .input_tables ) == 1 :
617644 arrow_table = accumulator .input_tables [0 ]
618645 else :
@@ -622,6 +649,7 @@ def _maybe_do_batch(
622649 accumulator .partial_accumulators )
623650 del accumulator .input_tables [:]
624651 accumulator .curr_batch_size = 0
652+ accumulator .curr_byte_size = 0
625653
626654 def add_input (
627655 self , accumulator : _CombinerStatsGeneratorsCombineFnAcc ,
@@ -630,6 +658,7 @@ def add_input(
630658 accumulator .input_tables .append (input_table )
631659 num_rows = input_table .num_rows
632660 accumulator .curr_batch_size += num_rows
661+ accumulator .curr_byte_size += table_util .TotalByteSize (input_table )
633662 self ._maybe_do_batch (accumulator )
634663 self ._num_instances .inc (num_rows )
635664 return accumulator
@@ -657,6 +686,7 @@ def merge_accumulators(
657686 for acc in batched_accumulators :
658687 result .input_tables .extend (acc .input_tables )
659688 result .curr_batch_size += acc .curr_batch_size
689+ result .curr_byte_size += acc .curr_byte_size
660690 self ._maybe_do_batch (result )
661691 batched_partial_accumulators .append (acc .partial_accumulators )
662692
0 commit comments