2323from tensorflow_transform import tf2_utils
2424from tensorflow_transform .beam import analysis_graph_builder
2525from tensorflow_transform .beam import analyzer_cache
26- from tensorflow_transform import test_case
26+ from tensorflow_transform . beam import tft_unit
2727# TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple`
2828# once the Spark issue is resolved.
2929from tfx_bsl .types import tfx_namedtuple
@@ -396,17 +396,21 @@ def __new__(cls):
396396]
397397
398398
399- class AnalysisGraphBuilderTest (test_case .TransformTestCase ):
399+ class AnalysisGraphBuilderTest (tft_unit .TransformTestCase ):
400400
401- @test_case .named_parameters (
402- * test_case .cross_named_parameters (_ANALYZE_TEST_CASES , [
403- dict (testcase_name = 'tf_compat_v1' , use_tf_compat_v1 = True ),
404- dict (testcase_name = 'tf2' , use_tf_compat_v1 = False )
405- ]))
401+ @tft_unit .named_parameters (
402+ * tft_unit .cross_named_parameters (
403+ _ANALYZE_TEST_CASES ,
404+ [
405+ dict (testcase_name = 'tf_compat_v1' , use_tf_compat_v1 = True ),
406+ dict (testcase_name = 'tf2' , use_tf_compat_v1 = False ),
407+ ],
408+ )
409+ )
406410 def test_build (self , feature_spec , preprocessing_fn , expected_dot_graph_str ,
407411 expected_dot_graph_str_tf2 , use_tf_compat_v1 ):
408412 if not use_tf_compat_v1 :
409- test_case .skip_if_not_tf2 ('Tensorflow 2.x required' )
413+ tft_unit .skip_if_not_tf2 ('Tensorflow 2.x required' )
410414 specs = (
411415 feature_spec if use_tf_compat_v1 else
412416 impl_helper .get_type_specs_from_feature_specs (feature_spec ))
@@ -430,48 +434,54 @@ def test_build(self, feature_spec, preprocessing_fn, expected_dot_graph_str,
430434 second = (expected_dot_graph_str
431435 if use_tf_compat_v1 else expected_dot_graph_str_tf2 ))
432436
433- @test_case .named_parameters (* test_case .cross_named_parameters (
434- [
435- dict (
436- testcase_name = 'one_dataset_cached_single_phase' ,
437- preprocessing_fn = _preprocessing_fn_with_one_analyzer ,
438- full_dataset_keys = ['a' , 'b' ],
439- cached_dataset_keys = ['a' ],
440- expected_dataset_keys = ['b' ],
441- ),
442- dict (
443- testcase_name = 'all_datasets_cached_single_phase' ,
444- preprocessing_fn = _preprocessing_fn_with_one_analyzer ,
445- full_dataset_keys = ['a' , 'b' ],
446- cached_dataset_keys = ['a' , 'b' ],
447- expected_dataset_keys = [],
448- ),
449- dict (
450- testcase_name = 'mixed_single_phase' ,
451- preprocessing_fn = lambda d : dict ( # pylint: disable=g-long-lambda
452- list (_preprocessing_fn_with_chained_ptransforms (d ).items ()) +
453- list (_preprocessing_fn_with_one_analyzer (d ).items ())),
454- full_dataset_keys = ['a' , 'b' ],
455- cached_dataset_keys = ['a' , 'b' ],
456- expected_dataset_keys = ['a' , 'b' ],
457- ),
458- dict (
459- testcase_name = 'multi_phase' ,
460- preprocessing_fn = _preprocessing_fn_with_two_phases ,
461- full_dataset_keys = ['a' , 'b' ],
462- cached_dataset_keys = ['a' , 'b' ],
463- expected_dataset_keys = ['a' , 'b' ],
464- )
465- ],
466- [
467- dict (testcase_name = 'tf_compat_v1' , use_tf_compat_v1 = True ),
468- dict (testcase_name = 'tf2' , use_tf_compat_v1 = False )
469- ]))
437+ @tft_unit .named_parameters (
438+ * tft_unit .cross_named_parameters (
439+ [
440+ dict (
441+ testcase_name = 'one_dataset_cached_single_phase' ,
442+ preprocessing_fn = _preprocessing_fn_with_one_analyzer ,
443+ full_dataset_keys = ['a' , 'b' ],
444+ cached_dataset_keys = ['a' ],
445+ expected_dataset_keys = ['b' ],
446+ ),
447+ dict (
448+ testcase_name = 'all_datasets_cached_single_phase' ,
449+ preprocessing_fn = _preprocessing_fn_with_one_analyzer ,
450+ full_dataset_keys = ['a' , 'b' ],
451+ cached_dataset_keys = ['a' , 'b' ],
452+ expected_dataset_keys = [],
453+ ),
454+ dict (
455+ testcase_name = 'mixed_single_phase' ,
456+ preprocessing_fn = lambda d : dict ( # pylint: disable=g-long-lambda
457+ list (
458+ _preprocessing_fn_with_chained_ptransforms (d ).items ()
459+ )
460+ + list (_preprocessing_fn_with_one_analyzer (d ).items ())
461+ ),
462+ full_dataset_keys = ['a' , 'b' ],
463+ cached_dataset_keys = ['a' , 'b' ],
464+ expected_dataset_keys = ['a' , 'b' ],
465+ ),
466+ dict (
467+ testcase_name = 'multi_phase' ,
468+ preprocessing_fn = _preprocessing_fn_with_two_phases ,
469+ full_dataset_keys = ['a' , 'b' ],
470+ cached_dataset_keys = ['a' , 'b' ],
471+ expected_dataset_keys = ['a' , 'b' ],
472+ ),
473+ ],
474+ [
475+ dict (testcase_name = 'tf_compat_v1' , use_tf_compat_v1 = True ),
476+ dict (testcase_name = 'tf2' , use_tf_compat_v1 = False ),
477+ ],
478+ )
479+ )
470480 def test_get_analysis_dataset_keys (self , preprocessing_fn , full_dataset_keys ,
471481 cached_dataset_keys , expected_dataset_keys ,
472482 use_tf_compat_v1 ):
473483 if not use_tf_compat_v1 :
474- test_case .skip_if_not_tf2 ('Tensorflow 2.x required' )
484+ tft_unit .skip_if_not_tf2 ('Tensorflow 2.x required' )
475485 full_dataset_keys = list (
476486 map (analyzer_cache .DatasetKey , full_dataset_keys ))
477487 cached_dataset_keys = map (analyzer_cache .DatasetKey , cached_dataset_keys )
@@ -499,18 +509,16 @@ def test_get_analysis_dataset_keys(self, preprocessing_fn, full_dataset_keys,
499509 full_dataset_keys ,
500510 input_cache ,
501511 force_tf_compat_v1 = use_tf_compat_v1 ))
502-
503- dot_string = nodes .get_dot_graph ([analysis_graph_builder ._ANALYSIS_GRAPH
504- ]).to_string ()
505- self .WriteRenderedDotFile (dot_string )
512+ self .DebugPublishLatestsRenderedTFTGraph ()
506513 self .assertCountEqual (expected_dataset_keys , dataset_keys )
507514
508- @test_case .named_parameters (
515+ @tft_unit .named_parameters (
509516 dict (testcase_name = 'tf_compat_v1' , use_tf_compat_v1 = True ),
510- dict (testcase_name = 'tf2' , use_tf_compat_v1 = False ))
517+ dict (testcase_name = 'tf2' , use_tf_compat_v1 = False ),
518+ )
511519 def test_get_analysis_cache_entry_keys (self , use_tf_compat_v1 ):
512520 if not use_tf_compat_v1 :
513- test_case .skip_if_not_tf2 ('Tensorflow 2.x required' )
521+ tft_unit .skip_if_not_tf2 ('Tensorflow 2.x required' )
514522 full_dataset_keys = map (analyzer_cache .DatasetKey , ['a' , 'b' ])
515523 def preprocessing_fn (inputs ):
516524 return {'x' : tft .scale_to_0_1 (inputs ['x' ])}
@@ -531,10 +539,7 @@ def mocked_make_cache_entry_key(_):
531539 specs ,
532540 full_dataset_keys ,
533541 force_tf_compat_v1 = use_tf_compat_v1 ))
534-
535- dot_string = nodes .get_dot_graph ([analysis_graph_builder ._ANALYSIS_GRAPH
536- ]).to_string ()
537- self .WriteRenderedDotFile (dot_string )
542+ self .DebugPublishLatestsRenderedTFTGraph ()
538543 self .assertCountEqual (cache_entry_keys , [mocked_cache_entry_key ])
539544
540545 def test_duplicate_label_error (self ):
@@ -575,4 +580,4 @@ class _Analyzer(
575580
576581
577582if __name__ == '__main__' :
578- test_case .main ()
583+ tft_unit .main ()
0 commit comments