2222from tensorflow_transform import nodes
2323from tensorflow_transform import tf2_utils
2424from tensorflow_transform .beam import analysis_graph_builder
25+ from tensorflow_transform .beam import analyzer_cache
2526from tensorflow_transform import test_case
2627# TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple`
2728# once the Spark issue is resolved.
@@ -74,7 +75,7 @@ def _plus_one(x):
7475directed=True;
7576node [shape=Mrecord];
7677"CreateSavedModelForAnalyzerInputs[Phase0]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x/mean_and_var/Cast_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/div_no_nan', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/div_no_nan_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/zeros', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase0]}"];
77- "ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"];
78+ "ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset', is_cached=True )|label: ExtractInputForSavedModel[FlattenedDataset]}"];
7879"ApplySavedModel[Phase0]" [label="{ApplySavedModel|phase: 0|label: ApplySavedModel[Phase0]|partitionable: True}"];
7980"CreateSavedModelForAnalyzerInputs[Phase0]" -> "ApplySavedModel[Phase0]";
8081"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase0]";
@@ -99,7 +100,7 @@ def _plus_one(x):
99100directed=True;
100101node [shape=Mrecord];
101102"CreateSavedModelForAnalyzerInputs[Phase0]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x/mean_and_var/Cast_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/div_no_nan', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/div_no_nan_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/zeros', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase0]}"];
102- "ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"];
103+ "ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset', is_cached=True )|label: ExtractInputForSavedModel[FlattenedDataset]}"];
103104"ApplySavedModel[Phase0]" [label="{ApplySavedModel|phase: 0|label: ApplySavedModel[Phase0]|partitionable: True}"];
104105"CreateSavedModelForAnalyzerInputs[Phase0]" -> "ApplySavedModel[Phase0]";
105106"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase0]";
@@ -144,7 +145,7 @@ def _preprocessing_fn_with_table(inputs):
144145directed=True;
145146node [shape=Mrecord];
146147"CreateSavedModelForAnalyzerInputs[Phase0]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x/boolean_mask/GatherV2', \"Tensor\<shape: [None], \<dtype: 'string'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase0]}"];
147- "ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"];
148+ "ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset', is_cached=True )|label: ExtractInputForSavedModel[FlattenedDataset]}"];
148149"ApplySavedModel[Phase0]" [label="{ApplySavedModel|phase: 0|label: ApplySavedModel[Phase0]|partitionable: True}"];
149150"CreateSavedModelForAnalyzerInputs[Phase0]" -> "ApplySavedModel[Phase0]";
150151"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase0]";
@@ -178,7 +179,7 @@ def _preprocessing_fn_with_table(inputs):
178179directed=True;
179180node [shape=Mrecord];
180181"CreateSavedModelForAnalyzerInputs[Phase0]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x/boolean_mask/GatherV2', \"Tensor\<shape: [None], \<dtype: 'string'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase0]}"];
181- "ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"];
182+ "ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset', is_cached=True )|label: ExtractInputForSavedModel[FlattenedDataset]}"];
182183"ApplySavedModel[Phase0]" [label="{ApplySavedModel|phase: 0|label: ApplySavedModel[Phase0]|partitionable: True}"];
183184"CreateSavedModelForAnalyzerInputs[Phase0]" -> "ApplySavedModel[Phase0]";
184185"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase0]";
@@ -227,7 +228,7 @@ def _preprocessing_fn_with_two_phases(inputs):
227228directed=True;
228229node [shape=Mrecord];
229230"CreateSavedModelForAnalyzerInputs[Phase0]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x/mean_and_var/Cast_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/div_no_nan', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/div_no_nan_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/zeros', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase0]}"];
230- "ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"];
231+ "ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset', is_cached=True )|label: ExtractInputForSavedModel[FlattenedDataset]}"];
231232"ApplySavedModel[Phase0]" [label="{ApplySavedModel|phase: 0|label: ApplySavedModel[Phase0]|partitionable: True}"];
232233"CreateSavedModelForAnalyzerInputs[Phase0]" -> "ApplySavedModel[Phase0]";
233234"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase0]";
@@ -272,7 +273,7 @@ def _preprocessing_fn_with_two_phases(inputs):
272273directed=True;
273274node [shape=Mrecord];
274275"CreateSavedModelForAnalyzerInputs[Phase0]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x/mean_and_var/Cast_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/div_no_nan', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/div_no_nan_1', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\"), ('x/mean_and_var/zeros', \"Tensor\<shape: [], \<dtype: 'float32'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase0]}"];
275- "ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"];
276+ "ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset', is_cached=True )|label: ExtractInputForSavedModel[FlattenedDataset]}"];
276277"ApplySavedModel[Phase0]" [label="{ApplySavedModel|phase: 0|label: ApplySavedModel[Phase0]|partitionable: True}"];
277278"CreateSavedModelForAnalyzerInputs[Phase0]" -> "ApplySavedModel[Phase0]";
278279"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase0]";
@@ -349,7 +350,7 @@ def __new__(cls):
349350directed=True;
350351node [shape=Mrecord];
351352"CreateSavedModelForAnalyzerInputs[Phase0]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('inputs/inputs/x_copy', \"Tensor\<shape: [None], \<dtype: 'int64'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase0]}"];
352- "ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"];
353+ "ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset', is_cached=True )|label: ExtractInputForSavedModel[FlattenedDataset]}"];
353354"ApplySavedModel[Phase0]" [label="{ApplySavedModel|phase: 0|label: ApplySavedModel[Phase0]|partitionable: True}"];
354355"CreateSavedModelForAnalyzerInputs[Phase0]" -> "ApplySavedModel[Phase0]";
355356"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase0]";
@@ -369,7 +370,7 @@ def __new__(cls):
369370directed=True;
370371node [shape=Mrecord];
371372"CreateSavedModelForAnalyzerInputs[Phase0]" [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('inputs_copy', \"Tensor\<shape: [None], \<dtype: 'int64'\>\>\")])|label: CreateSavedModelForAnalyzerInputs[Phase0]}"];
372- "ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset')|label: ExtractInputForSavedModel[FlattenedDataset]}"];
373+ "ExtractInputForSavedModel[FlattenedDataset]" [label="{ExtractInputForSavedModel|dataset_key: DatasetKey(key='FlattenedDataset', is_cached=True )|label: ExtractInputForSavedModel[FlattenedDataset]}"];
373374"ApplySavedModel[Phase0]" [label="{ApplySavedModel|phase: 0|label: ApplySavedModel[Phase0]|partitionable: True}"];
374375"CreateSavedModelForAnalyzerInputs[Phase0]" -> "ApplySavedModel[Phase0]";
375376"ExtractInputForSavedModel[FlattenedDataset]" -> "ApplySavedModel[Phase0]";
@@ -471,16 +472,17 @@ def test_get_analysis_dataset_keys(self, preprocessing_fn, full_dataset_keys,
471472 use_tf_compat_v1 ):
472473 if not use_tf_compat_v1 :
473474 test_case .skip_if_not_tf2 ('Tensorflow 2.x required' )
474- full_dataset_keys = [
475- analysis_graph_builder .analyzer_cache .DatasetKey (k )
476- for k in full_dataset_keys
477- ]
475+ full_dataset_keys = list (
476+ map (analyzer_cache .DatasetKey , full_dataset_keys ))
477+ cached_dataset_keys = map (analyzer_cache .DatasetKey , cached_dataset_keys )
478+ expected_dataset_keys = map (
479+ analyzer_cache .DatasetKey , expected_dataset_keys )
478480 # We force all dataset keys with entries in the cache dict will have a cache
479481 # hit.
480482 mocked_cache_entry_key = b'M'
481483 input_cache = {
482- key : analysis_graph_builder . analyzer_cache .DatasetCache (
483- { mocked_cache_entry_key : 'C' }, None ) for key in cached_dataset_keys
484+ key : analyzer_cache .DatasetCache ({ mocked_cache_entry_key : 'C' }, None )
485+ for key in cached_dataset_keys
484486 }
485487 feature_spec = {'x' : tf .io .FixedLenFeature ([], tf .float32 )}
486488 specs = (
@@ -509,7 +511,7 @@ def test_get_analysis_dataset_keys(self, preprocessing_fn, full_dataset_keys,
509511 def test_get_analysis_cache_entry_keys (self , use_tf_compat_v1 ):
510512 if not use_tf_compat_v1 :
511513 test_case .skip_if_not_tf2 ('Tensorflow 2.x required' )
512- full_dataset_keys = ['a' , 'b' ]
514+ full_dataset_keys = map ( analyzer_cache . DatasetKey , ['a' , 'b' ])
513515 def preprocessing_fn (inputs ):
514516 return {'x' : tft .scale_to_0_1 (inputs ['x' ])}
515517 mocked_cache_entry_key = 'A'
0 commit comments