@@ -226,6 +226,8 @@ def __init__(self, saved_model_dir, input_tensor_names, exclude_outputs):
226226 super ().__init__ (saved_model_dir , input_tensor_names , outputs_tensor_keys ,
227227 callable_get_outputs )
228228
229+ # Initialized in process().
230+ _graph_state : _GraphStateCommon
229231 # Initialized in setup().
230232 _tensor_adapter : TensorAdapter
231233 # i-th element in this list contains the index of the column corresponding
@@ -274,8 +276,6 @@ def __init__(
274276 # it across multiple threads in the current process.
275277 self ._shared_graph_state_handle = shared_graph_state_handle
276278
277- # Initialized in process().
278- self ._graph_state = None
279279 # Metrics.
280280 self ._graph_load_seconds_distribution = beam .metrics .Metrics .distribution (
281281 beam_common .METRICS_NAMESPACE , 'graph_load_seconds' )
@@ -315,8 +315,6 @@ def _get_passthrough_data_from_recordbatch(
315315 return result
316316
317317 def _handle_batch (self , batch ):
318- assert self ._graph_state is not None
319-
320318 self ._update_metrics (batch )
321319 # No need to remove (and cannot remove) the passthrough columns here:
322320 # 1) The TensorAdapter expects the RecordBatch to be of the same schema as
@@ -390,7 +388,7 @@ def process(self, batch, saved_model_dir):
390388 A representation of output features as a dict mapping keys (logical column
391389 names) to values.
392390 """
393- if self . _graph_state is None :
391+ if not hasattr ( self , '_graph_state' ) :
394392 # If available, acquire will return a cached _GraphStateCommon, since
395393 # calling _make_graph_state is expensive.
396394 self ._graph_state = self ._shared_graph_state_handle .acquire (
0 commit comments