Skip to content

Commit fbd7d5f

Browse files
zoyahavtfx-copybara
authored andcommitted
Adds tests for the TFT graph node counts when cache is enabled/disabled.
This change also adds the functionality to count these nodes. PiperOrigin-RevId: 512036530
1 parent 877e39b commit fbd7d5f

File tree

3 files changed

+141
-9
lines changed

3 files changed

+141
-9
lines changed

tensorflow_transform/beam/cached_impl_test.py

Lines changed: 106 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,16 @@ def _get_running_index(self):
464464
self._running_index += 1
465465
return self._running_index
466466

467+
def _publish_rendered_dot_graph_file_from_leaf_nodes(self, leaf_nodes):
468+
dot_string = nodes.get_dot_graph(leaf_nodes).to_string()
469+
tf.io.gfile.makedirs(self.base_test_dir)
470+
output_file = os.path.join(
471+
self.base_test_dir,
472+
'rendered_graph_{}.svg'.format(self._get_running_index()),
473+
)
474+
self.WriteRenderedDotFile(dot_string, output_file=output_file)
475+
return dot_string
476+
467477
def _publish_rendered_dot_graph_file(self,
468478
preprocessing_fn,
469479
feature_spec,
@@ -486,15 +496,11 @@ def _publish_rendered_dot_graph_file(self,
486496
structured_outputs,
487497
dataset_keys, pcoll_cache_dict)
488498
sort_value_node_values = lambda d: sorted(d.values(), key=str)
489-
dot_string = nodes.get_dot_graph([transform_fn_future] +
490-
sort_value_node_values(cache_output_dict) +
491-
sideeffects).to_string()
492-
tf.io.gfile.makedirs(self.base_test_dir)
493-
output_file = os.path.join(
494-
self.base_test_dir,
495-
'rendered_graph_{}.svg'.format(self._get_running_index()))
496-
self.WriteRenderedDotFile(dot_string, output_file=output_file)
497-
return dot_string
499+
return self._publish_rendered_dot_graph_file_from_leaf_nodes(
500+
[transform_fn_future]
501+
+ sort_value_node_values(cache_output_dict)
502+
+ sideeffects
503+
)
498504

499505
_RunPipelineResult = tfx_namedtuple.namedtuple( # pylint: disable=invalid-name
500506
'_RunPipelineResult', ['cache_output', 'pipeline'])
@@ -1409,6 +1415,97 @@ def preprocessing_fn(inputs):
14091415
self.assertMetricsCounterEqual(p.metrics, 'analysis_input_bytes_from_cache',
14101416
0)
14111417

1418+
@tft_unit.parameters(
1419+
dict(
1420+
num_non_packed_analyzers=0,
1421+
num_packed_analyzers=0,
1422+
num_input_datasets=10,
1423+
expected_node_count_with_cache=1,
1424+
expected_node_count_without_cache=1,
1425+
),
1426+
dict(
1427+
num_non_packed_analyzers=2,
1428+
num_packed_analyzers=0,
1429+
num_input_datasets=1,
1430+
expected_node_count_with_cache=29,
1431+
expected_node_count_without_cache=24,
1432+
),
1433+
dict(
1434+
num_non_packed_analyzers=0,
1435+
num_packed_analyzers=2,
1436+
num_input_datasets=1,
1437+
expected_node_count_with_cache=24,
1438+
expected_node_count_without_cache=19,
1439+
),
1440+
dict(
1441+
num_non_packed_analyzers=2,
1442+
num_packed_analyzers=0,
1443+
num_input_datasets=10,
1444+
expected_node_count_with_cache=101,
1445+
expected_node_count_without_cache=24,
1446+
),
1447+
dict(
1448+
num_non_packed_analyzers=0,
1449+
num_packed_analyzers=2,
1450+
num_input_datasets=10,
1451+
expected_node_count_with_cache=87,
1452+
expected_node_count_without_cache=19,
1453+
),
1454+
)
1455+
def test_node_count(
1456+
self,
1457+
num_non_packed_analyzers,
1458+
num_packed_analyzers,
1459+
num_input_datasets,
1460+
expected_node_count_with_cache,
1461+
expected_node_count_without_cache,
1462+
):
1463+
dataset_keys = [str(x) for x in range(num_input_datasets)]
1464+
specs = impl_helper.get_type_specs_from_feature_specs(
1465+
{'x': tf.io.FixedLenFeature([], tf.int64)}
1466+
)
1467+
1468+
def preprocessing_fn(inputs):
1469+
for _ in range(num_packed_analyzers):
1470+
tft.mean(inputs['x'])
1471+
for _ in range(num_non_packed_analyzers):
1472+
tft.vocabulary(inputs['x'])
1473+
return inputs
1474+
1475+
def get_graph_leaf_nodes(cache_enabled):
1476+
graph, structured_inputs, structured_outputs = (
1477+
impl_helper.trace_preprocessing_function(
1478+
preprocessing_fn,
1479+
specs,
1480+
use_tf_compat_v1=False,
1481+
base_temp_dir=self.base_test_dir,
1482+
)
1483+
)
1484+
(transform_fn_future, cache_output_dict, sideeffects) = (
1485+
analysis_graph_builder.build(
1486+
graph,
1487+
structured_inputs,
1488+
structured_outputs,
1489+
dataset_keys,
1490+
cache_dict={} if cache_enabled else None,
1491+
)
1492+
)
1493+
cache_output_nodes = (
1494+
list(cache_output_dict.values()) if cache_output_dict else []
1495+
)
1496+
return [transform_fn_future] + cache_output_nodes + sideeffects
1497+
1498+
with_cache_graph = get_graph_leaf_nodes(True)
1499+
without_cache_graph = get_graph_leaf_nodes(False)
1500+
node_count_with_cache = nodes.count_graph_nodes(with_cache_graph)
1501+
node_count_without_cache = nodes.count_graph_nodes(without_cache_graph)
1502+
self._publish_rendered_dot_graph_file_from_leaf_nodes(without_cache_graph)
1503+
self._publish_rendered_dot_graph_file_from_leaf_nodes(with_cache_graph)
1504+
self.assertEqual(
1505+
(expected_node_count_without_cache, expected_node_count_with_cache),
1506+
(node_count_without_cache, node_count_with_cache),
1507+
)
1508+
14121509

14131510
if __name__ == '__main__':
14141511
tft_unit.main()

tensorflow_transform/nodes.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,3 +371,36 @@ def get_dot_graph(leaf_nodes: Collection[ValueNode]) -> pydot.Dot:
371371
for value_node in leaf_nodes:
372372
traverser.visit_value_node(value_node)
373373
return visitor.get_dot_graph()
374+
375+
376+
class _CountGraphNodes(Visitor):
377+
"""Visitor which counts the graph nodes."""
378+
379+
num_nodes = 0
380+
381+
def visit(self, operation_def: OperationDef, _) -> Tuple[int]:
382+
self.num_nodes += 1
383+
return tuple(1 for _ in range(operation_def.num_outputs))
384+
385+
def validate_value(self, value: int):
386+
pass
387+
388+
389+
def count_graph_nodes(leaf_nodes: Collection[ValueNode]) -> int:
390+
"""Counts the number of graph nodes.
391+
392+
Note: these nodes only include the TFT graph nodes, it doesn't count beam
393+
nodes constructed directly.
394+
395+
Args:
396+
leaf_nodes: A list of leaf `ValueNode`s to define the graph. The graph will
397+
be the transitive parents of the leaf nodes.
398+
399+
Returns:
400+
The count of TFT graph nodes.
401+
"""
402+
visitor = _CountGraphNodes()
403+
traverser = Traverser(visitor)
404+
for value_node in leaf_nodes:
405+
traverser.visit_value_node(value_node)
406+
return visitor.num_nodes

tensorflow_transform/test_case.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def wrapper(fn):
8585
arg_names = arg_names[1:]
8686

8787
def to_arg_dict(testcase):
88+
if isinstance(testcase, dict):
89+
return testcase
8890
testcase = tuple(testcase)
8991
if len(testcase) != len(arg_names):
9092
raise ValueError(

0 commit comments

Comments
 (0)