diff --git a/tfx_bsl/beam/run_inference.py b/tfx_bsl/beam/run_inference.py index 37f8e91f..0f37bc2a 100644 --- a/tfx_bsl/beam/run_inference.py +++ b/tfx_bsl/beam/run_inference.py @@ -44,8 +44,8 @@ from tfx_bsl.beam import shared from tfx_bsl.public.proto import model_spec_pb2 from tfx_bsl.telemetry import util -from typing import Any, Generator, Iterable, List, Mapping, Sequence, Text, \ - Tuple, Union +from typing import Any, Generator, Iterable, List, Mapping, Optional, \ + Sequence, Text, Tuple, Union # TODO(b/140306674): stop using the internal TF API. from tensorflow.python.saved_model import loader_impl @@ -86,6 +86,15 @@ Tuple[tf.train.Example, classification_pb2.Classifications]] +# Public facing type aliases +ExampleType = Union[tf.train.Example, tf.train.SequenceExample] +QueryType = Tuple[Union[model_spec_pb2.InferenceSpecType, None], ExampleType] + +_QueryBatchType = Tuple[ + Union[model_spec_pb2.InferenceSpecType, None], + List[ExampleType] +] + # TODO(b/151468119): Converts this into enum once we stop supporting Python 2.7 class OperationType(object): @@ -96,120 +105,258 @@ class OperationType(object): @beam.ptransform_fn -@beam.typehints.with_input_types(Union[tf.train.Example, - tf.train.SequenceExample]) +@beam.typehints.with_input_types(Union[ExampleType, QueryType]) @beam.typehints.with_output_types(prediction_log_pb2.PredictionLog) def RunInferenceImpl( # pylint: disable=invalid-name examples: beam.pvalue.PCollection, - inference_spec_type: model_spec_pb2.InferenceSpecType + inference_spec_type: Union[model_spec_pb2.InferenceSpecType, + beam.pvalue.PCollection] = None ) -> beam.pvalue.PCollection: """Implementation of RunInference API. + Note: inference with a model PCollection and inference on queries require + a Beam runner with stateful DoFn support. + Args: - examples: A PCollection containing examples. - inference_spec_type: Model inference endpoint. + examples: A PCollection containing examples. If inference_spec_type is + None, this is interpreted as a PCollection of queries: + (InferenceSpecType, Example) + inference_spec_type: Model inference endpoint. Can be one of: + - InferenceSpecType: specifies a fixed model to use for inference. + - PCollection[InferenceSpecType]: specifies a secondary PCollection of + models. Each example will use the most recent model for inference. + (requires stateful DoFn support) + - None: indicates that the primary PCollection contains + (InferenceSpecType, Example) tuples. (requires stateful DoFn support) Returns: A PCollection containing prediction logs. Raises: - ValueError; when operation is not supported. + ValueError: When operation is not supported. + NotImplementedError: If the selected API is not supported by the current + runner. """ - logging.info('RunInference on model: %s', inference_spec_type) - - batched_examples = examples | 'BatchExamples' >> beam.BatchElements() - operation_type = _get_operation_type(inference_spec_type) - if operation_type == OperationType.CLASSIFICATION: - return batched_examples | 'Classify' >> _Classify(inference_spec_type) - elif operation_type == OperationType.REGRESSION: - return batched_examples | 'Regress' >> _Regress(inference_spec_type) - elif operation_type == OperationType.PREDICTION: - return batched_examples | 'Predict' >> _Predict(inference_spec_type) - elif operation_type == OperationType.MULTIHEAD: - return (batched_examples - | 'MultiInference' >> _MultiInference(inference_spec_type)) + predictions = None + if type(inference_spec_type) is model_spec_pb2.InferenceSpecType: + logging.info('RunInference on model: %s', inference_spec_type) + queries = examples | 'Format as queries' >> beam.Map(lambda x: (None, x)) + predictions = queries | '_RunInferenceCoreOnFixedModel' >> _RunInferenceCore( + fixed_inference_spec_type=inference_spec_type) + elif type(inference_spec_type) is beam.pvalue.PCollection: + if not _runner_supports_stateful_dofn(examples.pipeline.runner): + raise NotImplementedError( + 'Model streaming inference requires stateful DoFn support which is not' + 'provided by the current runner: %s' % repr(examples.pipeline.runner)) + + logging.info('RunInference on dynamic models') + queries = examples | 'Join examples' >> _TemporalJoin(inference_spec_type) + predictions = queries | '_RunInferenceCoreOnDynamicModel' >> _RunInferenceCore() + elif inference_spec_type is None: + if not _runner_supports_stateful_dofn(examples.pipeline.runner): + raise NotImplementedError( + 'Inference on queries requires stateful DoFn support which is not' + 'provided by the current runner: %s' % repr(examples.pipeline.runner)) + + logging.info('RunInference on queries') + predictions = examples | '_RunInferenceCoreOnQueries' >> _RunInferenceCore() else: - raise ValueError('Unsupported operation_type %s' % operation_type) + raise ValueError('Invalid type for inference_spec_type: %s' + % type(inference_spec_type)) - -_IOTensorSpec = collections.namedtuple( - '_IOTensorSpec', - ['input_tensor_alias', 'input_tensor_name', 'output_alias_tensor_names']) - -_Signature = collections.namedtuple('_Signature', ['name', 'signature_def']) + return predictions @beam.ptransform_fn -@beam.typehints.with_input_types(Union[tf.train.Example, - tf.train.SequenceExample]) +@beam.typehints.with_input_types(QueryType) @beam.typehints.with_output_types(prediction_log_pb2.PredictionLog) -def _Classify(pcoll: beam.pvalue.PCollection, # pylint: disable=invalid-name - inference_spec_type: model_spec_pb2.InferenceSpecType): - """Performs classify PTransform.""" - if _using_in_process_inference(inference_spec_type): - return (pcoll - | 'Classify' >> beam.ParDo( - _BatchClassifyDoFn(inference_spec_type, shared.Shared())) - | 'BuildPredictionLogForClassifications' >> beam.ParDo( - _BuildPredictionLogForClassificationsDoFn())) +def _RunInferenceCore( + queries: beam.pvalue.PCollection, + fixed_inference_spec_type: model_spec_pb2.InferenceSpecType = None +) -> beam.pvalue.PCollection: + """Runs inference on queries and returns prediction logs. + + This internal run inference implementation operates on queries. Internally, + these queries are grouped by model and inference runs in batches. If a + fixed_inference_spec_type is provided, this spec is used for all inference + requests which enables pre-configuring the model during pipeline + construction. If the fixed_inference_spec_type is not provided, each input + query must contain a valid InferenceSpecType and models will be loaded + dynamically at runtime. + + Args: + queries: A PCollection containing QueryType tuples. + fixed_inference_spec_type: An optional model inference endpoint. If + specified, this is "preloaded" during inference and models specified in + query tuples are ignored. This requires the InferenceSpecType to be known + at pipeline creation time. If this fixed_inference_spec_type is not + provided, each input query must contain a valid InferenceSpecType and + models will be loaded dynamically at runtime. + + Returns: + A PCollection containing prediction logs. + + Raises: + ValueError: when operation is not supported. + """ + batched_queries = None + if _runner_supports_stateful_dofn(queries.pipeline.runner): + batched_queries = queries | 'BatchQueries' >> _BatchQueries() else: - raise NotImplementedError + # If the current runner does not support stateful DoFn's, we fall back to + # a simpler batching operation that assumes all queries share the same + # inference spec. + batched_queries = queries | 'BatchQueriesSimple' >> _BatchQueriesSimple() + + predictions = None + + if fixed_inference_spec_type is None: + # operation type is determined at runtime + split = batched_queries | 'SplitByOperation' >> _SplitByOperation() + + predictions = [ + split[OperationType.CLASSIFICATION] | 'Classify' >> _Classify(), + split[OperationType.REGRESSION] | 'Regress' >> _Regress(), + split[OperationType.PREDICTION] | 'Predict' >> _Predict(), + split[OperationType.MULTIHEAD] | 'MultiInference' >> _MultiInference() + ] | beam.Flatten() + else: + # operation type is determined at pipeline construction time + operation_type = _get_operation_type(fixed_inference_spec_type) + + if operation_type == OperationType.CLASSIFICATION: + predictions = batched_queries | 'Classify' >> _Classify( + fixed_inference_spec_type=fixed_inference_spec_type) + elif operation_type == OperationType.REGRESSION: + predictions = batched_queries | 'Regress' >> _Regress( + fixed_inference_spec_type=fixed_inference_spec_type) + elif operation_type == OperationType.PREDICTION: + predictions = batched_queries | 'Predict' >> _Predict( + fixed_inference_spec_type=fixed_inference_spec_type) + elif operation_type == OperationType.MULTIHEAD: + predictions = (batched_queries | 'MultiInference' >> _MultiInference( + fixed_inference_spec_type=fixed_inference_spec_type)) + else: + raise ValueError('Unsupported operation_type %s' % operation_type) + + return predictions @beam.ptransform_fn -@beam.typehints.with_input_types(Union[tf.train.Example, - tf.train.SequenceExample]) -@beam.typehints.with_output_types(prediction_log_pb2.PredictionLog) -def _Regress(pcoll: beam.pvalue.PCollection, # pylint: disable=invalid-name - inference_spec_type: model_spec_pb2.InferenceSpecType): - """Performs regress PTransform.""" - if _using_in_process_inference(inference_spec_type): - return (pcoll - | 'Regress' >> beam.ParDo( - _BatchRegressDoFn(inference_spec_type, shared.Shared())) - | 'BuildPredictionLogForRegressions' >> beam.ParDo( - _BuildPredictionLogForRegressionsDoFn())) - else: - raise NotImplementedError +@beam.typehints.with_input_types(QueryType) +@beam.typehints.with_output_types(_QueryBatchType) +def _BatchQueries(queries: beam.pvalue.PCollection) -> beam.pvalue.PCollection: + """Groups queries into batches.""" + + def _add_key(query: QueryType) -> Tuple[bytes, QueryType]: + """Adds serialized proto as key for QueryType tuples.""" + inference_spec, example = query + key = (inference_spec.SerializeToString() if inference_spec else b'') + return (key, (inference_spec, example)) + + def _to_query_batch( + query_list: Tuple[bytes, List[QueryType]] + ) -> _QueryBatchType: + """Converts a list of queries to a logical _QueryBatch.""" + inference_spec = query_list[1][0][0] + examples = [x[1] for x in query_list[1]] + return (inference_spec, examples) + + batches = ( + queries + | 'Serialize inference_spec as key' >> beam.Map(_add_key) + # TODO(hgarrereyn): GroupIntoBatches with automatic batch sizes + | 'Batch' >> beam.GroupIntoBatches(1000) + | 'Convert to QueryBatch' >> beam.Map(_to_query_batch) + ) + return batches @beam.ptransform_fn -@beam.typehints.with_input_types(Union[tf.train.Example, - tf.train.SequenceExample]) -@beam.typehints.with_output_types(prediction_log_pb2.PredictionLog) -def _Predict(pcoll: beam.pvalue.PCollection, # pylint: disable=invalid-name - inference_spec_type: model_spec_pb2.InferenceSpecType): - """Performs predict PTransform.""" - if _using_in_process_inference(inference_spec_type): - predictions = ( - pcoll - | 'Predict' >> beam.ParDo( - _BatchPredictDoFn(inference_spec_type, shared.Shared()))) - else: - predictions = ( - pcoll - | 'RemotePredict' >> beam.ParDo( - _RemotePredictDoFn(inference_spec_type, pcoll.pipeline.options))) - return (predictions - | 'BuildPredictionLogForPredictions' >> beam.ParDo( - _BuildPredictionLogForPredictionsDoFn())) +@beam.typehints.with_input_types(QueryType) +@beam.typehints.with_output_types(_QueryBatchType) +def _BatchQueriesSimple( + queries: beam.pvalue.PCollection) -> beam.pvalue.PCollection: + """Groups queries into batches. + + This version of _BatchQueries uses beam.BatchElements and works in runners + that do not support stateful DoFn's. However, in this case we need to make + the assumption that all queries share the same inference_spec. + """ + + def _to_query_batch(query_list: List[QueryType]) -> _QueryBatchType: + """Converts a list of queries to a logical _QueryBatch.""" + inference_spec = query_list[0][0] + examples = [x[1] for x in query_list] + return (inference_spec, examples) + + batches = ( + queries + | 'Batch' >> beam.BatchElements() + | 'ToQueryBatch' >> beam.Map(_to_query_batch) + ) + return batches @beam.ptransform_fn -@beam.typehints.with_input_types(Union[tf.train.Example, - tf.train.SequenceExample]) -@beam.typehints.with_output_types(prediction_log_pb2.PredictionLog) -def _MultiInference(pcoll: beam.pvalue.PCollection, # pylint: disable=invalid-name - inference_spec_type: model_spec_pb2.InferenceSpecType): - """Performs multi inference PTransform.""" - if _using_in_process_inference(inference_spec_type): - return ( - pcoll - | 'MultiInference' >> beam.ParDo( - _BatchMultiInferenceDoFn(inference_spec_type, shared.Shared())) - | 'BuildMultiInferenceLog' >> beam.ParDo(_BuildMultiInferenceLogDoFn())) - else: - raise NotImplementedError +@beam.typehints.with_input_types(_QueryBatchType) +@beam.typehints.with_output_types(_QueryBatchType) +def _SplitByOperation(batches): + """A PTransform that splits a _QueryBatchType PCollection based on operation. + + Benchmarks demonstrated that this transform was a bottleneck (comprising + nearly 25% of the total RunInference walltime) since looking up the operation + type requires reading the saved model signature from disk. To improve + performance, we use a caching layer inside each DoFn instance that saves a + mapping of: + + {inference_spec.SerializeToString(): operation_type} + + In practice this cache reduces _SplitByOperation walltime by more than 90%. + + Returns a DoOutputsTuple with keys: + - OperationType.CLASSIFICATION + - OperationType.REGRESSION + - OperationType.PREDICTION + - OperationType.MULTIHEAD + + Raises: + ValueError: If any inference_spec_type is None. + """ + class _SplitDoFn(beam.DoFn): + def __init__(self): + self._cache = {} + + def process(self, batch): + inference_spec, _ = batch + + if inference_spec is None: + raise ValueError("InferenceSpecType cannot be None.") + + key = inference_spec.SerializeToString() + operation_type = self._cache.get(key) + + if operation_type is None: + operation_type = _get_operation_type(inference_spec) + self._cache[key] = operation_type + + return [beam.pvalue.TaggedOutput(operation_type, batch)] + + return ( + batches + | 'SplitDoFn' >> beam.ParDo(_SplitDoFn()).with_outputs( + OperationType.CLASSIFICATION, + OperationType.REGRESSION, + OperationType.PREDICTION, + OperationType.MULTIHEAD + )) + + +_IOTensorSpec = collections.namedtuple( + '_IOTensorSpec', + ['input_tensor_alias', 'input_tensor_name', 'output_alias_tensor_names']) + +_Signature = collections.namedtuple('_Signature', ['name', 'signature_def']) @six.add_metaclass(abc.ABCMeta) @@ -219,14 +366,17 @@ class _BaseDoFn(beam.DoFn): class _MetricsCollector(object): """A collector for beam metrics.""" - def __init__(self, inference_spec_type: model_spec_pb2.InferenceSpecType): - operation_type = _get_operation_type(inference_spec_type) - proximity_descriptor = ( - _METRICS_DESCRIPTOR_IN_PROCESS - if _using_in_process_inference(inference_spec_type) else - _METRICS_DESCRIPTOR_CLOUD_AI_PREDICTION) - namespace = util.MakeTfxNamespace( - [_METRICS_DESCRIPTOR_INFERENCE, operation_type, proximity_descriptor]) + def __init__(self, operation_type: Text, proximity_descriptor: Text): + """Initializes a metrics collector. + + Args: + operation_type: A string describing the type of operation, e.g. + "CLASSIFICATION". + proximity_descriptor: A string describing the location of inference, + e.g. "InProcess". + """ + namespace = util.MakeTfxNamespace([ + _METRICS_DESCRIPTOR_INFERENCE, operation_type, proximity_descriptor]) # Metrics self._inference_counter = beam.metrics.Metrics.counter( @@ -249,21 +399,45 @@ def __init__(self, inference_spec_type: model_spec_pb2.InferenceSpecType): namespace, 'load_model_latency_milli_secs') # Metrics cache - self.load_model_latency_milli_secs_cache = None - self.model_byte_size_cache = None + self._load_model_latency_milli_secs_cache = None + self._model_byte_size_cache = None + + def commit_cached_metrics(self): + """Updates any cached metrics. - def update_metrics_with_cache(self): - if self.load_model_latency_milli_secs_cache is not None: + If there are no cached metrics, this has no effect. Cached metrics are + automatically cleared after use. + """ + if self._load_model_latency_milli_secs_cache is not None: self._load_model_latency_milli_secs.update( - self.load_model_latency_milli_secs_cache) - self.load_model_latency_milli_secs_cache = None - if self.model_byte_size_cache is not None: - self._model_byte_size.update(self.model_byte_size_cache) - self.model_byte_size_cache = None - - def update(self, elements: List[Union[tf.train.Example, - tf.train.SequenceExample]], - latency_micro_secs: int) -> None: + self._load_model_latency_milli_secs_cache) + self._load_model_latency_milli_secs_cache = None + if self._model_byte_size_cache is not None: + self._model_byte_size.update(self._model_byte_size_cache) + self._model_byte_size_cache = None + + def update_model_load( + self, load_model_latency_milli_secs: int, model_byte_size: int): + """Updates model loading metrics. + + Note: To commit model loading metrics, you must call + commit_cached_metrics() after storing values with this method. + + Args: + load_model_latency_milli_secs: Model loading latency in milliseconds. + model_byte_size: Approximate model size in bytes. + """ + self._load_model_latency_milli_secs_cache = load_model_latency_milli_secs + self._model_byte_size_cache = model_byte_size + + def update_inference( + self, elements: List[ExampleType], latency_micro_secs: int) -> None: + """Updates inference metrics. + + Args: + elements: A list of examples used for inference. + latency_micro_secs: Total inference latency in microseconds. + """ self._inference_batch_latency_micro_secs.update(latency_micro_secs) self._num_instances.inc(len(elements)) self._inference_counter.inc(len(elements)) @@ -271,37 +445,35 @@ def update(self, elements: List[Union[tf.train.Example, self._inference_request_batch_byte_size.update( sum(element.ByteSize() for element in elements)) - def __init__(self, inference_spec_type: model_spec_pb2.InferenceSpecType): + def __init__(self, operation_type: Text, proximity_descriptor: Text): super(_BaseDoFn, self).__init__() self._clock = None - self._metrics_collector = self._MetricsCollector(inference_spec_type) + self._metrics_collector = self._MetricsCollector( + operation_type, proximity_descriptor) def setup(self): self._clock = _ClockFactory.make_clock() - def process( - self, elements: List[Union[tf.train.Example, tf.train.SequenceExample]] - ) -> Iterable[Any]: + def process(self, batch: _QueryBatchType) -> Iterable[Any]: + inference_spec, elements = batch batch_start_time = self._clock.get_current_time_in_microseconds() - outputs = self.run_inference(elements) + outputs = self.run_inference(inference_spec, elements) result = self._post_process(elements, outputs) - self._metrics_collector.update( + self._metrics_collector.update_inference( elements, self._clock.get_current_time_in_microseconds() - batch_start_time) return result - def finish_bundle(self): - self._metrics_collector.update_metrics_with_cache() - @abc.abstractmethod def run_inference( - self, elements: List[Union[tf.train.Example, tf.train.SequenceExample]] + self, + inference_spec: model_spec_pb2.InferenceSpecType, + elements: List[ExampleType] ) -> Union[Mapping[Text, np.ndarray], Sequence[Mapping[Text, Any]]]: raise NotImplementedError @abc.abstractmethod - def _post_process(self, elements: List[Union[tf.train.Example, - tf.train.SequenceExample]], + def _post_process(self, elements: List[ExampleType], outputs: Any) -> Iterable[Any]: raise NotImplementedError @@ -322,8 +494,7 @@ def _retry_on_unavailable_and_resource_error_filter(exception: Exception): exception.resp.status in (503, 429)) -@beam.typehints.with_input_types(List[Union[tf.train.Example, - tf.train.SequenceExample]]) +@beam.typehints.with_input_types(_QueryBatchType) # Using output typehints triggers NotImplementedError('BEAM-2717)' on # streaming mode on Dataflow runner. # TODO(b/151468119): Consider to re-batch with online serving request size @@ -349,16 +520,34 @@ class _RemotePredictDoFn(_BaseDoFn): without having access to cloud-hosted model's signatures. """ - def __init__(self, inference_spec_type: model_spec_pb2.InferenceSpecType, - pipeline_options: PipelineOptions): - super(_RemotePredictDoFn, self).__init__(inference_spec_type) + def __init__( + self, + pipeline_options: PipelineOptions, + fixed_inference_spec_type: model_spec_pb2.InferenceSpecType = None + ): + super(_RemotePredictDoFn, self).__init__( + OperationType.PREDICTION, _METRICS_DESCRIPTOR_CLOUD_AI_PREDICTION) + self._pipeline_options = pipeline_options + self._fixed_inference_spec_type = fixed_inference_spec_type + + self._ai_platform_prediction_model_spec = None + self._api_client = None + self._full_model_name = None + + def setup(self): + super(_RemotePredictDoFn, self).setup() + if self._fixed_inference_spec_type: + self._setup_model(self._fixed_inference_spec_type) + + def _setup_model( + self, inference_spec_type: model_spec_pb2.InferenceSpecType + ): self._ai_platform_prediction_model_spec = ( inference_spec_type.ai_platform_prediction_model_spec) - self._api_client = None project_id = ( inference_spec_type.ai_platform_prediction_model_spec.project_id or - pipeline_options.view_as(GoogleCloudOptions).project) + self._pipeline_options.view_as(GoogleCloudOptions).project) if not project_id: raise ValueError('Either a non-empty project id or project flag in ' ' beam pipeline options needs be provided.') @@ -377,8 +566,6 @@ def __init__(self, inference_spec_type: model_spec_pb2.InferenceSpecType, self._full_model_name = name_spec.format(project_id, model_name, version_name) - def setup(self): - super(_RemotePredictDoFn, self).setup() # TODO(b/151468119): Add tfx_bsl_version and tfx_bsl_py_version to # user agent once custom header is supported in googleapiclient. self._api_client = discovery.build('ml', 'v1') @@ -401,7 +588,7 @@ def _make_request(self, body: Mapping[Text, List[Any]]) -> http.HttpRequest: name=self._full_model_name, body=body) def _prepare_instances_dict( - self, elements: List[tf.train.Example] + self, elements: List[ExampleType] ) -> Generator[Mapping[Text, Any], None, None]: """Prepare instances by converting features to dictionary.""" for example in elements: @@ -423,14 +610,14 @@ def _prepare_instances_dict( yield instance def _prepare_instances_serialized( - self, elements: List[tf.train.Example] + self, elements: List[ExampleType] ) -> Generator[Mapping[Text, Text], None, None]: """Prepare instances by base64 encoding serialized examples.""" for example in elements: yield {'b64': base64.b64encode(example.SerializeToString()).decode()} def _prepare_instances( - self, elements: List[tf.train.Example] + self, elements: List[ExampleType] ) -> Generator[Mapping[Text, Any], None, None]: if self._ai_platform_prediction_model_spec.use_serialization_config: return self._prepare_instances_serialized(elements) @@ -465,16 +652,19 @@ def _parse_feature_content(values: Sequence[Any], attr_name: Text, return list(values) def run_inference( - self, elements: List[Union[tf.train.Example, tf.train.SequenceExample]] + self, + inference_spec: model_spec_pb2.InferenceSpecType, + elements: List[ExampleType] ) -> Sequence[Mapping[Text, Any]]: + if not self._fixed_inference_spec_type: + self._setup_model(inference_spec) body = {'instances': list(self._prepare_instances(elements))} request = self._make_request(body) response = self._execute_request(request) return response['predictions'] def _post_process( - self, elements: List[Union[tf.train.Example, tf.train.SequenceExample]], - outputs: Sequence[Mapping[Text, Any]] + self, elements: List[ExampleType], outputs: Sequence[Mapping[Text, Any]] ) -> Iterable[prediction_log_pb2.PredictLog]: result = [] for output in outputs: @@ -507,38 +697,67 @@ class _BaseBatchSavedModelDoFn(_BaseDoFn): def __init__( self, - inference_spec_type: model_spec_pb2.InferenceSpecType, shared_model_handle: shared.Shared, + fixed_inference_spec_type: model_spec_pb2.InferenceSpecType = None, + operation_type: Text = '' ): - super(_BaseBatchSavedModelDoFn, self).__init__(inference_spec_type) - self._inference_spec_type = inference_spec_type + super(_BaseBatchSavedModelDoFn, self).__init__( + operation_type, _METRICS_DESCRIPTOR_IN_PROCESS) self._shared_model_handle = shared_model_handle - self._model_path = inference_spec_type.saved_model_spec.model_path + self._fixed_inference_spec_type = fixed_inference_spec_type + + self._model_path = None self._tags = None - self._signatures = _get_signatures( - inference_spec_type.saved_model_spec.model_path, - inference_spec_type.saved_model_spec.signature_name, - _get_tags(inference_spec_type)) + self._signatures = None self._session = None self._io_tensor_spec = None def setup(self): """Load the model. - Note that worker may crash if exception is thrown in setup due to b/139207285. """ - super(_BaseBatchSavedModelDoFn, self).setup() - self._tags = _get_tags(self._inference_spec_type) + if self._fixed_inference_spec_type: + self._setup_model(self._fixed_inference_spec_type) + + def finish_bundle(self): + # If we are using a fixed model, _setup_model will be called in DoFn.setup + # and model loading metrics will be cached. To commit these metrics, we + # need to call _metrics_collector.commit_cached_metrics() once during the + # DoFn lifetime. DoFn.teardown() is not guaranteed to be called, so the + # next best option is to call this in finish_bundle(). + if self._fixed_inference_spec_type: + self._metrics_collector.commit_cached_metrics() + + def _setup_model( + self, inference_spec_type: model_spec_pb2.InferenceSpecType + ): + self._model_path = inference_spec_type.saved_model_spec.model_path + self._signatures = _get_signatures( + inference_spec_type.saved_model_spec.model_path, + inference_spec_type.saved_model_spec.signature_name, + _get_tags(inference_spec_type)) + + self._validate_model() + + self._tags = _get_tags(inference_spec_type) self._io_tensor_spec = self._pre_process() if self._has_tpu_tag(): # TODO(b/131873699): Support TPU inference. raise ValueError('TPU inference is not supported yet.') - self._session = self._load_model() + self._session = self._load_model(inference_spec_type) + + def _validate_model(self): + """Optional subclass model validation hook. + + Raises: + ValueError: if model is invalid. + """ + pass - def _load_model(self): + def _load_model(self, inference_spec_type: model_spec_pb2.InferenceSpecType): """Load a saved model into memory. Returns: @@ -554,15 +773,20 @@ def load(): tf.compat.v1.saved_model.loader.load(result, self._tags, self._model_path) end_time = self._clock.get_current_time_in_microseconds() memory_after = _get_current_process_memory_in_bytes() - self._metrics_collector.load_model_latency_milli_secs_cache = ( - (end_time - start_time) / _MILLISECOND_TO_MICROSECOND) - self._metrics_collector.model_byte_size_cache = ( - memory_after - memory_before) + + # Compute model loading metrics. + load_model_latency_milli_secs = ( + (end_time - start_time) / _MILLISECOND_TO_MICROSECOND) + model_byte_size = (memory_after - memory_before) + self._metrics_collector.update_model_load( + load_model_latency_milli_secs, model_byte_size) + return result if not self._model_path: raise ValueError('Model path is not valid.') - return self._shared_model_handle.acquire(load) + return self._shared_model_handle.acquire( + load, tag=inference_spec_type.SerializeToString().decode('latin-1')) def _pre_process(self) -> _IOTensorSpec: # Pre process functions will validate for each signature. @@ -600,14 +824,19 @@ def _has_tpu_tag(self) -> bool: tf.saved_model.TPU in self._tags) def run_inference( - self, elements: List[Union[tf.train.Example, tf.train.SequenceExample]] + self, + inference_spec_type: model_spec_pb2.InferenceSpecType, + elements: List[ExampleType] ) -> Mapping[Text, np.ndarray]: + if not self._fixed_inference_spec_type: + self._setup_model(inference_spec_type) + self._metrics_collector.commit_cached_metrics() self._check_elements(elements) outputs = self._run_tf_operations(elements) return outputs def _run_tf_operations( - self, elements: List[Union[tf.train.Example, tf.train.SequenceExample]] + self, elements: List[ExampleType] ) -> Mapping[Text, np.ndarray]: input_values = [] for element in elements: @@ -619,93 +848,111 @@ def _run_tf_operations( raise RuntimeError('Output length does not match fetches') return result - def _check_elements( - self, elements: List[Union[tf.train.Example, - tf.train.SequenceExample]]) -> None: + def _check_elements(self, elements: List[ExampleType]) -> None: """Unimplemented.""" raise NotImplementedError -@beam.typehints.with_input_types(List[Union[tf.train.Example, - tf.train.SequenceExample]]) +@beam.typehints.with_input_types(_QueryBatchType) @beam.typehints.with_output_types(Tuple[tf.train.Example, classification_pb2.Classifications]) class _BatchClassifyDoFn(_BaseBatchSavedModelDoFn): """A DoFn that run inference on classification model.""" - def setup(self): + def __init__( + self, + shared_model_handle: shared.Shared, + fixed_inference_spec_type: model_spec_pb2.InferenceSpecType = None + ): + super(_BatchClassifyDoFn, self).__init__( + shared_model_handle, fixed_inference_spec_type, + OperationType.CLASSIFICATION) + + def _validate_model(self): signature_def = self._signatures[0].signature_def if signature_def.method_name != tf.saved_model.CLASSIFY_METHOD_NAME: raise ValueError( 'BulkInferrerClassifyDoFn requires signature method ' 'name %s, got: %s' % tf.saved_model.CLASSIFY_METHOD_NAME, signature_def.method_name) - super(_BatchClassifyDoFn, self).setup() def _check_elements( - self, elements: List[Union[tf.train.Example, - tf.train.SequenceExample]]) -> None: + self, elements: List[ExampleType]) -> None: if not all(isinstance(element, tf.train.Example) for element in elements): raise ValueError('Classify only supports tf.train.Example') def _post_process( - self, elements: Sequence[tf.train.Example], outputs: Mapping[Text, - np.ndarray] + self, elements: Sequence[ExampleType], outputs: Mapping[Text, np.ndarray] ) -> Iterable[Tuple[tf.train.Example, classification_pb2.Classifications]]: classifications = _post_process_classify( self._io_tensor_spec.output_alias_tensor_names, elements, outputs) return zip(elements, classifications) -@beam.typehints.with_input_types(List[Union[tf.train.Example, - tf.train.SequenceExample]]) +@beam.typehints.with_input_types(_QueryBatchType) @beam.typehints.with_output_types(Tuple[tf.train.Example, regression_pb2.Regression]) class _BatchRegressDoFn(_BaseBatchSavedModelDoFn): """A DoFn that run inference on regression model.""" - def setup(self): - super(_BatchRegressDoFn, self).setup() + def __init__( + self, + shared_model_handle: shared.Shared, + fixed_inference_spec_type: model_spec_pb2.InferenceSpecType = None + ): + super(_BatchRegressDoFn, self).__init__( + shared_model_handle, fixed_inference_spec_type, + OperationType.REGRESSION) + + def _validate_model(self): + signature_def = self._signatures[0].signature_def + if signature_def.method_name != tf.saved_model.REGRESS_METHOD_NAME: + raise ValueError( + 'BulkInferrerRegressDoFn requires signature method ' + 'name %s, got: %s' % tf.saved_model.REGRESS_METHOD_NAME, + signature_def.method_name) def _check_elements( - self, elements: List[Union[tf.train.Example, - tf.train.SequenceExample]]) -> None: + self, elements: List[ExampleType]) -> None: if not all(isinstance(element, tf.train.Example) for element in elements): raise ValueError('Regress only supports tf.train.Example') def _post_process( - self, elements: Sequence[tf.train.Example], outputs: Mapping[Text, - np.ndarray] + self, elements: Sequence[ExampleType], outputs: Mapping[Text, np.ndarray] ) -> Iterable[Tuple[tf.train.Example, regression_pb2.Regression]]: regressions = _post_process_regress(elements, outputs) return zip(elements, regressions) -@beam.typehints.with_input_types(List[Union[tf.train.Example, - tf.train.SequenceExample]]) +@beam.typehints.with_input_types(_QueryBatchType) @beam.typehints.with_output_types(prediction_log_pb2.PredictLog) class _BatchPredictDoFn(_BaseBatchSavedModelDoFn): """A DoFn that runs inference on predict model.""" - def setup(self): + def __init__( + self, + shared_model_handle: shared.Shared, + fixed_inference_spec_type: model_spec_pb2.InferenceSpecType = None + ): + super(_BatchPredictDoFn, self).__init__( + shared_model_handle, fixed_inference_spec_type, + OperationType.PREDICTION) + + def _validate_model(self): signature_def = self._signatures[0].signature_def if signature_def.method_name != tf.saved_model.PREDICT_METHOD_NAME: raise ValueError( 'BulkInferrerPredictDoFn requires signature method ' 'name %s, got: %s' % tf.saved_model.PREDICT_METHOD_NAME, signature_def.method_name) - super(_BatchPredictDoFn, self).setup() def _check_elements( - self, elements: List[Union[tf.train.Example, - tf.train.SequenceExample]]) -> None: + self, elements: List[ExampleType]) -> None: pass def _post_process( - self, elements: Union[Sequence[tf.train.Example], - Sequence[tf.train.SequenceExample]], - outputs: Mapping[Text, np.ndarray] + self, elements: Sequence[ExampleType], outputs: Mapping[Text, np.ndarray] ) -> Iterable[prediction_log_pb2.PredictLog]: input_tensor_alias = self._io_tensor_spec.input_tensor_alias signature_name = self._signatures[0].name @@ -741,22 +988,28 @@ def _post_process( return result -@beam.typehints.with_input_types(List[Union[tf.train.Example, - tf.train.SequenceExample]]) +@beam.typehints.with_input_types(_QueryBatchType) @beam.typehints.with_output_types(Tuple[tf.train.Example, inference_pb2.MultiInferenceResponse]) class _BatchMultiInferenceDoFn(_BaseBatchSavedModelDoFn): """A DoFn that runs inference on multi-head model.""" + def __init__( + self, + shared_model_handle: shared.Shared, + fixed_inference_spec_type: model_spec_pb2.InferenceSpecType = None + ): + super(_BatchMultiInferenceDoFn, self).__init__( + shared_model_handle, fixed_inference_spec_type, + OperationType.MULTIHEAD) + def _check_elements( - self, elements: List[Union[tf.train.Example, - tf.train.SequenceExample]]) -> None: + self, elements: List[ExampleType]) -> None: if not all(isinstance(element, tf.train.Example) for element in elements): raise ValueError('Multi inference only supports tf.train.Example') def _post_process( - self, elements: Sequence[tf.train.Example], outputs: Mapping[Text, - np.ndarray] + self, elements: Sequence[ExampleType], outputs: Mapping[Text, np.ndarray] ) -> Iterable[Tuple[tf.train.Example, inference_pb2.MultiInferenceResponse]]: classifications = None regressions = None @@ -862,6 +1115,99 @@ def process( yield result +def _BuildInferenceOperation( + name: str, + in_process_dofn: _BaseBatchSavedModelDoFn, + remote_dofn: Optional[_BaseDoFn], + build_prediction_log_dofn: beam.DoFn +): + """Construct an operation specific inference sub-pipeline. + + Args: + name: Name of the operation (e.g. "Classify"). + in_process_dofn: A _BaseBatchSavedModelDoFn class to use for in-process + inference. + remote_dofn: An optional DoFn that is used for remote inference. If not + provided, attempts at remote inference will throw a NotImplementedError. + build_prediction_log_dofn: A DoFn that can build prediction logs from the + output of `in_process_dofn` and `remote_dofn`. + + Returns: + A PTransform of the type (_QueryBatchType -> PredictionLog). + + Raises: + NotImplementedError: if remote inference is attempted and not supported. + """ + @beam.ptransform_fn + @beam.typehints.with_input_types(_QueryBatchType) + @beam.typehints.with_output_types(prediction_log_pb2.PredictionLog) + def _Op( + pcoll: beam.pvalue.PCollection, + fixed_inference_spec_type: model_spec_pb2.InferenceSpecType = None + ): # pylint: disable=invalid-name + raw_result = None + + if fixed_inference_spec_type is None: + tagged = pcoll | ('TagInferenceType%s' % name) >> _TagUsingInProcessInference() + + in_process_result = ( + tagged['in_process'] + | ('InProcess%s' % name) >> beam.ParDo( + in_process_dofn(shared.Shared()))) + + if remote_dofn: + remote_result = ( + tagged['remote'] + | ('Remote%s' % name) >> beam.ParDo( + remote_dofn(pcoll.pipeline.options))) + + raw_result = ( + [in_process_result, remote_result] + | 'FlattenResult' >> beam.Flatten()) + else: + tagged['remote'] | 'NotImplemented' >> _NotImplementedTransform( + 'Remote inference is not supported for operation type: %s' % name) + raw_result = in_process_result + else: + if _using_in_process_inference(fixed_inference_spec_type): + raw_result = ( + pcoll + | ('InProcess%s' % name) >> beam.ParDo(in_process_dofn( + shared.Shared(), + fixed_inference_spec_type=fixed_inference_spec_type))) + else: + if remote_dofn: + raw_result = ( + pcoll + | ('Remote%s' % name) >> beam.ParDo(remote_dofn( + pcoll.pipeline.options, + fixed_inference_spec_type=fixed_inference_spec_type))) + else: + raise NotImplementedError('Remote inference is not supported for' + 'operation type: %s' % name) + + return ( + raw_result + | ('BuildPredictionLogFor%s' % name) >> beam.ParDo( + build_prediction_log_dofn())) + + return _Op + + +_Classify = _BuildInferenceOperation( + 'Classify', _BatchClassifyDoFn, None, + _BuildPredictionLogForClassificationsDoFn) +_Regress = _BuildInferenceOperation( + 'Regress', _BatchRegressDoFn, None, + _BuildPredictionLogForRegressionsDoFn) +_Predict = _BuildInferenceOperation( + 'Predict', _BatchPredictDoFn, _RemotePredictDoFn, + _BuildPredictionLogForPredictionsDoFn) +_MultiInference = _BuildInferenceOperation( + 'MultiInference', _BatchMultiInferenceDoFn, None, + _BuildMultiInferenceLogDoFn) + + def _post_process_classify( output_alias_tensor_names: Mapping[Text, Text], elements: Sequence[tf.train.Example], outputs: Mapping[Text, np.ndarray] @@ -1070,6 +1416,27 @@ def _using_in_process_inference( return inference_spec_type.WhichOneof('type') == 'saved_model_spec' +@beam.ptransform_fn +@beam.typehints.with_input_types(_QueryBatchType) +@beam.typehints.with_output_types(_QueryBatchType) +def _TagUsingInProcessInference( + queries: beam.pvalue.PCollection) -> beam.pvalue.DoOutputsTuple: + """Tags each query batch with 'in_process' or 'remote'.""" + return queries | 'TagBatches' >> beam.Map( + lambda query: beam.pvalue.TaggedOutput( + 'in_process' if _using_in_process_inference(query[0]) else 'remote', query) + ).with_outputs('in_process', 'remote') + + +@beam.ptransform_fn +def _NotImplementedTransform( + pcoll: beam.pvalue.PCollection, message: Text = ''): + """Raises NotImplementedError for each value in the input PCollection.""" + def _raise(x): + raise NotImplementedError(message) + pcoll | beam.Map(_raise) + + def _get_signatures(model_path: Text, signatures: Sequence[Text], tags: Sequence[Text]) -> Sequence[_Signature]: """Returns a sequence of {model_signature_name: signature}.""" @@ -1161,6 +1528,13 @@ def _get_tags( return [tf.saved_model.SERVING] +def _runner_supports_stateful_dofn( + runner: beam.pipeline.PipelineRunner) -> bool: + """Returns True if if the provided runner supports stateful DoFn's.""" + # TODO: Implement. + return True + + def _is_darwin() -> bool: return sys.platform == 'darwin' @@ -1195,3 +1569,73 @@ def make_clock() -> _Clock: and not _is_cygwin()): return _FineGrainedClock() return _Clock() + + +class _TemporalJoinStream(object): + PRIMARY = 1 + SECONDARY = 2 + + +class _TemporalJoinDoFn(beam.DoFn): + """A stateful DoFn that joins two PCollection streams. + + CACHE: holds the most recent item from the secondary stream + EARLY_PRIMARY: holds any items from the primary stream received before the + first item from the secondary stream + """ + CACHE = beam.transforms.userstate.CombiningValueStateSpec( + 'cache', combine_fn=beam.combiners.ToListCombineFn()) + + EARLY_PRIMARY = beam.transforms.userstate.CombiningValueStateSpec( + 'early_primary', combine_fn=beam.combiners.ToListCombineFn()) + + def process( + self, + x, + cache=beam.DoFn.StateParam(CACHE), + early_primary=beam.DoFn.StateParam(EARLY_PRIMARY) + ): + key, tup = x + value, stream_type = tup + + if stream_type == _TemporalJoinStream.PRIMARY: + cached = cache.read() + if len(cached) == 0: + # accumulate in early_primary + early_primary.add(value) + else: + return [(cached[0], value)] + elif stream_type == _TemporalJoinStream.SECONDARY: + cache.clear() + cache.add(value) + + # dump any cached values from primary + primary = early_primary.read() + if len(primary) > 0: + early_primary.clear() + return [(value, x) for x in primary] + else: + return [] + + +@beam.ptransform_fn +def _TemporalJoin(primary, secondary): + """Performs a temporal join of two PCollections. + + Returns tuples of the type (b,a) where: + - a is from the primary stream + - b is from the secondary stream + - b is the most recent item at the time a is processed (or a was processed + before b and b is the first item in the stream) + """ + tag_primary = primary | 'primary' >> beam.Map( + lambda x: (x, _TemporalJoinStream.PRIMARY)) + tag_secondary = secondary | 'secondary' >> beam.Map( + lambda x: (x, _TemporalJoinStream.SECONDARY)) + + joined = [tag_primary, tag_secondary] \ + | beam.Flatten() \ + | 'Fake keys' >> beam.Map(lambda x: (0,x)) \ + | 'Join' >> beam.ParDo(_TemporalJoinDoFn()) + + return joined diff --git a/tfx_bsl/beam/run_inference_test.py b/tfx_bsl/beam/run_inference_test.py index 5fb9adad..5cd156c1 100644 --- a/tfx_bsl/beam/run_inference_test.py +++ b/tfx_bsl/beam/run_inference_test.py @@ -28,6 +28,10 @@ import apache_beam as beam from apache_beam.metrics.metric import MetricsFilter +from apache_beam.options import pipeline_options +from apache_beam.testing.test_stream import ElementEvent +from apache_beam.testing.test_stream import ProcessingTimeEvent +from apache_beam.testing.test_stream import TestStream from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from googleapiclient import discovery @@ -71,6 +75,73 @@ def _prepare_predict_examples(self, example_path): for example in self._predict_examples: output_file.write(example.SerializeToString()) + def _get_results(self, prediction_log_path): + results = [] + for f in tf.io.gfile.glob(prediction_log_path + '-?????-of-?????'): + record_iterator = tf.compat.v1.io.tf_record_iterator(path=f) + for record_string in record_iterator: + prediction_log = prediction_log_pb2.PredictionLog() + prediction_log.MergeFromString(record_string) + results.append(prediction_log) + return results + + def _build_keras_model(self, add): + """Builds a dummy keras model with one input and output.""" + inp = tf.keras.layers.Input((1,), name='input') + out = tf.keras.layers.Lambda(lambda x: x + add)(inp) + m = tf.keras.models.Model(inp, out) + return m + + def _new_model(self, model_path, add): + """Exports a keras model in the SavedModel format.""" + class WrapKerasModel(tf.keras.Model): + """Wrapper class to apply a signature to a keras model.""" + def __init__(self, model): + super().__init__() + self.model = model + + @tf.function(input_signature=[ + tf.TensorSpec(shape=[None], dtype=tf.string, name='inputs') + ]) + def call(self, serialized_example): + features = { + 'input': tf.compat.v1.io.FixedLenFeature( + [1], + dtype=tf.float32, + default_value=0 + ) + } + input_tensor_dict = tf.io.parse_example(serialized_example, features) + return {'output': self.model(input_tensor_dict)} + + model = self._build_keras_model(add) + wrapped_model = WrapKerasModel(model) + tf.compat.v1.keras.experimental.export_saved_model( + wrapped_model, model_path, serving_only=True + ) + return self._get_saved_model_spec(model_path) + + def _decode_value(self, pl): + """Returns output value from prediction log.""" + out_tensor = pl.predict_log.response.outputs['output'] + arr = tf.make_ndarray(out_tensor) + x = arr[0][0] + return x + + def _make_example(self, x): + """Builds a TFExample object with a single value.""" + feature = {} + feature['input'] = tf.train.Feature( + float_list=tf.train.FloatList(value=[x])) + ex = tf.train.Example(features=tf.train.Features(feature=feature)) + return ex + + def _get_saved_model_spec(self, model_path): + """Returns an InferenceSpecType object for a saved model path.""" + return model_spec_pb2.InferenceSpecType( + saved_model_spec=model_spec_pb2.SavedModelSpec( + model_path=model_path)) + class RunOfflineInferenceTest(RunInferenceFixture): @@ -220,16 +291,6 @@ def _run_inference_with_beam(self, example_path, inference_spec_type, prediction_log_path, coder=beam.coders.ProtoCoder(prediction_log_pb2.PredictionLog))) - def _get_results(self, prediction_log_path): - results = [] - for f in tf.io.gfile.glob(prediction_log_path + '-?????-of-?????'): - record_iterator = tf.compat.v1.io.tf_record_iterator(path=f) - for record_string in record_iterator: - prediction_log = prediction_log_pb2.PredictionLog() - prediction_log.MergeFromString(record_string) - results.append(prediction_log) - return results - def testModelPathInvalid(self): example_path = self._get_output_data_dir('examples') self._prepare_predict_examples(example_path) @@ -609,7 +670,9 @@ def test_request_body_with_binary_data(self): project_id='test_project', model_name='test_model', version_name='test_version')) - remote_predict = run_inference._RemotePredictDoFn(inference_spec_type, None) + remote_predict = run_inference._RemotePredictDoFn( + None, fixed_inference_spec_type=inference_spec_type) + remote_predict._setup_model(remote_predict._fixed_inference_spec_type) result = list(remote_predict._prepare_instances([example])) self.assertEqual(result, [ { @@ -638,12 +701,208 @@ def test_request_serialized_example(self): model_name='test_model', version_name='test_version', use_serialization_config=True)) - remote_predict = run_inference._RemotePredictDoFn(inference_spec_type, None) + remote_predict = run_inference._RemotePredictDoFn( + None, fixed_inference_spec_type=inference_spec_type) + remote_predict._setup_model(remote_predict._fixed_inference_spec_type) result = list(remote_predict._prepare_instances([example])) self.assertEqual(result, [{ 'b64': base64.b64encode(example.SerializeToString()).decode() }]) +class RunInferenceCoreTest(RunInferenceFixture): + + def test_batch_queries_multiple_models(self): + spec1 = self._get_saved_model_spec('/example/model1') + spec2 = self._get_saved_model_spec('/example/model2') + + queries = [] + for i in range(100): + queries.append((spec1 if i % 2 == 0 else spec2, self._make_example(i))) + + correct = {example.SerializeToString(): spec for spec, example in queries} + + def _check_batch(batch): + """Assert examples are grouped with the correct inference spec.""" + spec, examples = batch + assert all([correct[x.SerializeToString()] == spec for x in examples]) + + with beam.Pipeline() as p: + queries = p | 'Build queries' >> beam.Create(queries) + batches = queries | '_BatchQueries' >> run_inference._BatchQueries() + + _ = batches | 'Check' >> beam.Map(_check_batch) + + def test_inference_on_queries(self): + spec = self._new_model(self._get_output_data_dir('model1'), 100) + predictions_path = self._get_output_data_dir('predictions') + queries = [(spec, self._make_example(i)) for i in range(10)] + + options = pipeline_options.PipelineOptions(streaming=False) + with beam.Pipeline(options=options) as p: + _ = ( + p + | 'Queries' >> beam.Create(queries) \ + | '_RunInferenceCore' >> run_inference._RunInferenceCore() \ + | 'WritePredictions' >> beam.io.WriteToTFRecord( + predictions_path, + coder=beam.coders.ProtoCoder(prediction_log_pb2.PredictionLog)) + ) + + results = self._get_results(predictions_path) + values = [int(self._decode_value(x)) for x in results] + self.assertEqual( + values, + [100,101,102,103,104,105,106,107,108,109] + ) + + +class RunStreamingModelInferenceTest(RunInferenceFixture): + + def setUp(self): + super(RunStreamingModelInferenceTest, self).setUp() + + def testLateModel(self): + """Single model specified dynamically.""" + inference_spec_type = self._new_model( + self._get_output_data_dir('model'), + 100 + ) + predictions_path = self._get_output_data_dir('predictions') + + with beam.Pipeline() as p: + model = p | 'Create model' >> beam.Create([inference_spec_type]) + examples = ( + p + | 'Create examples' >> beam.Create(range(20)) \ + | 'Convert to TFExample' >> beam.Map(self._make_example) + ) + _ = ( + examples + | 'RunInference' >> run_inference.RunInferenceImpl(model) \ + | 'WritePredictions' >> beam.io.WriteToTFRecord( + predictions_path, + coder=beam.coders.ProtoCoder(prediction_log_pb2.PredictionLog)) + ) + + results = self._get_results(predictions_path) + values = set([int(self._decode_value(x)) for x in results]) + self.assertEqual(values, set(range(100,120))) + + def testSeveralModels(self): + """Several models specified dynamically.""" + spec_1 = self._new_model(self._get_output_data_dir('model1'), 100) + spec_2 = self._new_model(self._get_output_data_dir('model2'), 200) + spec_3 = self._new_model(self._get_output_data_dir('model3'), 300) + + predictions_path = self._get_output_data_dir('predictions') + + h = TestStreamHelper() + h.add_stream( + 'examples', + range(20), + range(20) + ) + h.add_stream( + 'models', + [spec_1, spec_2, spec_3], + [4.5, 10.5, 15.5] + ) + stream = h.build() + + # TODO(hgarrereyn): this test doesn't work in streaming mode because + # records are never written + options = pipeline_options.PipelineOptions(streaming=False) + with beam.Pipeline(options=options) as p: + s = p | 'Stream' >> stream + models = s['models'] + examples = ( + s['examples'] + | 'Convert to TFExample' >> beam.Map(self._make_example) + ) + _ = ( + examples + | 'RunInference' >> run_inference.RunInferenceImpl(models) \ + | 'WritePredictions' >> beam.io.WriteToTFRecord( + predictions_path, + coder=beam.coders.ProtoCoder(prediction_log_pb2.PredictionLog)) + ) + + results = self._get_results(predictions_path) + values = set([int(self._decode_value(x)) for x in results]) + self.assertEqual( + values, + set([ + 100,101,102,103,104,105,106,107,108,109,110, + 211,212,213,214,215, + 316,317,318,319 + ]) + ) + + def testQueries(self): + """Several models specified in queries.""" + spec_1 = self._new_model(self._get_output_data_dir('model1'), 100) + spec_2 = self._new_model(self._get_output_data_dir('model2'), 200) + spec_3 = self._new_model(self._get_output_data_dir('model3'), 300) + + predictions_path = self._get_output_data_dir('predictions') + + queries = [ + (spec_1, self._make_example(0)), + (spec_2, self._make_example(1)), + (spec_3, self._make_example(2)), + (spec_1, self._make_example(3)), + (spec_2, self._make_example(4)), + (spec_3, self._make_example(5)), + (spec_1, self._make_example(6)), + (spec_2, self._make_example(7)), + (spec_3, self._make_example(8)), + (spec_1, self._make_example(9)), + (spec_2, self._make_example(10)), + (spec_3, self._make_example(11)), + ] + + # TODO(hgarrereyn): this test doesn't work in streaming mode because + # records are never written + options = pipeline_options.PipelineOptions(streaming=False) + with beam.Pipeline(options=options) as p: + _ = ( + p + | 'Queries' >> beam.Create(queries) \ + | 'RunInference' >> run_inference.RunInferenceImpl() \ + | 'WritePredictions' >> beam.io.WriteToTFRecord( + predictions_path, + coder=beam.coders.ProtoCoder(prediction_log_pb2.PredictionLog)) + ) + + results = self._get_results(predictions_path) + values = set([int(self._decode_value(x)) for x in results]) + self.assertEqual( + values, + set([ + 100,103,106,109, + 201,204,207,210, + 302,305,308,311 + ]) + ) + + +class TestStreamHelper(object): + """Helper object to build a test stream with tagged outputs.""" + def __init__(self): + self.events = [] + self.tags = set([None]) + + def add_stream(self, tag, values, timestamps): + self.tags.add(tag) + for v,ts in zip(values, timestamps): + self.events.append(ElementEvent( + tag=tag, timestamped_values=[beam.window.TimestampedValue(v,ts)])) + + def build(self): + events = sorted(self.events, key=lambda x: x.timestamped_values[0].timestamp) + return TestStream(events=events, output_tags=self.tags) + + if __name__ == '__main__': tf.test.main() diff --git a/tfx_bsl/public/beam/run_inference.py b/tfx_bsl/public/beam/run_inference.py index d27ab453..9f60dc96 100644 --- a/tfx_bsl/public/beam/run_inference.py +++ b/tfx_bsl/public/beam/run_inference.py @@ -22,18 +22,20 @@ import apache_beam as beam import tensorflow as tf from tfx_bsl.beam import run_inference +from tfx_bsl.beam.run_inference import ExampleType +from tfx_bsl.beam.run_inference import QueryType from tfx_bsl.public.proto import model_spec_pb2 from typing import Union from tensorflow_serving.apis import prediction_log_pb2 @beam.ptransform_fn -@beam.typehints.with_input_types(Union[tf.train.Example, - tf.train.SequenceExample]) +@beam.typehints.with_input_types(Union[ExampleType, QueryType]) @beam.typehints.with_output_types(prediction_log_pb2.PredictionLog) def RunInference( # pylint: disable=invalid-name examples: beam.pvalue.PCollection, - inference_spec_type: model_spec_pb2.InferenceSpecType + inference_spec_type: Union[model_spec_pb2.InferenceSpecType, + beam.pvalue.PCollection] = None ) -> beam.pvalue.PCollection: """Run inference with a model. @@ -44,14 +46,40 @@ def RunInference( # pylint: disable=invalid-name `ai_platform_prediction_model_spec` field is set in `inference_spec_type`. + The inference model can be specified in three ways: + 1. For a fixed inference model, provide an InferenceSpecType object as a + fixed parameter. This inference model will be used for all examples. + This form of the API has some performance benefits when running inference + locally as it is easier to cache the inference model. + 2. In a pipeline where the inference model may be updated at runtime, you can + specify a PCollection of InferenceSpecType objects as a side-input. Each + example will be joined with the most recent inference spec from this + PCollection (based on processing time). Any examples that arrive before + the first model will be buffered until a model is available. + 3. For finer control, you can run inference on a stream of query tuples: + (InferenceSpecType, Example) where each tuple specifies an example and a + model to use for inference. Internally, queries with the same inference + spec will be grouped together and inference will operate on batches of + examples. To use this api, don't provide an inference_spec_type parameter. + + Note: Options 2 and 3 above both require a Beam runner with stateful DoFn + support. You can reference the compatability matrix to determine a suitable + runner: (https://beam.apache.org/documentation/runners/capability-matrix). + TODO(b/131873699): Add support for the following features: 1. Bytes as Input. 2. PTable Input. - 3. Models as SideInput. Args: - examples: A PCollection containing examples. - inference_spec_type: Model inference endpoint. + examples: A PCollection containing examples. If inference_spec_type is + None, this is interpreted as a PCollection of queries: + (InferenceSpecType, Example) + inference_spec_type: Model inference endpoint. Can be one of: + - InferenceSpecType: specifies a fixed model to use for inference. + - PCollection[InferenceSpecType]: specifies a secondary PCollection of + models. Each example will use the most recent model for inference. + - None: indicates that the primary PCollection contains + (InferenceSpecType, Example) tuples. Returns: A PCollection containing prediction logs.