Skip to content

Commit f1d1605

Browse files
tf-transform-teamtfx-copybara
authored andcommitted
Enable the experimental_debug_stripper in TF2 when saving the transform_fn
PiperOrigin-RevId: 608950695
1 parent fb7688c commit f1d1605

File tree

7 files changed

+131
-57
lines changed

7 files changed

+131
-57
lines changed

RELEASE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
* Explicitly use Keras 2 or `tf_keras`` if Keras 3 is installed.
1414
* Added python 3.11 support.
1515
* Depends on `tensorflow>=2.15.0,<3`.
16+
* Enable passing `tf.saved_model.SaveOptions` to model saving functionality.
1617

1718
## Breaking Changes
1819

tensorflow_transform/beam/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import uuid
2222

2323
import apache_beam as beam
24+
import tensorflow as tf
2425
from tensorflow_transform import common_types
2526
from tensorflow_transform import nodes
2627
from tfx_bsl.telemetry import util
@@ -165,6 +166,7 @@ class ExtraArgs:
165166
cache_pcoll_dict: Optional[Dict[str, beam.PCollection]]
166167
preprocessing_fn: Any
167168
analyzers_fingerprint: Mapping[str, Any]
169+
save_options: tf.saved_model.SaveOptions
168170

169171
def __init__(self, extra_args):
170172
self._extra_args = extra_args

tensorflow_transform/beam/context.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class Context:
4242
force_tf_compat_v1: (Optional) If True, TFT's public APIs
4343
(e.g. AnalyzeDataset) will use Tensorflow in compat.v1 mode irrespective
4444
of installed version of Tensorflow. Defaults to `False`.
45+
save_options: (Optional) If set, the tf.saved_model.SaveOptions to save
46+
the transform_fn with. Only applies for TF2.
4547
4648
Note that the temp dir should be accessible to worker jobs, e.g. if running
4749
with the Cloud Dataflow runner, the temp dir should be on GCS and should have
@@ -56,6 +58,7 @@ class _State:
5658
passthrough_keys: Optional[Iterable[str]] = None
5759
use_deep_copy_optimization: Optional[bool] = None
5860
force_tf_compat_v1: Optional[bool] = None
61+
save_options: Optional[tf.saved_model.SaveOptions] = None
5962

6063
@classmethod
6164
def make_empty(cls):
@@ -80,7 +83,8 @@ def __init__(self,
8083
desired_batch_size: Optional[int] = None,
8184
passthrough_keys: Optional[Iterable[str]] = None,
8285
use_deep_copy_optimization: Optional[bool] = None,
83-
force_tf_compat_v1: Optional[bool] = None):
86+
force_tf_compat_v1: Optional[bool] = None,
87+
save_options: Optional[tf.saved_model.SaveOptions] = None):
8488
state = getattr(self._thread_local, 'state', None)
8589
if not state:
8690
self._thread_local.state = self._StateStack()
@@ -92,6 +96,7 @@ def __init__(self,
9296
self._passthrough_keys = passthrough_keys
9397
self._use_deep_copy_optimization = use_deep_copy_optimization
9498
self._force_tf_compat_v1 = force_tf_compat_v1
99+
self._save_options = save_options
95100

96101
def __enter__(self):
97102
# Previous State's properties are inherited if not explicitly specified.
@@ -110,7 +115,8 @@ def __enter__(self):
110115
last_frame.use_deep_copy_optimization,
111116
force_tf_compat_v1=self._force_tf_compat_v1
112117
if self._force_tf_compat_v1 is not None else
113-
last_frame.force_tf_compat_v1))
118+
last_frame.force_tf_compat_v1,
119+
save_options=self._save_options or last_frame.save_options))
114120

115121
def __exit__(self, *exn_info):
116122
self._thread_local.state.frames.pop()
@@ -175,3 +181,12 @@ def get_use_tf_compat_v1(cls) -> bool:
175181
"""Computes use_tf_compat_v1 from TF environment and force_tf_compat_v1."""
176182
force_tf_compat_v1 = cls._get_force_tf_compat_v1()
177183
return tf2_utils.use_tf_compat_v1(force_tf_compat_v1)
184+
185+
@classmethod
186+
def get_save_options(cls) -> Optional[tf.saved_model.SaveOptions]:
187+
"""Retrieves a user set save_options, None if not set."""
188+
state = cls._get_topmost_state_frame()
189+
if state.save_options is not None:
190+
tf.compat.v1.logging.info('Using save_options: %s', state.save_options)
191+
return state.save_options
192+
return None

tensorflow_transform/beam/impl.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ def expand(self, inputs):
651651
def _create_v2_saved_model(tensor_replacement_map, base_temp_dir,
652652
preprocessing_fn, input_signature,
653653
baseline_analyzers_fingerprint,
654-
output_keys_to_name_map):
654+
output_keys_to_name_map, save_options):
655655
"""Writes out a SavedModelV2 with preprocessing_fn traced using tf.function.
656656
657657
The SavedModel written contains a method called `transform_fn` that
@@ -669,6 +669,7 @@ def _create_v2_saved_model(tensor_replacement_map, base_temp_dir,
669669
paths that define its fingerprint.
670670
output_keys_to_name_map: A map from output dictionary keys to the names of
671671
the tensors that they represent.
672+
save_options: The tf.saved_model.SaveOptions to save the model with.
672673
673674
Returns:
674675
Path to which SavedModel was written.
@@ -678,7 +679,8 @@ def _create_v2_saved_model(tensor_replacement_map, base_temp_dir,
678679
input_signature, base_temp_dir,
679680
baseline_analyzers_fingerprint,
680681
tensor_replacement_map,
681-
output_keys_to_name_map)
682+
output_keys_to_name_map,
683+
save_options)
682684
return saved_model_dir
683685

684686

@@ -695,6 +697,7 @@ def __init__(self, operation, extra_args):
695697
self._input_signature = extra_args.input_specs
696698
self._output_signature = operation.output_signature
697699
self._analyzers_fingerprint = extra_args.analyzers_fingerprint
700+
self._save_options = extra_args.save_options
698701

699702
def _maybe_get_output_tensor_names_dict(self):
700703
# output_signature will contain CompositeTensors only if this is the final
@@ -719,7 +722,7 @@ def expand(self, inputs):
719722
| 'CreateSavedModel' >> beam.Map(
720723
_create_v2_saved_model, self._base_temp_dir, self._preprocessing_fn,
721724
self._input_signature, self._analyzers_fingerprint,
722-
self._maybe_get_output_tensor_names_dict())
725+
self._maybe_get_output_tensor_names_dict(), self._save_options)
723726
| 'Count' >>
724727
beam_common.IncrementCounter(_CREATE_SAVED_MODEL_COUNTER_NAME))
725728

@@ -988,6 +991,7 @@ def __init__(self, preprocessing_fn, pipeline=None):
988991
"""
989992
self._preprocessing_fn = preprocessing_fn
990993
self.pipeline = pipeline
994+
self._save_options = Context.get_save_options()
991995
self._use_tf_compat_v1 = Context.get_use_tf_compat_v1()
992996
if self._use_tf_compat_v1:
993997
_warn_about_tf_compat_v1()
@@ -1155,7 +1159,8 @@ def expand(self, dataset):
11551159
use_tf_compat_v1=self._use_tf_compat_v1,
11561160
cache_pcoll_dict=dataset_cache_dict,
11571161
preprocessing_fn=self._preprocessing_fn,
1158-
analyzers_fingerprint=analyzers_fingerprint)
1162+
analyzers_fingerprint=analyzers_fingerprint,
1163+
save_options=self._save_options)
11591164

11601165
(transform_fn_future, cache_value_nodes,
11611166
detached_sideeffect_leafs) = analysis_graph_builder.build(

tensorflow_transform/impl_helper.py

Lines changed: 76 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -434,103 +434,138 @@ def trace_preprocessing_function(preprocessing_fn,
434434
0. the graph representing the traced `preprocessing_fn`
435435
1. the graph's structured inputs
436436
2. the graph's structured outputs
437-
438437
"""
439438
if use_tf_compat_v1:
440439
return _trace_preprocessing_fn_v1(preprocessing_fn, input_specs)
441440
else:
442-
return _trace_preprocessing_fn_v2(preprocessing_fn, input_specs,
443-
base_temp_dir)
441+
return _trace_preprocessing_fn_v2(
442+
preprocessing_fn, input_specs, base_temp_dir
443+
)
444444

445445

446446
def _trace_and_write_transform_fn(
447447
saved_model_dir: str,
448-
preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
449-
Mapping[str, common_types.TensorType]],
450-
input_signature: Mapping[str, tf.TypeSpec], base_temp_dir: Optional[str],
448+
preprocessing_fn: Callable[
449+
[Mapping[str, common_types.TensorType]],
450+
Mapping[str, common_types.TensorType],
451+
],
452+
input_signature: Mapping[str, tf.TypeSpec],
453+
base_temp_dir: Optional[str],
451454
tensor_replacement_map: Optional[Dict[str, tf.Tensor]],
452-
output_keys_to_name_map: Optional[Dict[str,
453-
str]]) -> function.ConcreteFunction:
455+
output_keys_to_name_map: Optional[Dict[str, str]],
456+
save_options: Optional[tf.saved_model.SaveOptions],
457+
) -> function.ConcreteFunction:
454458
"""Trace `preprocessing_fn` and serialize to a SavedModel."""
455459
tf_graph_context = graph_context.TFGraphContext(
456460
module_to_export=tf.Module(),
457461
temp_dir=base_temp_dir,
458-
evaluated_replacements=tensor_replacement_map)
462+
evaluated_replacements=tensor_replacement_map,
463+
)
459464
transform_fn = get_traced_transform_fn(
460465
preprocessing_fn,
461466
input_signature,
462467
tf_graph_context,
463-
output_keys_to_name_map=output_keys_to_name_map)
468+
output_keys_to_name_map=output_keys_to_name_map,
469+
)
464470
return saved_transform_io_v2.write_v2_saved_model(
465-
tf_graph_context.module_to_export, transform_fn, 'transform_fn',
466-
saved_model_dir)
471+
tf_graph_context.module_to_export,
472+
transform_fn,
473+
'transform_fn',
474+
saved_model_dir,
475+
save_options,
476+
)
467477

468478

469479
def _trace_and_get_metadata(
470480
concrete_transform_fn: function.ConcreteFunction,
471481
structured_inputs: Mapping[str, common_types.TensorType],
472-
preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
473-
Mapping[str, common_types.TensorType]],
482+
preprocessing_fn: Callable[
483+
[Mapping[str, common_types.TensorType]],
484+
Mapping[str, common_types.TensorType],
485+
],
474486
base_temp_dir: Optional[str],
475-
tensor_replacement_map: Optional[Dict[str, tf.Tensor]]
487+
tensor_replacement_map: Optional[Dict[str, tf.Tensor]],
476488
) -> dataset_metadata.DatasetMetadata:
477489
"""Compute and return metadata for the outputs of `concrete_transform_fn`."""
478490
tf_graph_context = graph_context.TFGraphContext(
479491
module_to_export=tf.Module(),
480492
temp_dir=base_temp_dir,
481-
evaluated_replacements=tensor_replacement_map)
493+
evaluated_replacements=tensor_replacement_map,
494+
)
482495
concrete_metadata_fn = schema_inference.get_traced_metadata_fn(
483496
preprocessing_fn,
484497
structured_inputs,
485498
tf_graph_context,
486-
evaluate_schema_overrides=True)
499+
evaluate_schema_overrides=True,
500+
)
487501
return dataset_metadata.DatasetMetadata(
488502
schema=schema_inference.infer_feature_schema_v2(
489503
concrete_transform_fn.structured_outputs,
490504
concrete_metadata_fn,
491-
evaluate_schema_overrides=True))
505+
evaluate_schema_overrides=True,
506+
)
507+
)
492508

493509

494510
def _validate_analyzers_fingerprint(
495-
baseline_analyzers_fingerprint: Mapping[str,
496-
graph_tools.AnalyzersFingerprint],
497-
graph: tf.Graph, structured_inputs: Mapping[str, common_types.TensorType]):
511+
baseline_analyzers_fingerprint: Mapping[
512+
str, graph_tools.AnalyzersFingerprint
513+
],
514+
graph: tf.Graph,
515+
structured_inputs: Mapping[str, common_types.TensorType],
516+
):
498517
"""Validates analyzers fingerprint in `graph` is same as baseline."""
499518
analyzers_fingerprint = graph_tools.get_analyzers_fingerprint(
500-
graph, structured_inputs)
519+
graph, structured_inputs
520+
)
501521
error_msg = (
502522
'The order of analyzers in your `preprocessing_fn` appears to be '
503523
'non-deterministic. This can be fixed either by changing your '
504524
'`preprocessing_fn` such that tf.Transform analyzers are encountered '
505525
'in a deterministic order or by passing a unique name to each '
506-
'analyzer API call.')
526+
'analyzer API call.'
527+
)
507528
for analyzer in analyzers_fingerprint:
508529
if analyzer not in baseline_analyzers_fingerprint:
509-
prefix_msg = (f'Analyzer node ({analyzer}) not found in '
510-
f'{baseline_analyzers_fingerprint.keys()}. ')
530+
prefix_msg = (
531+
f'Analyzer node ({analyzer}) not found in '
532+
f'{baseline_analyzers_fingerprint.keys()}. '
533+
)
511534
raise RuntimeError(prefix_msg + error_msg)
512-
if (baseline_analyzers_fingerprint[analyzer].source_keys !=
513-
analyzers_fingerprint[analyzer].source_keys):
535+
if (
536+
baseline_analyzers_fingerprint[analyzer].source_keys
537+
!= analyzers_fingerprint[analyzer].source_keys
538+
):
514539
raise RuntimeError(error_msg)
515540

516-
if (baseline_analyzers_fingerprint[analyzer].unique_path_hash !=
517-
analyzers_fingerprint[analyzer].unique_path_hash):
541+
if (
542+
baseline_analyzers_fingerprint[analyzer].unique_path_hash
543+
!= analyzers_fingerprint[analyzer].unique_path_hash
544+
):
518545
logging.warning(
519-
'Analyzer (%s) node\'s cache key varies on repeated tracing.'
546+
"Analyzer (%s) node's cache key varies on repeated tracing."
520547
' This warning is safe to ignore if you either specify `name` for all'
521548
' analyzers or if the order in which they are invoked is'
522-
' deterministic. If not, please file a bug with details.', analyzer)
549+
' deterministic. If not, please file a bug with details.',
550+
analyzer,
551+
)
523552

524553

525554
def trace_and_write_v2_saved_model(
526555
saved_model_dir: str,
527-
preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
528-
Mapping[str, common_types.TensorType]],
529-
input_signature: Mapping[str, tf.TypeSpec], base_temp_dir: Optional[str],
530-
baseline_analyzers_fingerprint: Mapping[str,
531-
graph_tools.AnalyzersFingerprint],
556+
preprocessing_fn: Callable[
557+
[Mapping[str, common_types.TensorType]],
558+
Mapping[str, common_types.TensorType],
559+
],
560+
input_signature: Mapping[str, tf.TypeSpec],
561+
base_temp_dir: Optional[str],
562+
baseline_analyzers_fingerprint: Mapping[
563+
str, graph_tools.AnalyzersFingerprint
564+
],
532565
tensor_replacement_map: Optional[Dict[str, tf.Tensor]],
533-
output_keys_to_name_map: Optional[Dict[str, str]]):
566+
output_keys_to_name_map: Optional[Dict[str, str]],
567+
save_options: Optional[tf.saved_model.SaveOptions],
568+
):
534569
"""Writes out a SavedModelV2 with preprocessing_fn traced using tf.function.
535570
536571
The SavedModel written contains a method called `transform_fn` that
@@ -549,6 +584,7 @@ def trace_and_write_v2_saved_model(
549584
evaluated replacement tensors.
550585
output_keys_to_name_map: A map from output dictionary keys to the names of
551586
the tensors that they represent.
587+
save_options: The options to use when saving the saved_model.
552588
553589
Returns:
554590
A tuple containing a pair of `tf.ConcreteFunction`s:
@@ -562,7 +598,7 @@ def trace_and_write_v2_saved_model(
562598
"""
563599
concrete_transform_fn = _trace_and_write_transform_fn(
564600
saved_model_dir, preprocessing_fn, input_signature, base_temp_dir,
565-
tensor_replacement_map, output_keys_to_name_map)
601+
tensor_replacement_map, output_keys_to_name_map, save_options)
566602
structured_inputs = tf2_utils.get_structured_inputs_from_func_graph(
567603
concrete_transform_fn.graph)
568604
_validate_analyzers_fingerprint(baseline_analyzers_fingerprint,
@@ -632,7 +668,8 @@ def analyze_in_place(preprocessing_fn, force_tf_compat_v1, feature_specs,
632668
input_signature=type_specs,
633669
base_temp_dir=None,
634670
tensor_replacement_map=None,
635-
output_keys_to_name_map=None)
671+
output_keys_to_name_map=None,
672+
save_options=None)
636673
_assert_no_analyzers_in_graph(concrete_transform_fn.graph)
637674
structured_inputs = tf2_utils.get_structured_inputs_from_func_graph(
638675
concrete_transform_fn.graph)

tensorflow_transform/saved/saved_transform_io_v2.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
"""Utility functions to save and load from SavedModels in TF 2.x."""
1515

16-
from typing import Any, Dict, Iterable, Mapping, Tuple, Union
16+
from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union
1717

1818
import tensorflow as tf
1919
from tensorflow_transform import annotators
@@ -534,9 +534,15 @@ def write_v2_saved_model(
534534
tf_function: tf.types.experimental.GenericFunction,
535535
name: str,
536536
saved_model_dir: str,
537+
save_options: Optional[tf.saved_model.SaveOptions] = None,
537538
) -> function.ConcreteFunction:
538539
"""Writes `tf_function` under attr `name` of `module` to `saved_model_dir`."""
539540
concrete_fn = trace_and_update_module(
540-
module, tf_function, name, strip_control_dependencies=False)
541-
tf.saved_model.save(module, saved_model_dir)
541+
module, tf_function, name, strip_control_dependencies=False
542+
)
543+
tf.saved_model.save(
544+
module,
545+
saved_model_dir,
546+
options=save_options,
547+
)
542548
return concrete_fn

0 commit comments

Comments
 (0)