Skip to content

Commit a206811

Browse files
zoyahavtfx-copybara
authored andcommitted
Combine a few beam metrics to reduce the number of counters in a pipeline.
PiperOrigin-RevId: 469712221
1 parent fe2d383 commit a206811

File tree

5 files changed

+92
-51
lines changed

5 files changed

+92
-51
lines changed

tensorflow_transform/analyzer_nodes.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,12 +1025,16 @@ def is_partitionable(self):
10251025

10261026

10271027
class InstrumentDatasetCache(
1028-
tfx_namedtuple.namedtuple('InstrumentDatasetCache',
1029-
['dataset_key', 'label']), nodes.OperationDef):
1028+
tfx_namedtuple.namedtuple('InstrumentDatasetCache', [
1029+
'input_cache_dataset_keys', 'num_encode_cache', 'num_decode_cache',
1030+
'label'
1031+
]), nodes.OperationDef):
10301032
"""OperationDef instrumenting cached datasets.
10311033
10321034
Fields:
1033-
dataset_key: A dataset key.
1035+
input_cache_dataset_keys: A dataset keys for which there's input cache.
1036+
num_encode_cache: Number of cache entries encoded.
1037+
num_decode_cache: Number of cache entries decoded.
10341038
label: A unique label for this operation.
10351039
"""
10361040
__slots__ = ()

tensorflow_transform/beam/analysis_graph_builder.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -173,29 +173,33 @@ def __init__(self, dataset_keys, cache_dict, tensor_keys_to_paths,
173173
self._cache_dict = cache_dict
174174
self._tensor_keys_to_paths = tensor_keys_to_paths
175175
self._dataset_has_cache_misses = collections.defaultdict(bool)
176+
self._num_encode_cache_nodes = 0
177+
self._num_decode_cache_nodes = 0
176178
self.cache_output_nodes = cache_output_nodes
177179
self._num_phases = num_phases
178180

179181
def _validate_operation_def(self, operation_def):
180182
if operation_def.cache_coder is not None:
181183
if not operation_def.is_partitionable:
182184
raise ValueError(
183-
'Non partitionable OperationDefs cannot be cacheable: {}'.format(
184-
operation_def.label))
185+
'Non partitionable OperationDefs cannot be cacheable: '
186+
f'{operation_def.label}'
187+
)
185188
if operation_def.is_partitionable or operation_def.cache_coder is not None:
186189
if operation_def.num_outputs != 1:
187190
raise ValueError(
188-
'Cacheable OperationDefs must have exactly 1 output: {}'.format(
189-
operation_def.label))
191+
'Cacheable OperationDefs must have exactly 1 output: '
192+
f'{operation_def.label}'
193+
)
190194

191195
def get_detached_sideeffect_leafs(self):
192196
"""Returns a list of sideeffect leaf nodes after the visit is done."""
193197
# If this is a multi-phase analysis, then all datasets have to be read
194198
# anyway, and so we'll not instrument full cache coverage for this case.
195199
if self._num_phases > 1:
196200
return []
197-
result = []
198-
for (dataset_idx, dataset_key) in enumerate(self._sorted_dataset_keys):
201+
dataset_keys_with_decoded_cache = []
202+
for dataset_key in self._sorted_dataset_keys:
199203
# Default to True here, if the dataset_key is not in the cache misses map
200204
# then treat it like it does have cache misses because it has not been
201205
# visited in the optimization traversal.
@@ -207,12 +211,18 @@ def get_detached_sideeffect_leafs(self):
207211
cache_dict = self._cache_dict or {}
208212
dataset_cache_entries = cache_dict.get(dataset_key, None)
209213
if dataset_cache_entries is not None and dataset_cache_entries.metadata:
210-
node = nodes.apply_operation(
211-
analyzer_nodes.InstrumentDatasetCache,
212-
dataset_key=dataset_key,
213-
label=f'InstrumentDatasetCache[AnalysisIndex{dataset_idx}]')
214-
result.append(node)
215-
return result
214+
dataset_keys_with_decoded_cache.append(dataset_key)
215+
if (dataset_keys_with_decoded_cache or self._num_encode_cache_nodes or
216+
self._num_decode_cache_nodes):
217+
return [
218+
nodes.apply_operation(
219+
analyzer_nodes.InstrumentDatasetCache,
220+
input_cache_dataset_keys=dataset_keys_with_decoded_cache,
221+
num_encode_cache=self._num_encode_cache_nodes,
222+
num_decode_cache=self._num_decode_cache_nodes,
223+
label='InstrumentDatasetCache')
224+
]
225+
return []
216226

