Skip to content

Commit 09ac306

Browse files
zoyahavtfx-copybara
authored andcommitted
Implementing uncached datasets to allow for incremental cache build up.
PiperOrigin-RevId: 519975152
1 parent a74cdbe commit 09ac306

File tree

8 files changed

+301
-89
lines changed

8 files changed

+301
-89
lines changed

RELEASE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
* New experimental APIs added for annotating sparse output tensors:
1111
`tft.experimental.annotate_sparse_output_shape` and
1212
`tft.experimental.annotate_true_sparse_output`.
13+
* `DatasetKey.non_cacheable` added to allow for some datasets to not produce
14+
cache. This may be useful for gradual cache generation when operating on a
15+
large rolling range of datasets.
1316

1417
## Bug Fixes and Other Changes
1518

tensorflow_transform/beam/analysis_graph_builder.py

Lines changed: 72 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
import collections
1717
import hashlib
1818

19+
from typing import Dict, Mapping, Collection, Optional, Tuple
20+
1921
import tensorflow as tf
2022
from tensorflow_transform import analyzer_nodes
23+
from tensorflow_transform import common_types
2124
from tensorflow_transform import graph_tools
2225
from tensorflow_transform import impl_helper
2326
from tensorflow_transform import nodes
@@ -35,6 +38,11 @@
3538
_ANALYSIS_GRAPH = None
3639

3740

