Skip to content

Commit 1df33ac

Browse files
author
hgarrereyn
committed
add tests for _BatchQueries and _RunInferenceCore
1 parent 0550177 commit 1df33ac

File tree

1 file changed

+134
-10
lines changed

1 file changed

+134
-10
lines changed

tfx_bsl/beam/run_inference_test.py

Lines changed: 134 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import apache_beam as beam
2929
from apache_beam.metrics.metric import MetricsFilter
30+
from apache_beam.options import pipeline_options
3031
from apache_beam.testing.util import assert_that
3132
from apache_beam.testing.util import equal_to
3233
from googleapiclient import discovery
@@ -70,6 +71,16 @@ def _prepare_predict_examples(self, example_path):
7071
for example in self._predict_examples:
7172
output_file.write(example.SerializeToString())
7273

74+
def _get_results(self, prediction_log_path):
75+
results = []
76+
for f in tf.io.gfile.glob(prediction_log_path + '-?????-of-?????'):
77+
record_iterator = tf.compat.v1.io.tf_record_iterator(path=f)
78+
for record_string in record_iterator:
79+
prediction_log = prediction_log_pb2.PredictionLog()
80+
prediction_log.MergeFromString(record_string)
81+
results.append(prediction_log)
82+
return results
83+
7384

7485
class RunOfflineInferenceTest(RunInferenceFixture):
7586

@@ -219,16 +230,6 @@ def _run_inference_with_beam(self, example_path, inference_spec_type,
219230
prediction_log_path,
220231
coder=beam.coders.ProtoCoder(prediction_log_pb2.PredictionLog)))
221232

222-
def _get_results(self, prediction_log_path):
223-
results = []
224-
for f in tf.io.gfile.glob(prediction_log_path + '-?????-of-?????'):
225-
record_iterator = tf.compat.v1.io.tf_record_iterator(path=f)
226-
for record_string in record_iterator:
227-
prediction_log = prediction_log_pb2.PredictionLog()
228-
prediction_log.MergeFromString(record_string)
229-
results.append(prediction_log)
230-
return results
231-
232233
def testModelPathInvalid(self):
233234
example_path = self._get_output_data_dir('examples')
234235
self._prepare_predict_examples(example_path)
@@ -616,5 +617,128 @@ def test_request_body_with_binary_data(self):
616617
], result)
617618

618619

