@@ -173,29 +173,33 @@ def __init__(self, dataset_keys, cache_dict, tensor_keys_to_paths,
173173 self ._cache_dict = cache_dict
174174 self ._tensor_keys_to_paths = tensor_keys_to_paths
175175 self ._dataset_has_cache_misses = collections .defaultdict (bool )
176+ self ._num_encode_cache_nodes = 0
177+ self ._num_decode_cache_nodes = 0
176178 self .cache_output_nodes = cache_output_nodes
177179 self ._num_phases = num_phases
178180
179181 def _validate_operation_def (self , operation_def ):
180182 if operation_def .cache_coder is not None :
181183 if not operation_def .is_partitionable :
182184 raise ValueError (
183- 'Non partitionable OperationDefs cannot be cacheable: {}' .format (
184- operation_def .label ))
185+ 'Non partitionable OperationDefs cannot be cacheable: '
186+ f'{ operation_def .label } '
187+ )
185188 if operation_def .is_partitionable or operation_def .cache_coder is not None :
186189 if operation_def .num_outputs != 1 :
187190 raise ValueError (
188- 'Cacheable OperationDefs must have exactly 1 output: {}' .format (
189- operation_def .label ))
191+ 'Cacheable OperationDefs must have exactly 1 output: '
192+ f'{ operation_def .label } '
193+ )
190194
191195 def get_detached_sideeffect_leafs (self ):
192196 """Returns a list of sideeffect leaf nodes after the visit is done."""
193197 # If this is a multi-phase analysis, then all datasets have to be read
194198 # anyway, and so we'll not instrument full cache coverage for this case.
195199 if self ._num_phases > 1 :
196200 return []
197- result = []
198- for ( dataset_idx , dataset_key ) in enumerate ( self ._sorted_dataset_keys ) :
201+ dataset_keys_with_decoded_cache = []
202+ for dataset_key in self ._sorted_dataset_keys :
199203 # Default to True here, if the dataset_key is not in the cache misses map
200204 # then treat it like it does have cache misses because it has not been
201205 # visited in the optimization traversal.
@@ -207,12 +211,18 @@ def get_detached_sideeffect_leafs(self):
207211 cache_dict = self ._cache_dict or {}
208212 dataset_cache_entries = cache_dict .get (dataset_key , None )
209213 if dataset_cache_entries is not None and dataset_cache_entries .metadata :
210- node = nodes .apply_operation (
211- analyzer_nodes .InstrumentDatasetCache ,
212- dataset_key = dataset_key ,
213- label = f'InstrumentDatasetCache[AnalysisIndex{ dataset_idx } ]' )
214- result .append (node )
215- return result
214+ dataset_keys_with_decoded_cache .append (dataset_key )
215+ if (dataset_keys_with_decoded_cache or self ._num_encode_cache_nodes or
216+ self ._num_decode_cache_nodes ):
217+ return [
218+ nodes .apply_operation (
219+ analyzer_nodes .InstrumentDatasetCache ,
220+ input_cache_dataset_keys = dataset_keys_with_decoded_cache ,
221+ num_encode_cache = self ._num_encode_cache_nodes ,
222+ num_decode_cache = self ._num_decode_cache_nodes ,
223+ label = 'InstrumentDatasetCache' )
224+ ]
225+ return []
216226
217227 def _make_next_hashed_path (self , parent_hashed_paths , operation_def ):
218228 # Making a copy of parent_hashed_paths.
@@ -269,7 +279,7 @@ def visit(self, operation_def, input_values):
269279 next_inputs = nodes .apply_multi_output_operation (
270280 beam_nodes .Flatten ,
271281 * disaggregated_input_values ,
272- label = 'FlattenCache[{}]' . format ( operation_def .label ) )
282+ label = f 'FlattenCache[{ operation_def .label } ]' )
273283 else :
274284 # Parent operation output is not cacheable, therefore we can just use
275285 # a flattened view.
@@ -341,30 +351,31 @@ def _apply_operation_on_fine_grained_view(self, operation_def,
341351
342352 for (dataset_idx , dataset_key ) in enumerate (self ._sorted_dataset_keys ):
343353 # We use an index for the label in order to make beam labels more stable.
344- infix = 'AnalysisIndex{}' . format ( dataset_idx )
354+ infix = f 'AnalysisIndex{ dataset_idx } '
345355 if (operation_def .cache_coder and self ._cache_dict .get (
346356 dataset_key , {}).get (cache_entry_key ) is not None ):
347357 self ._dataset_has_cache_misses [dataset_key ] |= False
348358 decode_cache = analyzer_nodes .DecodeCache (
349359 dataset_key ,
350360 cache_entry_key ,
351361 coder = operation_def .cache_coder ,
352- label = 'DecodeCache[{}][{}]' . format ( operation_def . label , infix ) )
362+ label = f 'DecodeCache[{ operation_def . label } ][{ infix } ]' )
353363 (op_output ,) = nodes .OperationNode (decode_cache , tuple ()).outputs
364+ self ._num_decode_cache_nodes += 1
354365 else :
355366 value_nodes = tuple (v [dataset_key ] for v in fine_grained_views )
356367 (op_output ,) = nodes .OperationNode (
357- operation_def ._replace (
358- label = '{}[{}]' .format (operation_def .label , infix )),
368+ operation_def ._replace (label = f'{ operation_def .label } [{ infix } ]' ),
359369 value_nodes ).outputs
360370 if operation_def .cache_coder :
361371 self ._dataset_has_cache_misses [dataset_key ] = True
362372 encode_cache = nodes .apply_operation (
363373 analyzer_nodes .EncodeCache ,
364374 op_output ,
365375 coder = operation_def .cache_coder ,
366- label = 'EncodeCache[{}][{}]' . format ( operation_def . label , infix ) )
376+ label = f 'EncodeCache[{ operation_def . label } ][{ infix } ]' )
367377 self .cache_output_nodes [(dataset_key , cache_entry_key )] = encode_cache
378+ self ._num_encode_cache_nodes += 1
368379 result_fine_grained_view [dataset_key ] = op_output
369380
370381 return result_fine_grained_view
@@ -377,16 +388,15 @@ def _visit_apply_savedmodel_operation(self, operation_def, upstream_views):
377388
378389 fine_grained_view = collections .OrderedDict ()
379390 for (dataset_idx , dataset_key ) in enumerate (self ._sorted_dataset_keys ):
380- infix = 'AnalysisIndex{}' . format ( dataset_idx )
391+ infix = f 'AnalysisIndex{ dataset_idx } '
381392 input_node = nodes .apply_operation (
382393 beam_nodes .ExtractInputForSavedModel ,
383394 dataset_key = dataset_key ,
384- label = 'ExtractInputForSavedModel[{}]' . format ( infix ) )
395+ label = f 'ExtractInputForSavedModel[{ infix } ]' )
385396 # We use an index for the label in order to make beam labels more stable.
386397 (fine_grained_view [dataset_key ],) = (
387398 nodes .OperationNode (
388- operation_def ._replace (
389- label = '{}[{}]' .format (operation_def .label , infix )),
399+ operation_def ._replace (label = f'{ operation_def .label } [{ infix } ]' ),
390400 (saved_model_path_upstream_view .flattened_view ,
391401 input_node )).outputs )
392402
@@ -404,8 +414,8 @@ def validate_value(self, value):
404414 assert isinstance (value , _OptimizationView ), value
405415 if value .fine_grained_view :
406416 assert set (value .fine_grained_view .keys ()) == set (
407- self ._sorted_dataset_keys ), ( '{} != {}' . format (
408- value .fine_grained_view .keys (), self ._sorted_dataset_keys ) )
417+ self ._sorted_dataset_keys
418+ ), ( f' { value .fine_grained_view .keys ()} != { self ._sorted_dataset_keys } ' )
409419
410420
411421def _perform_cache_optimization (saved_model_future , dataset_keys ,
@@ -593,7 +603,7 @@ def preprocessing_fn(input)
593603 label = 'ExtractInputForSavedModel[FlattenedDataset]' )
594604
595605 while not all (sink_tensors_ready .values ()):
596- infix = 'Phase{}' . format ( phase )
606+ infix = f 'Phase{ phase } '
597607 # Determine which table init ops are ready to run in this phase
598608 # Determine which keys of pending_tensor_replacements are ready to run
599609 # in this phase, based in whether their dependencies are ready.
@@ -610,14 +620,14 @@ def preprocessing_fn(input)
610620 * tensor_bindings ,
611621 table_initializers = tuple (graph_analyzer .ready_table_initializers ),
612622 output_signature = intermediate_output_signature ,
613- label = 'CreateSavedModelForAnalyzerInputs[{}]' . format ( infix ) )
623+ label = f 'CreateSavedModelForAnalyzerInputs[{ infix } ]' )
614624
615625 extracted_values_dict = nodes .apply_operation (
616626 beam_nodes .ApplySavedModel ,
617627 saved_model_future ,
618628 extracted_input_node ,
619629 phase = phase ,
620- label = 'ApplySavedModel[{}]' . format ( infix ) )
630+ label = f 'ApplySavedModel[{ infix } ]' )
621631
622632 translate_visitor .phase = phase
623633 translate_visitor .intermediate_output_signature = (
@@ -643,7 +653,7 @@ def preprocessing_fn(input)
643653 dtype_enum = tensor .dtype .as_datatype_enum ,
644654 is_asset_filepath = is_asset_filepath ,
645655 label = analyzer_nodes .sanitize_label (
646- 'CreateTensorBinding[{}]' . format ( name ) )))
656+ f 'CreateTensorBinding[{ name } ]' )))
647657 sink_tensors_ready [hashable_tensor ] = True
648658
649659 analyzers_input_signature .update (intermediate_output_signature )
0 commit comments