@@ -464,6 +464,16 @@ def _get_running_index(self):
464464 self ._running_index += 1
465465 return self ._running_index
466466
467+ def _publish_rendered_dot_graph_file_from_leaf_nodes (self , leaf_nodes ):
468+ dot_string = nodes .get_dot_graph (leaf_nodes ).to_string ()
469+ tf .io .gfile .makedirs (self .base_test_dir )
470+ output_file = os .path .join (
471+ self .base_test_dir ,
472+ 'rendered_graph_{}.svg' .format (self ._get_running_index ()),
473+ )
474+ self .WriteRenderedDotFile (dot_string , output_file = output_file )
475+ return dot_string
476+
467477 def _publish_rendered_dot_graph_file (self ,
468478 preprocessing_fn ,
469479 feature_spec ,
@@ -486,15 +496,11 @@ def _publish_rendered_dot_graph_file(self,
486496 structured_outputs ,
487497 dataset_keys , pcoll_cache_dict )
488498 sort_value_node_values = lambda d : sorted (d .values (), key = str )
489- dot_string = nodes .get_dot_graph ([transform_fn_future ] +
490- sort_value_node_values (cache_output_dict ) +
491- sideeffects ).to_string ()
492- tf .io .gfile .makedirs (self .base_test_dir )
493- output_file = os .path .join (
494- self .base_test_dir ,
495- 'rendered_graph_{}.svg' .format (self ._get_running_index ()))
496- self .WriteRenderedDotFile (dot_string , output_file = output_file )
497- return dot_string
499+ return self ._publish_rendered_dot_graph_file_from_leaf_nodes (
500+ [transform_fn_future ]
501+ + sort_value_node_values (cache_output_dict )
502+ + sideeffects
503+ )
498504
499505 _RunPipelineResult = tfx_namedtuple .namedtuple ( # pylint: disable=invalid-name
500506 '_RunPipelineResult' , ['cache_output' , 'pipeline' ])
@@ -1409,6 +1415,97 @@ def preprocessing_fn(inputs):
14091415 self .assertMetricsCounterEqual (p .metrics , 'analysis_input_bytes_from_cache' ,
14101416 0 )
14111417
1418+ @tft_unit .parameters (
1419+ dict (
1420+ num_non_packed_analyzers = 0 ,
1421+ num_packed_analyzers = 0 ,
1422+ num_input_datasets = 10 ,
1423+ expected_node_count_with_cache = 1 ,
1424+ expected_node_count_without_cache = 1 ,
1425+ ),
1426+ dict (
1427+ num_non_packed_analyzers = 2 ,
1428+ num_packed_analyzers = 0 ,
1429+ num_input_datasets = 1 ,
1430+ expected_node_count_with_cache = 29 ,
1431+ expected_node_count_without_cache = 24 ,
1432+ ),
1433+ dict (
1434+ num_non_packed_analyzers = 0 ,
1435+ num_packed_analyzers = 2 ,
1436+ num_input_datasets = 1 ,
1437+ expected_node_count_with_cache = 24 ,
1438+ expected_node_count_without_cache = 19 ,
1439+ ),
1440+ dict (
1441+ num_non_packed_analyzers = 2 ,
1442+ num_packed_analyzers = 0 ,
1443+ num_input_datasets = 10 ,
1444+ expected_node_count_with_cache = 101 ,
1445+ expected_node_count_without_cache = 24 ,
1446+ ),
1447+ dict (
1448+ num_non_packed_analyzers = 0 ,
1449+ num_packed_analyzers = 2 ,
1450+ num_input_datasets = 10 ,
1451+ expected_node_count_with_cache = 87 ,
1452+ expected_node_count_without_cache = 19 ,
1453+ ),
1454+ )
1455+ def test_node_count (
1456+ self ,
1457+ num_non_packed_analyzers ,
1458+ num_packed_analyzers ,
1459+ num_input_datasets ,
1460+ expected_node_count_with_cache ,
1461+ expected_node_count_without_cache ,
1462+ ):
1463+ dataset_keys = [str (x ) for x in range (num_input_datasets )]
1464+ specs = impl_helper .get_type_specs_from_feature_specs (
1465+ {'x' : tf .io .FixedLenFeature ([], tf .int64 )}
1466+ )
1467+
1468+ def preprocessing_fn (inputs ):
1469+ for _ in range (num_packed_analyzers ):
1470+ tft .mean (inputs ['x' ])
1471+ for _ in range (num_non_packed_analyzers ):
1472+ tft .vocabulary (inputs ['x' ])
1473+ return inputs
1474+
1475+ def get_graph_leaf_nodes (cache_enabled ):
1476+ graph , structured_inputs , structured_outputs = (
1477+ impl_helper .trace_preprocessing_function (
1478+ preprocessing_fn ,
1479+ specs ,
1480+ use_tf_compat_v1 = False ,
1481+ base_temp_dir = self .base_test_dir ,
1482+ )
1483+ )
1484+ (transform_fn_future , cache_output_dict , sideeffects ) = (
1485+ analysis_graph_builder .build (
1486+ graph ,
1487+ structured_inputs ,
1488+ structured_outputs ,
1489+ dataset_keys ,
1490+ cache_dict = {} if cache_enabled else None ,
1491+ )
1492+ )
1493+ cache_output_nodes = (
1494+ list (cache_output_dict .values ()) if cache_output_dict else []
1495+ )
1496+ return [transform_fn_future ] + cache_output_nodes + sideeffects
1497+
1498+ with_cache_graph = get_graph_leaf_nodes (True )
1499+ without_cache_graph = get_graph_leaf_nodes (False )
1500+ node_count_with_cache = nodes .count_graph_nodes (with_cache_graph )
1501+ node_count_without_cache = nodes .count_graph_nodes (without_cache_graph )
1502+ self ._publish_rendered_dot_graph_file_from_leaf_nodes (without_cache_graph )
1503+ self ._publish_rendered_dot_graph_file_from_leaf_nodes (with_cache_graph )
1504+ self .assertEqual (
1505+ (expected_node_count_without_cache , expected_node_count_with_cache ),
1506+ (node_count_without_cache , node_count_with_cache ),
1507+ )
1508+
14121509
14131510if __name__ == '__main__' :
14141511 tft_unit .main ()
0 commit comments