|
27 | 27 |
|
28 | 28 | import apache_beam as beam
|
29 | 29 | from apache_beam.metrics.metric import MetricsFilter
|
| 30 | +from apache_beam.options import pipeline_options |
30 | 31 | from apache_beam.testing.util import assert_that
|
31 | 32 | from apache_beam.testing.util import equal_to
|
32 | 33 | from googleapiclient import discovery
|
@@ -70,6 +71,16 @@ def _prepare_predict_examples(self, example_path):
|
70 | 71 | for example in self._predict_examples:
|
71 | 72 | output_file.write(example.SerializeToString())
|
72 | 73 |
|
| 74 | + def _get_results(self, prediction_log_path): |
| 75 | + results = [] |
| 76 | + for f in tf.io.gfile.glob(prediction_log_path + '-?????-of-?????'): |
| 77 | + record_iterator = tf.compat.v1.io.tf_record_iterator(path=f) |
| 78 | + for record_string in record_iterator: |
| 79 | + prediction_log = prediction_log_pb2.PredictionLog() |
| 80 | + prediction_log.MergeFromString(record_string) |
| 81 | + results.append(prediction_log) |
| 82 | + return results |
| 83 | + |
73 | 84 |
|
74 | 85 | class RunOfflineInferenceTest(RunInferenceFixture):
|
75 | 86 |
|
@@ -219,16 +230,6 @@ def _run_inference_with_beam(self, example_path, inference_spec_type,
|
219 | 230 | prediction_log_path,
|
220 | 231 | coder=beam.coders.ProtoCoder(prediction_log_pb2.PredictionLog)))
|
221 | 232 |
|
222 |
| - def _get_results(self, prediction_log_path): |
223 |
| - results = [] |
224 |
| - for f in tf.io.gfile.glob(prediction_log_path + '-?????-of-?????'): |
225 |
| - record_iterator = tf.compat.v1.io.tf_record_iterator(path=f) |
226 |
| - for record_string in record_iterator: |
227 |
| - prediction_log = prediction_log_pb2.PredictionLog() |
228 |
| - prediction_log.MergeFromString(record_string) |
229 |
| - results.append(prediction_log) |
230 |
| - return results |
231 |
| - |
232 | 233 | def testModelPathInvalid(self):
|
233 | 234 | example_path = self._get_output_data_dir('examples')
|
234 | 235 | self._prepare_predict_examples(example_path)
|
@@ -616,5 +617,128 @@ def test_request_body_with_binary_data(self):
|
616 | 617 | ], result)
|
617 | 618 |
|
618 | 619 |
|
| 620 | +class RunInferenceCoreTest(RunInferenceFixture): |
| 621 | + |
| 622 | + def _build_keras_model(self, add): |
| 623 | + """Builds a dummy keras model with one input and output.""" |
| 624 | + inp = tf.keras.layers.Input((1,), name='input') |
| 625 | + out = tf.keras.layers.Lambda(lambda x: x + add)(inp) |
| 626 | + m = tf.keras.models.Model(inp, out) |
| 627 | + return m |
| 628 | + |
| 629 | + def _new_model(self, model_path, add): |
| 630 | + """Exports a keras model in the SavedModel format.""" |
| 631 | + class WrapKerasModel(tf.keras.Model): |
| 632 | + """Wrapper class to apply a signature to a keras model.""" |
| 633 | + def __init__(self, model): |
| 634 | + super().__init__() |
| 635 | + self.model = model |
| 636 | + |
| 637 | + @tf.function(input_signature=[ |
| 638 | + tf.TensorSpec(shape=[None], dtype=tf.string, name='inputs') |
| 639 | + ]) |
| 640 | + def call(self, serialized_example): |
| 641 | + features = { |
| 642 | + 'input': tf.compat.v1.io.FixedLenFeature( |
| 643 | + [1], |
| 644 | + dtype=tf.float32, |
| 645 | + default_value=0 |
| 646 | + ) |
| 647 | + } |
| 648 | + input_tensor_dict = tf.io.parse_example(serialized_example, features) |
| 649 | + return self.model(input_tensor_dict) |
| 650 | + |
| 651 | + model = self._build_keras_model(add) |
| 652 | + wrapped_model = WrapKerasModel(model) |
| 653 | + tf.compat.v1.keras.experimental.export_saved_model( |
| 654 | + wrapped_model, model_path, serving_only=True |
| 655 | + ) |
| 656 | + return self._get_saved_model_spec(model_path) |
| 657 | + |
| 658 | + def _decode_value(self, pl): |
| 659 | + """Returns output value from prediction log.""" |
| 660 | + out_tensor = pl.predict_log.response.outputs['output_1'] |
| 661 | + arr = tf.make_ndarray(out_tensor) |
| 662 | + x = arr[0][0] |
| 663 | + return x |
| 664 | + |
| 665 | + def _make_example(self, x): |
| 666 | + """Builds a TFExample object with a single value.""" |
| 667 | + feature = {} |
| 668 | + feature['input'] = tf.train.Feature( |
| 669 | + float_list=tf.train.FloatList(value=[x])) |
| 670 | + ex = tf.train.Example(features=tf.train.Features(feature=feature)) |
| 671 | + return ex |
| 672 | + |
| 673 | + def _get_saved_model_spec(self, model_path): |
| 674 | + """Returns an InferenceSpecType object for a saved model path.""" |
| 675 | + return model_spec_pb2.InferenceSpecType( |
| 676 | + saved_model_spec=model_spec_pb2.SavedModelSpec( |
| 677 | + model_path=model_path)) |
| 678 | + |
| 679 | + def test_batch_queries_single_model(self): |
| 680 | + spec = self._get_saved_model_spec('/example/model') |
| 681 | + QUERIES = [(spec, self._make_example(i)) for i in range(100)] |
| 682 | + CORRECT = {example.SerializeToString(): spec for spec, example in QUERIES} |
| 683 | + |
| 684 | + def _check_batch(batch): |
| 685 | + """Assert examples are grouped with the correct inference spec.""" |
| 686 | + spec, examples = batch |
| 687 | + assert all([CORRECT[x.SerializeToString()] == spec for x in examples]) |
| 688 | + |
| 689 | + with beam.Pipeline() as p: |
| 690 | + queries = p | 'Build queries' >> beam.Create(QUERIES) |
| 691 | + batches = queries | '_BatchQueries' >> run_inference._BatchQueries() |
| 692 | + |
| 693 | + _ = batches | 'Check' >> beam.Map(_check_batch) |
| 694 | + |
| 695 | + # TODO(hgarrereyn): Switch _BatchElements to use GroupIntoBatches once |
| 696 | + # BEAM-2717 is fixed so examples are grouped by inference spec key. |
| 697 | + # |
| 698 | + # def test_batch_queries_multiple_models(self): |
| 699 | + # spec1 = self._get_saved_model_spec('/example/model1') |
| 700 | + # spec2 = self._get_saved_model_spec('/example/model2') |
| 701 | + # |
| 702 | + # QUERIES = [] |
| 703 | + # for i in range(100): |
| 704 | + # QUERIES.append((spec1 if i % 2 == 0 else spec2, self._make_example(i))) |
| 705 | + # |
| 706 | + # CORRECT = {example.SerializeToString(): spec for spec, example in QUERIES} |
| 707 | + # |
| 708 | + # def _check_batch(batch): |
| 709 | + # """Assert examples are grouped with the correct inference spec.""" |
| 710 | + # spec, examples = batch |
| 711 | + # assert all([CORRECT[x.SerializeToString()] == spec for x in examples]) |
| 712 | + # |
| 713 | + # with beam.Pipeline() as p: |
| 714 | + # queries = p | 'Build queries' >> beam.Create(QUERIES) |
| 715 | + # batches = queries | '_BatchQueries' >> run_inference._BatchQueries() |
| 716 | + # |
| 717 | + # _ = batches | 'Check' >> beam.Map(_check_batch) |
| 718 | + |
| 719 | + def test_inference_on_queries(self): |
| 720 | + spec = self._new_model(self._get_output_data_dir('model1'), 100) |
| 721 | + predictions_path = self._get_output_data_dir('predictions') |
| 722 | + QUERIES = [(spec, self._make_example(i)) for i in range(10)] |
| 723 | + |
| 724 | + options = pipeline_options.PipelineOptions(streaming=False) |
| 725 | + with beam.Pipeline(options=options) as p: |
| 726 | + _ = ( |
| 727 | + p |
| 728 | + | 'Queries' >> beam.Create(QUERIES) \ |
| 729 | + | '_RunInferenceCore' >> run_inference._RunInferenceCore() \ |
| 730 | + | 'WritePredictions' >> beam.io.WriteToTFRecord( |
| 731 | + predictions_path, |
| 732 | + coder=beam.coders.ProtoCoder(prediction_log_pb2.PredictionLog)) |
| 733 | + ) |
| 734 | + |
| 735 | + results = self._get_results(predictions_path) |
| 736 | + values = [int(self._decode_value(x)) for x in results] |
| 737 | + self.assertEqual( |
| 738 | + values, |
| 739 | + [100,101,102,103,104,105,106,107,108,109] |
| 740 | + ) |
| 741 | + |
| 742 | + |
619 | 743 | if __name__ == '__main__':
|
620 | 744 | tf.test.main()
|
0 commit comments