620+
class RunInferenceCoreTest(RunInferenceFixture):
621+
622+
def _build_keras_model(self, add):
623+
"""Builds a dummy keras model with one input and output."""
624+
inp = tf.keras.layers.Input((1,), name='input')
625+
out = tf.keras.layers.Lambda(lambda x: x + add)(inp)
626+
m = tf.keras.models.Model(inp, out)
627+
return m
628+
629+
def _new_model(self, model_path, add):
630+
"""Exports a keras model in the SavedModel format."""
631+
class WrapKerasModel(tf.keras.Model):
632+
"""Wrapper class to apply a signature to a keras model."""
633+
def __init__(self, model):
634+
super().__init__()
635+
self.model = model
636+
637+
@tf.function(input_signature=[
638+
tf.TensorSpec(shape=[None], dtype=tf.string, name='inputs')
639+
])
640+
def call(self, serialized_example):
641+
features = {
642+
'input': tf.compat.v1.io.FixedLenFeature(
643+
[1],
644+
dtype=tf.float32,
645+
default_value=0
646+
)
647+
}
648+
input_tensor_dict = tf.io.parse_example(serialized_example, features)
649+
return self.model(input_tensor_dict)
650+
651+
model = self._build_keras_model(add)
652+
wrapped_model = WrapKerasModel(model)
653+
tf.compat.v1.keras.experimental.export_saved_model(
654+
wrapped_model, model_path, serving_only=True
655+
)
656+
return self._get_saved_model_spec(model_path)
657+
658+
def _decode_value(self, pl):
659+
"""Returns output value from prediction log."""
660+
out_tensor = pl.predict_log.response.outputs['output_1']
661+
arr = tf.make_ndarray(out_tensor)
662+
x = arr[0][0]
663+
return x
664+
665+
def _make_example(self, x):
666+
"""Builds a TFExample object with a single value."""
667+
feature = {}
668+
feature['input'] = tf.train.Feature(
669+
float_list=tf.train.FloatList(value=[x]))
670+
ex = tf.train.Example(features=tf.train.Features(feature=feature))
671+
return ex
672+
673+
def _get_saved_model_spec(self, model_path):
674+
"""Returns an InferenceSpecType object for a saved model path."""
675+
return model_spec_pb2.InferenceSpecType(
676+
saved_model_spec=model_spec_pb2.SavedModelSpec(
677+
model_path=model_path))
678+
679+
def test_batch_queries_single_model(self):
680+
spec = self._get_saved_model_spec('/example/model')
681+
QUERIES = [(spec, self._make_example(i)) for i in range(100)]
682+
CORRECT = {example.SerializeToString(): spec for spec, example in QUERIES}
683+
684+
def _check_batch(batch):
685+
"""Assert examples are grouped with the correct inference spec."""
686+
spec, examples = batch
687+
assert all([CORRECT[x.SerializeToString()] == spec for x in examples])
688+
689+
with beam.Pipeline() as p:
690+
queries = p | 'Build queries' >> beam.Create(QUERIES)
691+
batches = queries | '_BatchQueries' >> run_inference._BatchQueries()
692+
693+
_ = batches | 'Check' >> beam.Map(_check_batch)
694+
695+
# TODO(hgarrereyn): Switch _BatchElements to use GroupIntoBatches once
696+
# BEAM-2717 is fixed so examples are grouped by inference spec key.
697+
#
698+
# def test_batch_queries_multiple_models(self):
699+
# spec1 = self._get_saved_model_spec('/example/model1')
700+
# spec2 = self._get_saved_model_spec('/example/model2')
701+
#
702+
# QUERIES = []
703+
# for i in range(100):
704+
# QUERIES.append((spec1 if i % 2 == 0 else spec2, self._make_example(i)))
705+
#
706+
# CORRECT = {example.SerializeToString(): spec for spec, example in QUERIES}
707+
#
708+
# def _check_batch(batch):
709+
# """Assert examples are grouped with the correct inference spec."""
710+
# spec, examples = batch
711+
# assert all([CORRECT[x.SerializeToString()] == spec for x in examples])
712+
#
713+
# with beam.Pipeline() as p:
714+
# queries = p | 'Build queries' >> beam.Create(QUERIES)
715+
# batches = queries | '_BatchQueries' >> run_inference._BatchQueries()
716+
#
717+
# _ = batches | 'Check' >> beam.Map(_check_batch)
718+
719+
def test_inference_on_queries(self):
720+
spec = self._new_model(self._get_output_data_dir('model1'), 100)
721+
predictions_path = self._get_output_data_dir('predictions')
722+
QUERIES = [(spec, self._make_example(i)) for i in range(10)]
723+
724+
options = pipeline_options.PipelineOptions(streaming=False)
725+
with beam.Pipeline(options=options) as p:
726+
_ = (
727+
p
728+
| 'Queries' >> beam.Create(QUERIES) \
729+
| '_RunInferenceCore' >> run_inference._RunInferenceCore() \
730+
| 'WritePredictions' >> beam.io.WriteToTFRecord(
731+
predictions_path,
732+
coder=beam.coders.ProtoCoder(prediction_log_pb2.PredictionLog))
733+
)
734+
735+
results = self._get_results(predictions_path)
736+
values = [int(self._decode_value(x)) for x in results]
737+
self.assertEqual(
738+
values,
739+
[100,101,102,103,104,105,106,107,108,109]
740+
)
741+
742+
619743
if __name__ == '__main__':
620744
tf.test.main()

0 commit comments

Comments
 (0)