diff --git a/README.md b/README.md index d9d873f..8141553 100644 --- a/README.md +++ b/README.md @@ -173,6 +173,28 @@ Alternatively, you can run the program interactively by typing image paths in the terminal (one per line, type Control-D when you want the model to run the input entered so far). +# Deploying model to Tensorflow Model Serving + +To host models trained by this platform using Tensorflow Model Serving, +you would first need to convert the given checkpoint into the SavedModel +format. The SavedModel conversion script (`export_to_saved_model.py`) +allows you to restore an existing checkpoint, and export it as the +SavedModel, which is required format by Tensorflow Model Serving. + +Please edit scripts parameters and execute the following: + + cd src ; python export_to_saved_model.py + +# Validation with Tensorflow Model Serving + +We have created the example script (`tf_serving_predict.py`) , which +will use the specified image file and run it against the model hosted by +the Tensorflow Model Serving. + +Please edit scripts parameters and execute the following: + + cd src ; python tf_serving_predict.py + # Configuration There are many command-line options to configure training @@ -240,7 +262,7 @@ Please cite the following [paper](https://weinman.cs.grinnell.edu/pubs/weinman19 ```text @inproceedings{ weinman19deep, - author = {Jerod Weinman and Ziwen Chen and Ben Gafford and Nathan Gifford and Abyaya Lamsal and Liam Niehus-Staab}, + author = {Jerod Weinman and Ziwen Chen and Ben Gafford and Nathan Gifford and Abyaya Lamsal and Liam Niehus-Staab and Igor Vishnevskiy}, title = {Deep Neural Networks for Text Detection and Recognition in Historical Maps}, booktitle = {Proc. IAPR International Conference on Document Analysis and Recognition}, month = {Sep.}, diff --git a/src/export_to_saved_model.py b/src/export_to_saved_model.py new file mode 100755 index 0000000..18a1028 --- /dev/null +++ b/src/export_to_saved_model.py @@ -0,0 +1,93 @@ +# CNN-LSTM-CTC-OCR +# Copyright (C) 2017,2018 Jerod Weinman, Abyaya Lamsal, Benjamin Gafford +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +# validate.py - Run model directly on from paths to image filenames. +# NOTE: assumes mjsynth files are given by adding an extra row of padding + +import tensorflow as tf + +import model_fn as model_fn +import os + +FLAGS = tf.app.flags.FLAGS + +tf.app.flags.DEFINE_string( 'ckpt_path','/media/path/to/models/', + """Directory for model checkpoints""" ) +tf.app.flags.DEFINE_string( 'checkpoint','model.ckpt-73152', + """The checkpoint to export. Just the name of the checkpoint, + no path needed. Exemple: 'model.ckpt-104976'""" ) +tf.app.flags.DEFINE_string( 'export_dir','/media/path/to/where/to/export/saved_model/', + """Path to the directory where SavedModel is to be exported""" ) +tf.app.flags.DEFINE_string( 'lexicon','', + """File containing lexicon of image words""" ) +tf.app.flags.DEFINE_float( 'lexicon_prior',None, + """Prior bias [0,1] for lexicon word""" ) + +IMAGE = 'image' +WIDTH = 'width' +LABELS = 'labels' +LENGTH = 'length' +TEXT = 'text' + + +class ExportSavedModel(): + + def __get_config(self): + """Setup session config to soften device placement""" + + device_config=tf.ConfigProto( + allow_soft_placement=True, + log_device_placement=False ) + + custom_config = tf.estimator.RunConfig( session_config=device_config ) + + return custom_config + + def __serving_input_receiver(self): + """Placeholders function required by the Estimator to export SavedModel""" + + image = tf.placeholder(dtype=tf.float32, shape=[1, 32, None, 1], name=IMAGE) + width = tf.placeholder(dtype=tf.int32, shape=[], name=WIDTH) + labels = tf.placeholder(dtype=tf.float32, shape=[None], name=LABELS) + length = tf.placeholder(dtype=tf.float32, shape=[], name=LENGTH) + text = tf.placeholder(dtype=tf.string, shape=[], name=TEXT) + features = {IMAGE: image, + WIDTH: width, + LABELS: labels, + LENGTH: length, + TEXT: text} + receiver_tensor = {IMAGE: image, + WIDTH: width} + return tf.estimator.export.ServingInputReceiver(features, receiver_tensor) + + def export_saved_model(self): + """Will restore the given checkpoint and export it into the SavedModel format""" + + classifier = tf.estimator.Estimator(config=self.__get_config(), + model_fn=model_fn.predict_fn( + FLAGS.lexicon, + FLAGS.lexicon_prior), + model_dir=FLAGS.ckpt_path) + + checkpoint = os.path.join(FLAGS.ckpt_path, FLAGS.checkpoint) + + classifier.export_saved_model(FLAGS.export_dir, + serving_input_receiver_fn=self.__serving_input_receiver, + checkpoint_path=checkpoint) + +if __name__ == '__main__': + export = ExportSavedModel() + export.export_saved_model() diff --git a/src/tf_serving_predict.py b/src/tf_serving_predict.py new file mode 100755 index 0000000..750e0c2 --- /dev/null +++ b/src/tf_serving_predict.py @@ -0,0 +1,107 @@ +# CNN-LSTM-CTC-OCR +# Copyright (C) 2017,2018 Jerod Weinman, Abyaya Lamsal, Benjamin Gafford +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +# validate.py - Run model directly on from paths to image filenames. +# NOTE: assumes mjsynth files are given by adding an extra row of padding + +import numpy as np +from PIL import Image + +from tensorflow import saved_model as sm + +import charset as charset +import mjsynth + +import datetime + +import grpc +import tensorflow as tf +from tensorflow_serving.apis import predict_pb2, prediction_service_pb2_grpc + +global ctc_graph +ctc_graph = tf.compat.v1.get_default_graph + +FLAGS = tf.app.flags.FLAGS + +tf.app.flags.DEFINE_string( 'server_host','localhost:8500', + """IP:PORT of the Tensorflow Model Serving""" ) +tf.app.flags.DEFINE_string( 'model_spec_name','crnn', + """Name of the saved models batch as it is referenced in the "tf_server.conf""" ) +tf.app.flags.DEFINE_integer( 'model_version',1616119579, + """SavedModel version to call""" ) +tf.app.flags.DEFINE_string( 'image_path','/media/test/images/test_image.png', + """Image to be used as input for inference/prediction""" ) +tf.app.flags.DEFINE_integer( 'image_base_height',32, + """The height for all tensors to be reshaped to""" ) + +tf.logging.set_verbosity( tf.logging.INFO ) + +server = FLAGS.server_host +channel = grpc.insecure_channel(server) +stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) + +request = predict_pb2.PredictRequest() +request.model_spec.signature_name = sm.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY +request.model_spec.name = FLAGS.model_spec_name +request.model_spec.version.value = FLAGS.model_version + +IMAGE = 'image' +WIDTH = 'width' +LABELS = 'labels' + +class Predict(): + + def __get_image(self, filename): + """Load image data for placement in graph""" + + pil_image = Image.open(filename) + width, height = pil_image.size + new_width = FLAGS.image_base_height * width / height + pil_image = pil_image.resize((int(new_width), FLAGS.image_base_height), Image.ANTIALIAS) + image = np.array(pil_image) + + # in mjsynth, all three channels are the same in these grayscale-cum-RGB data + image = image[:, :, :1] # so just extract first channel, preserving 3D shape + image = mjsynth.preprocess_image(image) + image = image.eval(session=tf.compat.v1.Session()) + + return image + + def predict(self): + """ Will call Tensorflow Model Server for Prediction, using gRPC API. """ + + print("Prediction Start Time: ", datetime.datetime.now()) + image = self.__get_image(FLAGS.image_path) + + h, w, c = image.shape + image = image[np.newaxis, :, :, :] + + request.inputs[IMAGE].CopyFrom(tf.compat.v1.make_tensor_proto(image, dtype=None, shape=image.shape)) + request.inputs[WIDTH].CopyFrom(tf.compat.v1.make_tensor_proto(w)) + result = stub.Predict(request, 10) + + protobuf_response = result.outputs[LABELS] + ndarray_response = tf.make_ndarray(protobuf_response) + label = charset.label_to_string(ndarray_response[0]) + + print("Prediction End Time: ", datetime.datetime.now()) + print("Predicted Text: ", label) + + return label + +if __name__ == '__main__': + pr = Predict() + pr.predict()