217227
def _make_next_hashed_path(self, parent_hashed_paths, operation_def):
218228
# Making a copy of parent_hashed_paths.
@@ -269,7 +279,7 @@ def visit(self, operation_def, input_values):
269279
next_inputs = nodes.apply_multi_output_operation(
270280
beam_nodes.Flatten,
271281
*disaggregated_input_values,
272-
label='FlattenCache[{}]'.format(operation_def.label))
282+
label=f'FlattenCache[{operation_def.label}]')
273283
else:
274284
# Parent operation output is not cacheable, therefore we can just use
275285
# a flattened view.
@@ -341,30 +351,31 @@ def _apply_operation_on_fine_grained_view(self, operation_def,
341351

342352
for (dataset_idx, dataset_key) in enumerate(self._sorted_dataset_keys):
343353
# We use an index for the label in order to make beam labels more stable.
344-
infix = 'AnalysisIndex{}'.format(dataset_idx)
354+
infix = f'AnalysisIndex{dataset_idx}'
345355
if (operation_def.cache_coder and self._cache_dict.get(
346356
dataset_key, {}).get(cache_entry_key) is not None):
347357
self._dataset_has_cache_misses[dataset_key] |= False
348358
decode_cache = analyzer_nodes.DecodeCache(
349359
dataset_key,
350360
cache_entry_key,
351361
coder=operation_def.cache_coder,
352-
label='DecodeCache[{}][{}]'.format(operation_def.label, infix))
362+
label=f'DecodeCache[{operation_def.label}][{infix}]')
353363
(op_output,) = nodes.OperationNode(decode_cache, tuple()).outputs
364+
self._num_decode_cache_nodes += 1
354365
else:
355366
value_nodes = tuple(v[dataset_key] for v in fine_grained_views)
356367
(op_output,) = nodes.OperationNode(
357-
operation_def._replace(
358-
label='{}[{}]'.format(operation_def.label, infix)),
368+
operation_def._replace(label=f'{operation_def.label}[{infix}]'),
359369
value_nodes).outputs
360370
if operation_def.cache_coder:
361371
self._dataset_has_cache_misses[dataset_key] = True
362372
encode_cache = nodes.apply_operation(
363373
analyzer_nodes.EncodeCache,
364374
op_output,
365375
coder=operation_def.cache_coder,
366-
label='EncodeCache[{}][{}]'.format(operation_def.label, infix))
376+
label=f'EncodeCache[{operation_def.label}][{infix}]')
367377
self.cache_output_nodes[(dataset_key, cache_entry_key)] = encode_cache
378+
self._num_encode_cache_nodes += 1
368379
result_fine_grained_view[dataset_key] = op_output
369380

370381
return result_fine_grained_view
@@ -377,16 +388,15 @@ def _visit_apply_savedmodel_operation(self, operation_def, upstream_views):
377388

378389
fine_grained_view = collections.OrderedDict()
379390
for (dataset_idx, dataset_key) in enumerate(self._sorted_dataset_keys):
380-
infix = 'AnalysisIndex{}'.format(dataset_idx)
391+
infix = f'AnalysisIndex{dataset_idx}'
381392
input_node = nodes.apply_operation(
382393
beam_nodes.ExtractInputForSavedModel,
383394
dataset_key=dataset_key,
384-
label='ExtractInputForSavedModel[{}]'.format(infix))
395+
label=f'ExtractInputForSavedModel[{infix}]')
385396
# We use an index for the label in order to make beam labels more stable.
386397
(fine_grained_view[dataset_key],) = (
387398
nodes.OperationNode(
388-
operation_def._replace(
389-
label='{}[{}]'.format(operation_def.label, infix)),
399+
operation_def._replace(label=f'{operation_def.label}[{infix}]'),
390400
(saved_model_path_upstream_view.flattened_view,
391401
input_node)).outputs)
392402

@@ -404,8 +414,8 @@ def validate_value(self, value):
404414
assert isinstance(value, _OptimizationView), value
405415
if value.fine_grained_view:
406416
assert set(value.fine_grained_view.keys()) == set(
407-
self._sorted_dataset_keys), ('{} != {}'.format(
408-
value.fine_grained_view.keys(), self._sorted_dataset_keys))
417+
self._sorted_dataset_keys
418+
), (f'{value.fine_grained_view.keys()} != {self._sorted_dataset_keys}')
409419

