Skip to content

Commit 2b437f1

Browse files
zoyahavtfx-copybara
authored andcommitted
Moving _graph_state declaration to class level of _RunMetaGraphDoFn
PiperOrigin-RevId: 544369780
1 parent eaf8de4 commit 2b437f1

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

tensorflow_transform/beam/impl.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ def __init__(self, saved_model_dir, input_tensor_names, exclude_outputs):
226226
super().__init__(saved_model_dir, input_tensor_names, outputs_tensor_keys,
227227
callable_get_outputs)
228228

229+
# Initialized in process().
230+
_graph_state: _GraphStateCommon
229231
# Initialized in setup().
230232
_tensor_adapter: TensorAdapter
231233
# i-th element in this list contains the index of the column corresponding
@@ -274,8 +276,6 @@ def __init__(
274276
# it across multiple threads in the current process.
275277
self._shared_graph_state_handle = shared_graph_state_handle
276278

277-
# Initialized in process().
278-
self._graph_state = None
279279
# Metrics.
280280
self._graph_load_seconds_distribution = beam.metrics.Metrics.distribution(
281281
beam_common.METRICS_NAMESPACE, 'graph_load_seconds')
@@ -315,8 +315,6 @@ def _get_passthrough_data_from_recordbatch(
315315
return result
316316

317317
def _handle_batch(self, batch):
318-
assert self._graph_state is not None
319-
320318
self._update_metrics(batch)
321319
# No need to remove (and cannot remove) the passthrough columns here:
322320
# 1) The TensorAdapter expects the RecordBatch to be of the same schema as
@@ -390,7 +388,7 @@ def process(self, batch, saved_model_dir):
390388
A representation of output features as a dict mapping keys (logical column
391389
names) to values.
392390
"""
393-
if self._graph_state is None:
391+
if not hasattr(self, '_graph_state'):
394392
# If available, acquire will return a cached _GraphStateCommon, since
395393
# calling _make_graph_state is expensive.
396394
self._graph_state = self._shared_graph_state_handle.acquire(

0 commit comments

Comments
 (0)