41+
_IntermediateCacheType = Dict[
42+
Tuple[analyzer_cache.DatasetKey, str], analyzer_cache.DatasetCache
43+
]
44+
45+
3846
def _tensor_name(tensor):
3947
"""Get a name of a tensor without trailing ":0" when relevant."""
4048
# tensor.name is unicode in Python 3 and bytes in Python 2 so convert to
@@ -152,8 +160,14 @@ class _OptimizeVisitor(nodes.Visitor):
152160
input data, according to the `is_partitionable` annotation.
153161
"""
154162

155-
def __init__(self, dataset_keys, cache_dict, tensor_keys_to_paths,
156-
cache_output_nodes, num_phases):
163+
def __init__(
164+
self,
165+
dataset_keys: Collection[analyzer_cache.DatasetKey],
166+
cache_dict: Optional[analyzer_cache.BeamAnalysisCache],
167+
tensor_keys_to_paths: Mapping[str, str],
168+
cache_output_nodes: _IntermediateCacheType,
169+
num_phases: int,
170+
):
157171
"""Init method for _OptimizeVisitor.
158172
159173
Args:
@@ -367,7 +381,7 @@ def _apply_operation_on_fine_grained_view(self, operation_def,
367381
(op_output,) = nodes.OperationNode(
368382
operation_def._replace(label=f'{operation_def.label}[{infix}]'),
369383
value_nodes).outputs
370-
if operation_def.cache_coder:
384+
if operation_def.cache_coder and dataset_key.is_cached:
371385
self._dataset_has_cache_misses[dataset_key] = True
372386
encode_cache = nodes.apply_operation(
373387
analyzer_nodes.EncodeCache,
@@ -418,8 +432,17 @@ def validate_value(self, value):
418432
), (f'{value.fine_grained_view.keys()} != {self._sorted_dataset_keys}')
419433

420434

421-
def _perform_cache_optimization(saved_model_future, dataset_keys,
422-
tensor_keys_to_paths, cache_dict, num_phases):
435+
def _perform_cache_optimization(
436+
saved_model_future: nodes.ValueNode,
437+
dataset_keys: Collection[analyzer_cache.DatasetKey],
438+
tensor_keys_to_paths: Dict[str, str],
439+
cache_dict: Optional[analyzer_cache.BeamAnalysisCache],
440+
num_phases: int,
441+
) -> Tuple[
442+
Tuple[nodes.ValueNode],
443+
Optional[_IntermediateCacheType],
444+
Collection[nodes.ValueNode],
445+
]:
423446
"""Performs cache optimization on the given graph."""
424447
cache_output_nodes = {}
425448
optimize_visitor = _OptimizeVisitor(dataset_keys or {}, cache_dict,
@@ -526,14 +549,38 @@ def get_analysis_cache_entry_keys(preprocessing_fn,
526549
_, cache_dict = _build_analysis_graph_for_inspection(preprocessing_fn, specs,
527550
dataset_keys, {},
528551
force_tf_compat_v1)
529-
return set([cache_key for _, cache_key in cache_dict.keys()])
552+
result = set()
553+
for dataset_cache in cache_dict.values():
554+
result.update(dataset_cache.keys())
555+
return result
530556

531557

532-
def build(graph,
533-
input_signature,
534-
output_signature,
535-
dataset_keys=None,
536-
cache_dict=None):
558+
AnalysisCache = Mapping[
559+
analyzer_cache.DatasetKey, Mapping[str, nodes.ValueNode]
560+
]
561+
562+
563+
def _format_output_cache(
564+
cache_value_nodes: _IntermediateCacheType,
565+
) -> Optional[AnalysisCache]:
566+
"""Triggers dataset cache encoding and composes analysis cache output."""
567+
if cache_value_nodes is None:
568+
return None
569+
cache_dict = collections.defaultdict(dict)
570+
for (dataset_key, cache_key), value_node in cache_value_nodes.items():
571+
cache_dict[dataset_key][cache_key] = value_node
572+
return cache_dict
573+
574+
575+
def build(
576+
graph: tf.Graph,
577+
input_signature: Mapping[str, common_types.TensorType],
578+
output_signature: Mapping[str, common_types.TensorType],
579+
dataset_keys: Optional[Collection[analyzer_cache.DatasetKey]] = None,
580+
cache_dict: Optional[analyzer_cache.BeamAnalysisCache] = None,
581+
) -> Tuple[
582+
nodes.ValueNode, Optional[AnalysisCache], Collection[nodes.ValueNode]
583+
]:
537584
"""Returns a list of `Phase`s describing how to execute the pipeline.
538585
539586
The default graph is assumed to contain some `Analyzer`s which must be
@@ -567,18 +614,19 @@ def preprocessing_fn(input)
567614
568615
Args:
569616
graph: A `tf.Graph`.
570-
input_signature: A dict whose keys are strings and values are `Tensor`s or
571-
`SparseTensor`s.
572-
output_signature: A dict whose keys are strings and values are `Tensor`s or
573-
`SparseTensor`s.
574-
dataset_keys: (Optional) A set of strings which are dataset keys, they
575-
uniquely identify these datasets across analysis runs.
617+
input_signature: A dict whose keys are strings and values are `Tensor`s,
618+
`SparseTensor`s, or `RaggedTensor`s.
619+
output_signature: A dict whose keys are strings and values are `Tensor`s,
620+
`SparseTensor`s, or `RaggedTensor`s.
621+
dataset_keys: (Optional) A set of `DatasetKeys`, which uniquely identify
622+
these datasets across analysis runs.
576623
cache_dict: (Optional): A cache dictionary.
577624
578625
Returns:
579-
A pair of:
580-
* list of `Phase`s
626+
A tuple of:
627+
* A SavedModel future node.
581628
* A dictionary of output cache `ValueNode`s.
629+
* Side affect leaf nodes.
582630
583631
Raises:
584632
ValueError: if the graph cannot be analyzed.
@@ -690,5 +738,8 @@ def preprocessing_fn(input)
690738

691739
global _ANALYSIS_GRAPH
692740
_ANALYSIS_GRAPH = optimized_saved_model_future
693-
return (optimized_saved_model_future, output_cache_value_nodes,
694-
detached_sideeffect_leafs)
741+
return (
742+
optimized_saved_model_future,
743+
_format_output_cache(output_cache_value_nodes),
744+
detached_sideeffect_leafs,
745+
)

tensorflow_transform/beam/analysis_graph_builder_test.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tensorflow_transform import nodes
2323
from tensorflow_transform import tf2_utils
2424
from tensorflow_transform.beam import analysis_graph_builder
25+
from tensorflow_transform.beam import analyzer_cache
2526
from tensorflow_transform import test_case
2627
# TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple`
2728
# once the Spark issue is resolved.
@@ -74,7 +75,7 @@ def _plus_one(x):
7475
directed=True;
7576
node [shape=Mrecord];
7677
"CreateSavedModelForAnalyzerInputs[Phase0]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x/mean_and_var/Cast_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/div_no_nan', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/div_no_nan_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/zeros', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase0]}"];
77-
"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"];
78+
"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset', is_cached=True)|label: ExtractInputForSavedModel[FlattenedDataset]}"];
7879
"ApplySavedModel[Phase0]" [label="{ApplySavedModel|phase: 0|label: ApplySavedModel[Phase0]|partitionable: True}"];
7980
"CreateSavedModelForAnalyzerInputs[Phase0]" -> "ApplySavedModel[Phase0]";
8081
"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase0]";
@@ -99,7 +100,7 @@ def _plus_one(x):
99100
directed=True;
100101
node [shape=Mrecord];
101102
"CreateSavedModelForAnalyzerInputs[Phase0]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x/mean_and_var/Cast_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/div_no_nan', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/div_no_nan_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/zeros', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase0]}"];
102-
"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"];
103+
"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset', is_cached=True)|label: ExtractInputForSavedModel[FlattenedDataset]}"];
103104
"ApplySavedModel[Phase0]" [label="{ApplySavedModel|phase: 0|label: ApplySavedModel[Phase0]|partitionable: True}"];
104105
"CreateSavedModelForAnalyzerInputs[Phase0]" -> "ApplySavedModel[Phase0]";
105106
"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase0]";
@@ -144,7 +145,7 @@ def _preprocessing_fn_with_table(inputs):
144145
directed=True;
145146
node [shape=Mrecord];
146147
"CreateSavedModelForAnalyzerInputs[Phase0]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x/boolean_mask/GatherV2', \"Tensor\<shape: [None], \<dtype: 'string'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase0]}"];
147-
"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"];
148+
"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset', is_cached=True)|label: ExtractInputForSavedModel[FlattenedDataset]}"];
148149
"ApplySavedModel[Phase0]" [label="{ApplySavedModel|phase: 0|label: ApplySavedModel[Phase0]|partitionable: True}"];
149150
"CreateSavedModelForAnalyzerInputs[Phase0]" -> "ApplySavedModel[Phase0]";
150151
"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase0]";
@@ -178,7 +179,7 @@ def _preprocessing_fn_with_table(inputs):
178179
directed=True;
179180
node [shape=Mrecord];
180181
"CreateSavedModelForAnalyzerInputs[Phase0]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x/boolean_mask/GatherV2', \"Tensor\<shape: [None], \<dtype: 'string'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase0]}"];
181-
"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"];
182+
"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset', is_cached=True)|label: ExtractInputForSavedModel[FlattenedDataset]}"];
182183
"ApplySavedModel[Phase0]" [label="{ApplySavedModel|phase: 0|label: ApplySavedModel[Phase0]|partitionable: True}"];
183184
"CreateSavedModelForAnalyzerInputs[Phase0]" -> "ApplySavedModel[Phase0]";
184185
"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase0]";
@@ -227,7 +228,7 @@ def _preprocessing_fn_with_two_phases(inputs):
227228
directed=True;
228229
node [shape=Mrecord];
229230
"CreateSavedModelForAnalyzerInputs[Phase0]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x/mean_and_var/Cast_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/div_no_nan', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/div_no_nan_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/zeros', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase0]}"];
230-
"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"];
231+
"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset', is_cached=True)|label: ExtractInputForSavedModel[FlattenedDataset]}"];
231232
"ApplySavedModel[Phase0]" [label="{ApplySavedModel|phase: 0|label: ApplySavedModel[Phase0]|partitionable: True}"];
232233
"CreateSavedModelForAnalyzerInputs[Phase0]" -> "ApplySavedModel[Phase0]";
233234
"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase0]";
@@ -272,7 +273,7 @@ def _preprocessing_fn_with_two_phases(inputs):
272273
directed=True;
273274
node [shape=Mrecord];
274275
"CreateSavedModelForAnalyzerInputs[Phase0]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x/mean_and_var/Cast_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/div_no_nan', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/div_no_nan_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/zeros', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase0]}"];
275-
"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"];
276+
"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset', is_cached=True)|label: ExtractInputForSavedModel[FlattenedDataset]}"];
276277
"ApplySavedModel[Phase0]" [label="{ApplySavedModel|phase: 0|label: ApplySavedModel[Phase0]|partitionable: True}"];
277278
"CreateSavedModelForAnalyzerInputs[Phase0]" -> "ApplySavedModel[Phase0]";
278279
"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase0]";
@@ -349,7 +350,7 @@ def __new__(cls):
349350
directed=True;
350351
node [shape=Mrecord];
351352
"CreateSavedModelForAnalyzerInputs[Phase0]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('inputs/inputs/x_copy', \"Tensor\<shape: [None], \<dtype: 'int64'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase0]}"];
352-
"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"];
353+
"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset', is_cached=True)|label: ExtractInputForSavedModel[FlattenedDataset]}"];
353354
"ApplySavedModel[Phase0]" [label="{ApplySavedModel|phase: 0|label: ApplySavedModel[Phase0]|partitionable: True}"];
354355
"CreateSavedModelForAnalyzerInputs[Phase0]" -> "ApplySavedModel[Phase0]";
355356
"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase0]";
@@ -369,7 +370,7 @@ def __new__(cls):
369370
directed=True;
370371
node [shape=Mrecord];
371372
"CreateSavedModelForAnalyzerInputs[Phase0]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('inputs_copy', \"Tensor\<shape: [None], \<dtype: 'int64'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase0]}"];
372-
"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"];
373+
"ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset', is_cached=True)|label: ExtractInputForSavedModel[FlattenedDataset]}"];
373374
"ApplySavedModel[Phase0]" [label="{ApplySavedModel|phase: 0|label: ApplySavedModel[Phase0]|partitionable: True}"];
374375
"CreateSavedModelForAnalyzerInputs[Phase0]" -> "ApplySavedModel[Phase0]";
375376
"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase0]";
@@ -471,16 +472,17 @@ def test_get_analysis_dataset_keys(self, preprocessing_fn, full_dataset_keys,
471472
use_tf_compat_v1):
472473
if not use_tf_compat_v1:
473474
test_case.skip_if_not_tf2('Tensorflow 2.x required')
474-
full_dataset_keys = [
475-
analysis_graph_builder.analyzer_cache.DatasetKey(k)
476-
for k in full_dataset_keys
477-
]
475+
full_dataset_keys = list(
476+
map(analyzer_cache.DatasetKey, full_dataset_keys))
477+
cached_dataset_keys = map(analyzer_cache.DatasetKey, cached_dataset_keys)
478+
expected_dataset_keys = map(
479+
analyzer_cache.DatasetKey, expected_dataset_keys)
478480
# We force all dataset keys with entries in the cache dict will have a cache
479481
# hit.
480482
mocked_cache_entry_key = b'M'
481483
input_cache = {
482-
key: analysis_graph_builder.analyzer_cache.DatasetCache(
483-
{mocked_cache_entry_key: 'C'}, None) for key in cached_dataset_keys
484+
key: analyzer_cache.DatasetCache({mocked_cache_entry_key: 'C'}, None)
485+
for key in cached_dataset_keys
484486
}
485487
feature_spec = {'x': tf.io.FixedLenFeature([], tf.float32)}
486488
specs = (
@@ -509,7 +511,7 @@ def test_get_analysis_dataset_keys(self, preprocessing_fn, full_dataset_keys,
509511
def test_get_analysis_cache_entry_keys(self, use_tf_compat_v1):
510512
if not use_tf_compat_v1:
511513
test_case.skip_if_not_tf2('Tensorflow 2.x required')
512-
full_dataset_keys = ['a', 'b']
514+
full_dataset_keys = map(analyzer_cache.DatasetKey, ['a', 'b'])
513515
def preprocessing_fn(inputs):
514516
return {'x': tft.scale_to_0_1(inputs['x'])}
515517
mocked_cache_entry_key = 'A'

0 commit comments

Comments
 (0)