410420

411421
def _perform_cache_optimization(saved_model_future, dataset_keys,
@@ -593,7 +603,7 @@ def preprocessing_fn(input)
593603
label='ExtractInputForSavedModel[FlattenedDataset]')
594604

595605
while not all(sink_tensors_ready.values()):
596-
infix = 'Phase{}'.format(phase)
606+
infix = f'Phase{phase}'
597607
# Determine which table init ops are ready to run in this phase
598608
# Determine which keys of pending_tensor_replacements are ready to run
599609
# in this phase, based in whether their dependencies are ready.
@@ -610,14 +620,14 @@ def preprocessing_fn(input)
610620
*tensor_bindings,
611621
table_initializers=tuple(graph_analyzer.ready_table_initializers),
612622
output_signature=intermediate_output_signature,
613-
label='CreateSavedModelForAnalyzerInputs[{}]'.format(infix))
623+
label=f'CreateSavedModelForAnalyzerInputs[{infix}]')
614624

615625
extracted_values_dict = nodes.apply_operation(
616626
beam_nodes.ApplySavedModel,
617627
saved_model_future,
618628
extracted_input_node,
619629
phase=phase,
620-
label='ApplySavedModel[{}]'.format(infix))
630+
label=f'ApplySavedModel[{infix}]')
621631

622632
translate_visitor.phase = phase
623633
translate_visitor.intermediate_output_signature = (
@@ -643,7 +653,7 @@ def preprocessing_fn(input)
643653
dtype_enum=tensor.dtype.as_datatype_enum,
644654
is_asset_filepath=is_asset_filepath,
645655
label=analyzer_nodes.sanitize_label(
646-
'CreateTensorBinding[{}]'.format(name))))
656+
f'CreateTensorBinding[{name}]')))
647657
sink_tensors_ready[hashable_tensor] = True
648658

649659
analyzers_input_signature.update(intermediate_output_signature)

tensorflow_transform/beam/analyzer_impls.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,30 +1325,51 @@ def __init__(self, operation, extra_args):
13251325
def expand(self, inputs):
13261326
pcoll, = inputs
13271327

1328-
return (pcoll
1329-
| 'Encode' >> beam.Map(self._coder.encode_cache)
1330-
| 'Count' >> common.IncrementCounter('cache_entries_encoded'))
1328+
return pcoll | 'Encode' >> beam.Map(self._coder.encode_cache)
13311329

13321330

13331331
@common.register_ptransform(analyzer_nodes.InstrumentDatasetCache)
13341332
@beam.typehints.with_input_types(beam.pvalue.PBegin)
13351333
@beam.typehints.with_output_types(None)
13361334
class _InstrumentDatasetCacheImpl(beam.PTransform):
1337-
"""Instruments datasets not read due to cache hit."""
1335+
"""Instruments pipeline analysis cache usage."""
13381336

13391337
def __init__(self, operation, extra_args):
1340-
self._metadata_pcoll = (
1341-
extra_args.cache_pcoll_dict[operation.dataset_key].metadata)
1338+
self.pipeline = extra_args.pipeline
1339+
self._metadata_pcolls = tuple(extra_args.cache_pcoll_dict[k].metadata
1340+
for k in operation.input_cache_dataset_keys)
1341+
self._num_encode_cache = operation.num_encode_cache
1342+
self._num_decode_cache = operation.num_decode_cache
13421343

1343-
def _make_and_increment_counter(self, metadata):
1344-
if metadata:
1345-
beam.metrics.Metrics.counter(common.METRICS_NAMESPACE,
1346-
'analysis_input_bytes_from_cache').inc(
1347-
metadata.dataset_size)
1344+
def _make_and_increment_counter(self, value, name):
1345+
beam.metrics.Metrics.counter(common.METRICS_NAMESPACE, name).inc(value)
13481346

