Skip to content

Commit a2efaef

Browse files
zoyahavtfx-copybara
authored andcommitted
Switch from namedtuple to frozen dataclass for TFT graph elements while adding more pytyping for easier maintenance.
PiperOrigin-RevId: 548621263
1 parent ffd8301 commit a2efaef

File tree

3 files changed

+44
-88
lines changed

3 files changed

+44
-88
lines changed

tensorflow_transform/beam/analysis_graph_builder.py

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
"""Functions to create the implementation graph."""
1515

1616
import collections
17+
import dataclasses
1718
import hashlib
18-
19-
from typing import Dict, Mapping, Collection, Optional, Tuple
19+
from typing import Collection, Dict, Mapping, Optional, Tuple
20+
from typing import OrderedDict as OrderedDictType
2021

2122
import tensorflow as tf
2223
from tensorflow_transform import analyzer_nodes
@@ -29,9 +30,6 @@
2930
from tensorflow_transform.beam import analyzer_cache
3031
from tensorflow_transform.beam import beam_nodes
3132
from tensorflow_transform.beam import combiner_packing_util
32-
# TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple`
33-
# once the Spark issue is resolved.
34-
from tfx_bsl.types import tfx_namedtuple
3533

3634

3735
# Used for debugging only. This will point to the most recent graph built.
@@ -109,11 +107,8 @@ def validate_value(self, value):
109107
assert isinstance(value, nodes.ValueNode)
110108

111109

112-
class _OptimizationView(
113-
tfx_namedtuple.namedtuple('_OptimizationView', [
114-
'prefer_fine_grained_view', 'flattened_view', 'fine_grained_view',
115-
'hashed_path'
116-
])):
110+
@dataclasses.dataclass(frozen=True)
111+
class _OptimizationView:
117112
"""A container for operation outputs during _OptimizeVisitor traversal.
118113
119114
This is used in order to maintain both a flattened view, and a fine grained
@@ -123,29 +118,15 @@ class _OptimizationView(
123118
`fine_grained_view` should be used. It should be set to true if the upstream
124119
view has cacheing operations that haven't been flattened yet.
125120
"""
121+
prefer_fine_grained_view: bool
122+
flattened_view: nodes.ValueNode
123+
fine_grained_view: Optional[OrderedDictType[str, nodes.ValueNode]]
124+
hashed_path: Optional[bytes]
126125

127-
def __init__(self, prefer_fine_grained_view, flattened_view,
128-
fine_grained_view, hashed_path):
129-
if prefer_fine_grained_view and not fine_grained_view:
126+
def __post_init__(self):
127+
if self.prefer_fine_grained_view and not self.fine_grained_view:
130128
raise ValueError(
131129
'Cannot prefer fine_grained_view when one is not provided')
132-
del hashed_path
133-
self._validate_flattened_view(flattened_view)
134-
self._validate_fine_grained_view(fine_grained_view)
135-
super().__init__()
136-
137-
def _validate_flattened_view(self, view):
138-
assert view is self.flattened_view
139-
assert view is not None
140-
assert isinstance(view, nodes.ValueNode), view
141-
142-
def _validate_fine_grained_view(self, view):
143-
assert view is self.fine_grained_view
144-
if view is None:
145-
return
146-
assert isinstance(view, collections.OrderedDict), view
147-
for value in view.values():
148-
assert isinstance(value, nodes.ValueNode), value
149130

150131

151132
class _OptimizeVisitor(nodes.Visitor):
@@ -287,8 +268,11 @@ def visit(self, operation_def, input_values):
287268
for view in input_values:
288269
disaggregated_input_values.extend(view.fine_grained_view.values())
289270

290-
# Checking that all cache has the same size.
291-
assert len({len(value) for value in disaggregated_input_values}) == 1
271+
# Each cache item should be a single ValueNode.
272+
assert all(
273+
isinstance(value, nodes.ValueNode)
274+
for value in disaggregated_input_values
275+
)
292276

