@@ -434,103 +434,138 @@ def trace_preprocessing_function(preprocessing_fn,
434434 0. the graph representing the traced `preprocessing_fn`
435435 1. the graph's structured inputs
436436 2. the graph's structured outputs
437-
438437 """
439438 if use_tf_compat_v1 :
440439 return _trace_preprocessing_fn_v1 (preprocessing_fn , input_specs )
441440 else :
442- return _trace_preprocessing_fn_v2 (preprocessing_fn , input_specs ,
443- base_temp_dir )
441+ return _trace_preprocessing_fn_v2 (
442+ preprocessing_fn , input_specs , base_temp_dir
443+ )
444444
445445
446446def _trace_and_write_transform_fn (
447447 saved_model_dir : str ,
448- preprocessing_fn : Callable [[Mapping [str , common_types .TensorType ]],
449- Mapping [str , common_types .TensorType ]],
450- input_signature : Mapping [str , tf .TypeSpec ], base_temp_dir : Optional [str ],
448+ preprocessing_fn : Callable [
449+ [Mapping [str , common_types .TensorType ]],
450+ Mapping [str , common_types .TensorType ],
451+ ],
452+ input_signature : Mapping [str , tf .TypeSpec ],
453+ base_temp_dir : Optional [str ],
451454 tensor_replacement_map : Optional [Dict [str , tf .Tensor ]],
452- output_keys_to_name_map : Optional [Dict [str ,
453- str ]]) -> function .ConcreteFunction :
455+ output_keys_to_name_map : Optional [Dict [str , str ]],
456+ save_options : Optional [tf .saved_model .SaveOptions ],
457+ ) -> function .ConcreteFunction :
454458 """Trace `preprocessing_fn` and serialize to a SavedModel."""
455459 tf_graph_context = graph_context .TFGraphContext (
456460 module_to_export = tf .Module (),
457461 temp_dir = base_temp_dir ,
458- evaluated_replacements = tensor_replacement_map )
462+ evaluated_replacements = tensor_replacement_map ,
463+ )
459464 transform_fn = get_traced_transform_fn (
460465 preprocessing_fn ,
461466 input_signature ,
462467 tf_graph_context ,
463- output_keys_to_name_map = output_keys_to_name_map )
468+ output_keys_to_name_map = output_keys_to_name_map ,
469+ )
464470 return saved_transform_io_v2 .write_v2_saved_model (
465- tf_graph_context .module_to_export , transform_fn , 'transform_fn' ,
466- saved_model_dir )
471+ tf_graph_context .module_to_export ,
472+ transform_fn ,
473+ 'transform_fn' ,
474+ saved_model_dir ,
475+ save_options ,
476+ )
467477
468478
469479def _trace_and_get_metadata (
470480 concrete_transform_fn : function .ConcreteFunction ,
471481 structured_inputs : Mapping [str , common_types .TensorType ],
472- preprocessing_fn : Callable [[Mapping [str , common_types .TensorType ]],
473- Mapping [str , common_types .TensorType ]],
482+ preprocessing_fn : Callable [
483+ [Mapping [str , common_types .TensorType ]],
484+ Mapping [str , common_types .TensorType ],
485+ ],
474486 base_temp_dir : Optional [str ],
475- tensor_replacement_map : Optional [Dict [str , tf .Tensor ]]
487+ tensor_replacement_map : Optional [Dict [str , tf .Tensor ]],
476488) -> dataset_metadata .DatasetMetadata :
477489 """Compute and return metadata for the outputs of `concrete_transform_fn`."""
478490 tf_graph_context = graph_context .TFGraphContext (
479491 module_to_export = tf .Module (),
480492 temp_dir = base_temp_dir ,
481- evaluated_replacements = tensor_replacement_map )
493+ evaluated_replacements = tensor_replacement_map ,
494+ )
482495 concrete_metadata_fn = schema_inference .get_traced_metadata_fn (
483496 preprocessing_fn ,
484497 structured_inputs ,
485498 tf_graph_context ,
486- evaluate_schema_overrides = True )
499+ evaluate_schema_overrides = True ,
500+ )
487501 return dataset_metadata .DatasetMetadata (
488502 schema = schema_inference .infer_feature_schema_v2 (
489503 concrete_transform_fn .structured_outputs ,
490504 concrete_metadata_fn ,
491- evaluate_schema_overrides = True ))
505+ evaluate_schema_overrides = True ,
506+ )
507+ )
492508
493509
494510def _validate_analyzers_fingerprint (
495- baseline_analyzers_fingerprint : Mapping [str ,
496- graph_tools .AnalyzersFingerprint ],
497- graph : tf .Graph , structured_inputs : Mapping [str , common_types .TensorType ]):
511+ baseline_analyzers_fingerprint : Mapping [
512+ str , graph_tools .AnalyzersFingerprint
513+ ],
514+ graph : tf .Graph ,
515+ structured_inputs : Mapping [str , common_types .TensorType ],
516+ ):
498517 """Validates analyzers fingerprint in `graph` is same as baseline."""
499518 analyzers_fingerprint = graph_tools .get_analyzers_fingerprint (
500- graph , structured_inputs )
519+ graph , structured_inputs
520+ )
501521 error_msg = (
502522 'The order of analyzers in your `preprocessing_fn` appears to be '
503523 'non-deterministic. This can be fixed either by changing your '
504524 '`preprocessing_fn` such that tf.Transform analyzers are encountered '
505525 'in a deterministic order or by passing a unique name to each '
506- 'analyzer API call.' )
526+ 'analyzer API call.'
527+ )
507528 for analyzer in analyzers_fingerprint :
508529 if analyzer not in baseline_analyzers_fingerprint :
509- prefix_msg = (f'Analyzer node ({ analyzer } ) not found in '
510- f'{ baseline_analyzers_fingerprint .keys ()} . ' )
530+ prefix_msg = (
531+ f'Analyzer node ({ analyzer } ) not found in '
532+ f'{ baseline_analyzers_fingerprint .keys ()} . '
533+ )
511534 raise RuntimeError (prefix_msg + error_msg )
512- if (baseline_analyzers_fingerprint [analyzer ].source_keys !=
513- analyzers_fingerprint [analyzer ].source_keys ):
535+ if (
536+ baseline_analyzers_fingerprint [analyzer ].source_keys
537+ != analyzers_fingerprint [analyzer ].source_keys
538+ ):
514539 raise RuntimeError (error_msg )
515540
516- if (baseline_analyzers_fingerprint [analyzer ].unique_path_hash !=
517- analyzers_fingerprint [analyzer ].unique_path_hash ):
541+ if (
542+ baseline_analyzers_fingerprint [analyzer ].unique_path_hash
543+ != analyzers_fingerprint [analyzer ].unique_path_hash
544+ ):
518545 logging .warning (
519- ' Analyzer (%s) node\ ' s cache key varies on repeated tracing.'
546+ " Analyzer (%s) node's cache key varies on repeated tracing."
520547 ' This warning is safe to ignore if you either specify `name` for all'
521548 ' analyzers or if the order in which they are invoked is'
522- ' deterministic. If not, please file a bug with details.' , analyzer )
549+ ' deterministic. If not, please file a bug with details.' ,
550+ analyzer ,
551+ )
523552
524553
525554def trace_and_write_v2_saved_model (
526555 saved_model_dir : str ,
527- preprocessing_fn : Callable [[Mapping [str , common_types .TensorType ]],
528- Mapping [str , common_types .TensorType ]],
529- input_signature : Mapping [str , tf .TypeSpec ], base_temp_dir : Optional [str ],
530- baseline_analyzers_fingerprint : Mapping [str ,
531- graph_tools .AnalyzersFingerprint ],
556+ preprocessing_fn : Callable [
557+ [Mapping [str , common_types .TensorType ]],
558+ Mapping [str , common_types .TensorType ],
559+ ],
560+ input_signature : Mapping [str , tf .TypeSpec ],
561+ base_temp_dir : Optional [str ],
562+ baseline_analyzers_fingerprint : Mapping [
563+ str , graph_tools .AnalyzersFingerprint
564+ ],
532565 tensor_replacement_map : Optional [Dict [str , tf .Tensor ]],
533- output_keys_to_name_map : Optional [Dict [str , str ]]):
566+ output_keys_to_name_map : Optional [Dict [str , str ]],
567+ save_options : Optional [tf .saved_model .SaveOptions ],
568+ ):
534569 """Writes out a SavedModelV2 with preprocessing_fn traced using tf.function.
535570
536571 The SavedModel written contains a method called `transform_fn` that
@@ -549,6 +584,7 @@ def trace_and_write_v2_saved_model(
549584 evaluated replacement tensors.
550585 output_keys_to_name_map: A map from output dictionary keys to the names of
551586 the tensors that they represent.
587+ save_options: The options to use when saving the saved_model.
552588
553589 Returns:
554590 A tuple containing a pair of `tf.ConcreteFunction`s:
@@ -562,7 +598,7 @@ def trace_and_write_v2_saved_model(
562598 """
563599 concrete_transform_fn = _trace_and_write_transform_fn (
564600 saved_model_dir , preprocessing_fn , input_signature , base_temp_dir ,
565- tensor_replacement_map , output_keys_to_name_map )
601+ tensor_replacement_map , output_keys_to_name_map , save_options )
566602 structured_inputs = tf2_utils .get_structured_inputs_from_func_graph (
567603 concrete_transform_fn .graph )
568604 _validate_analyzers_fingerprint (baseline_analyzers_fingerprint ,
@@ -632,7 +668,8 @@ def analyze_in_place(preprocessing_fn, force_tf_compat_v1, feature_specs,
632668 input_signature = type_specs ,
633669 base_temp_dir = None ,
634670 tensor_replacement_map = None ,
635- output_keys_to_name_map = None )
671+ output_keys_to_name_map = None ,
672+ save_options = None )
636673 _assert_no_analyzers_in_graph (concrete_transform_fn .graph )
637674 structured_inputs = tf2_utils .get_structured_inputs_from_func_graph (
638675 concrete_transform_fn .graph )
0 commit comments