13491347
def expand(self, pbegin):
1350-
return (self._metadata_pcoll | 'InstrumentCachedInputBytes' >> beam.Map(
1351-
self._make_and_increment_counter))
1348+
if self._num_encode_cache > 0:
1349+
_ = (
1350+
pbegin
1351+
| 'CreateSoleCacheEncodeInstrument' >> beam.Create(
1352+
[self._num_encode_cache])
1353+
| 'InstrumentCacheEncode' >> beam.Map(
1354+
self._make_and_increment_counter, 'cache_entries_encoded'))
1355+
if self._num_decode_cache > 0:
1356+
_ = (
1357+
self.pipeline
1358+
| 'CreateSoleCacheDecodeInstrument' >> beam.Create(
1359+
[self._num_decode_cache])
1360+
| 'InstrumentCacheDecode' >> beam.Map(
1361+
self._make_and_increment_counter, 'cache_entries_decoded'))
1362+
if self._metadata_pcolls:
1363+
# Instruments datasets not read due to cache hit.
1364+
_ = (
1365+
self._metadata_pcolls | beam.Flatten(pipeline=self.pipeline)
1366+
| 'ExtractCachedInputBytes' >>
1367+
beam.Map(lambda m: m.dataset_size if m else 0)
1368+
| 'SumCachedInputBytes' >> beam.CombineGlobally(sum)
1369+
| 'InstrumentCachedInputBytes' >> beam.Map(
1370+
self._make_and_increment_counter,
1371+
'analysis_input_bytes_from_cache'))
1372+
return pbegin | 'CreateSoleEmptyOutput' >> beam.Create([])
13521373

13531374

13541375
@common.register_ptransform(analyzer_nodes.DecodeCache)
@@ -1366,9 +1387,7 @@ def __init__(self, operation, extra_args):
13661387
def expand(self, pbegin):
13671388
del pbegin # unused
13681389

1369-
return (self._cache_pcoll
1370-
| 'Decode' >> beam.Map(self._coder.decode_cache)
1371-
| 'Count' >> common.IncrementCounter('cache_entries_decoded'))
1390+
return self._cache_pcoll | 'Decode' >> beam.Map(self._coder.decode_cache)
13721391

13731392

13741393
@common.register_ptransform(analyzer_nodes.AddKey)

tensorflow_transform/beam/cached_impl_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ def is_partitionable(self):
418418
"FakeChainableCacheable[x/cacheable2][AnalysisIndex0]" -> "EncodeCache[FakeChainableCacheable[x/cacheable2]][AnalysisIndex0]";
419419
"EncodeCache[FakeChainableCacheable[x/cacheable2]][AnalysisIndex1]" [label="{EncodeCache|coder: Not-a-coder-but-thats-ok!|label: EncodeCache[FakeChainableCacheable[x/cacheable2]][AnalysisIndex1]|partitionable: True}"];
420420
"FakeChainableCacheable[x/cacheable2][AnalysisIndex1]" -> "EncodeCache[FakeChainableCacheable[x/cacheable2]][AnalysisIndex1]";
421+
InstrumentDatasetCache [label="{InstrumentDatasetCache|input_cache_dataset_keys: []|num_encode_cache: 4|num_decode_cache: 0|label: InstrumentDatasetCache|partitionable: True}"];
421422
}
422423
""")
423424

tensorflow_transform/beam/impl.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,16 +1097,23 @@ def expand(self, dataset):
10971097
telemetry.TrackRecordBatchBytes(beam_common.METRICS_NAMESPACE,
10981098
'analysis_input_bytes'))
10991099
else:
1100+
bytes_per_dataset = []
11001101
for idx, key in enumerate(sorted(input_values_pcoll_dict.keys())):
11011102
infix = f'AnalysisIndex{idx}'
11021103
if input_values_pcoll_dict[key] is not None:
1104+
bytes_per_dataset.append(input_values_pcoll_dict[key]
1105+
| f'ExtractInputBytes[{infix}]' >>
1106+
telemetry.ExtractRecordBatchBytes())
11031107
dataset_metrics[key] = (
1104-
input_values_pcoll_dict[key]
1105-
| f'InstrumentInputBytes[AnalysisPCollDict][{infix}]' >>
1106-
telemetry.TrackRecordBatchBytes(beam_common.METRICS_NAMESPACE,
1107-
'analysis_input_bytes')
1108+
bytes_per_dataset[-1]
11081109
| f'ConstructMetadata[{infix}]' >> beam.Map(
11091110
analyzer_cache.DatasetCacheMetadata))
1111+
_ = (
1112+
bytes_per_dataset
1113+
| 'FlattenAnalysisBytes' >> beam.Flatten(pipeline=pipeline)
1114+
| 'InstrumentInputBytes[AnalysisPCollDict]' >>
1115+
telemetry.IncrementCounter(beam_common.METRICS_NAMESPACE,
1116+
'analysis_input_bytes'))
11101117

11111118
# Gather telemetry on types of input features.
11121119
_ = (

0 commit comments

Comments
 (0)