293277
next_inputs = nodes.apply_multi_output_operation(
294278
beam_nodes.Flatten,

tensorflow_transform/nodes.py

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,39 +26,10 @@
2626

2727
import abc
2828
import collections
29-
from typing import Collection, Optional, Tuple
29+
import dataclasses
30+
from typing import Any, Collection, Dict, List, Optional, Tuple
3031

3132
import pydot
32-
# TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple`
33-
# once the Spark issue is resolved.
34-
from tfx_bsl.types import tfx_namedtuple
35-
36-
37-
class ValueNode(
38-
tfx_namedtuple.namedtuple('ValueNode',
39-
['parent_operation', 'value_index'])):
40-
"""A placeholder that will ultimately be translated to a PCollection.
41-
42-
Attributes:
43-
parent_operation: The `OperationNode` that produces this value.
44-
value_index: The index of this value in the outputs of `parent_operation`.
45-
"""
46-
__slots__ = ()
47-
48-
def __init__(self, parent_operation, value_index: int):
49-
if not isinstance(parent_operation, OperationNode):
50-
raise TypeError(
51-
'parent_operation must be a OperationNode, got {} of type {}'.format(
52-
parent_operation, type(parent_operation)))
53-
num_outputs = parent_operation.operation_def.num_outputs
54-
if not (0 <= value_index and value_index < num_outputs):
55-
raise ValueError(
56-
'value_index was {} but parent_operation had {} outputs'.format(
57-
value_index, num_outputs))
58-
super().__init__()
59-
60-
def __iter__(self):
61-
raise ValueError('ValueNode is not iterable')
6233

6334

6435
class OperationDef(metaclass=abc.ABCMeta):
@@ -114,6 +85,28 @@ def cache_coder(self) -> Optional[object]:
11485
return None
11586

11687

88+
@dataclasses.dataclass(frozen=True)
89+
class ValueNode:
90+
"""A placeholder that will ultimately be translated to a PCollection.
91+
92+
Attributes:
93+
parent_operation: The `OperationNode` that produces this value.
94+
value_index: The index of this value in the outputs of `parent_operation`.
95+
"""
96+
97+
parent_operation: 'OperationNode'
98+
value_index: int
99+
100+
def __post_init__(self):
101+
num_outputs = self.parent_operation.operation_def.num_outputs
102+
if not (0 <= self.value_index and self.value_index < num_outputs):
103+
raise ValueError(
104+
'value_index was {} but parent_operation had {} outputs'.format(
105+
self.value_index, num_outputs
106+
)
107+
)
108+
109+
117110
class OperationNode:
118111
"""A placeholder that will ultimately be translated to a PTransform.
119112
@@ -229,8 +222,8 @@ def __init__(self, visitor: Visitor):
229222
Args:
230223
visitor: A `Visitor` object.
231224
"""
232-
self._cached_value_nodes_values = {}
233-
self._stack = []
225+
self._cached_value_nodes_values: Dict[ValueNode, Any] = {}
226+
self._stack: List[OperationNode] = []
234227
self._visitor = visitor
235228

236229
def visit_value_node(self, value_node: ValueNode):

tensorflow_transform/nodes_test.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -70,27 +70,6 @@ def testApplyOperationWithTupleOutput(self):
7070
self.assertEqual(op.inputs, (a, b))
7171
self.assertEqual(op.outputs, (b_copy, a_copy))
7272

73-
def testOperationNodeWithBadOperatonDef(self):
74-
with self.assertRaisesRegexp(
75-
TypeError, 'operation_def must be an OperationDef, got'):
76-
nodes.OperationNode('not a operation_def', ())
77-
78-
def testOperationNodeWithBadInput(self):
79-
a = nodes.apply_operation(_Constant, value='a', label='Constant[a]')
80-
with self.assertRaisesRegexp(
81-
TypeError, 'Inputs to Operation must be a ValueNode, got'):
82-
nodes.OperationNode(_Concat(label='Concat'), (a, 'not a value_node'))
83-
84-
def testOperationNodeWithBadInputs(self):
85-
with self.assertRaisesRegexp(
86-
TypeError, 'inputs must be a tuple, got'):
87-
nodes.OperationNode(_Concat(label='Concat'), 'not a tuple')
88-
89-
def testValueNodeWithBadParent(self):
90-
with self.assertRaisesRegexp(
91-
TypeError, 'parent_operation must be a OperationNode, got'):
92-
nodes.ValueNode('not an operation node', 0)
93-
9473
def testValueNodeWithNegativeValueIndex(self):
9574
a = nodes.apply_operation(_Constant, value='a', label='Constant[a]')
9675
with self.assertRaisesWithLiteralMatch(
@@ -183,7 +162,7 @@ def testTraverserOutputsNotATuple(self):
183162
mock_visitor = mock.MagicMock()
184163
mock_visitor.visit.side_effect = [42]
185164

186-
with self.assertRaisesRegexp(
165+
with self.assertRaisesRegex(
187166
ValueError, r'expected visitor to return a tuple, got'):
188167
nodes.Traverser(mock_visitor).visit_value_node(a)
189168

@@ -192,7 +171,7 @@ def testTraverserBadNumOutputs(self):
192171
mock_visitor = mock.MagicMock()
193172
mock_visitor.visit.side_effect = [('a', 'b')]
194173

195-
with self.assertRaisesRegexp(
174+
with self.assertRaisesRegex(
196175
ValueError, 'has 1 outputs but visitor returned 2 values: '):
197176
nodes.Traverser(mock_visitor).visit_value_node(a)
198177

0 commit comments

Comments
 (0)