diff --git a/README.md b/README.md
index d9d873f..8141553 100644
--- a/README.md
+++ b/README.md
@@ -173,6 +173,28 @@ Alternatively, you can run the program interactively by typing image
paths in the terminal (one per line, type Control-D when you want the
model to run the input entered so far).
+# Deploying model to Tensorflow Model Serving
+
+To host models trained by this platform using Tensorflow Model Serving,
+you would first need to convert the given checkpoint into the SavedModel
+format. The SavedModel conversion script (`export_to_saved_model.py`)
+allows you to restore an existing checkpoint, and export it as the
+SavedModel, which is required format by Tensorflow Model Serving.
+
+Please edit scripts parameters and execute the following:
+
+ cd src ; python export_to_saved_model.py
+
+# Validation with Tensorflow Model Serving
+
+We have created the example script (`tf_serving_predict.py`) , which
+will use the specified image file and run it against the model hosted by
+the Tensorflow Model Serving.
+
+Please edit scripts parameters and execute the following:
+
+ cd src ; python tf_serving_predict.py
+
# Configuration
There are many command-line options to configure training
@@ -240,7 +262,7 @@ Please cite the following [paper](https://weinman.cs.grinnell.edu/pubs/weinman19
```text
@inproceedings{ weinman19deep,
- author = {Jerod Weinman and Ziwen Chen and Ben Gafford and Nathan Gifford and Abyaya Lamsal and Liam Niehus-Staab},
+ author = {Jerod Weinman and Ziwen Chen and Ben Gafford and Nathan Gifford and Abyaya Lamsal and Liam Niehus-Staab and Igor Vishnevskiy},
title = {Deep Neural Networks for Text Detection and Recognition in Historical Maps},
booktitle = {Proc. IAPR International Conference on Document Analysis and Recognition},
month = {Sep.},
diff --git a/src/export_to_saved_model.py b/src/export_to_saved_model.py
new file mode 100755
index 0000000..18a1028
--- /dev/null
+++ b/src/export_to_saved_model.py
@@ -0,0 +1,93 @@
+# CNN-LSTM-CTC-OCR
+# Copyright (C) 2017,2018 Jerod Weinman, Abyaya Lamsal, Benjamin Gafford
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+# validate.py - Run model directly on from paths to image filenames.
+# NOTE: assumes mjsynth files are given by adding an extra row of padding
+
+import tensorflow as tf
+
+import model_fn as model_fn
+import os
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string( 'ckpt_path','/media/path/to/models/',
+ """Directory for model checkpoints""" )
+tf.app.flags.DEFINE_string( 'checkpoint','model.ckpt-73152',
+ """The checkpoint to export. Just the name of the checkpoint,
+ no path needed. Exemple: 'model.ckpt-104976'""" )
+tf.app.flags.DEFINE_string( 'export_dir','/media/path/to/where/to/export/saved_model/',
+ """Path to the directory where SavedModel is to be exported""" )
+tf.app.flags.DEFINE_string( 'lexicon','',
+ """File containing lexicon of image words""" )
+tf.app.flags.DEFINE_float( 'lexicon_prior',None,
+ """Prior bias [0,1] for lexicon word""" )
+
+IMAGE = 'image'
+WIDTH = 'width'
+LABELS = 'labels'
+LENGTH = 'length'
+TEXT = 'text'
+
+
+class ExportSavedModel():
+
+ def __get_config(self):
+ """Setup session config to soften device placement"""
+
+ device_config=tf.ConfigProto(
+ allow_soft_placement=True,
+ log_device_placement=False )
+
+ custom_config = tf.estimator.RunConfig( session_config=device_config )
+
+ return custom_config
+
+ def __serving_input_receiver(self):
+ """Placeholders function required by the Estimator to export SavedModel"""
+
+ image = tf.placeholder(dtype=tf.float32, shape=[1, 32, None, 1], name=IMAGE)
+ width = tf.placeholder(dtype=tf.int32, shape=[], name=WIDTH)
+ labels = tf.placeholder(dtype=tf.float32, shape=[None], name=LABELS)
+ length = tf.placeholder(dtype=tf.float32, shape=[], name=LENGTH)
+ text = tf.placeholder(dtype=tf.string, shape=[], name=TEXT)
+ features = {IMAGE: image,
+ WIDTH: width,
+ LABELS: labels,
+ LENGTH: length,
+ TEXT: text}
+ receiver_tensor = {IMAGE: image,
+ WIDTH: width}
+ return tf.estimator.export.ServingInputReceiver(features, receiver_tensor)
+
+ def export_saved_model(self):
+ """Will restore the given checkpoint and export it into the SavedModel format"""
+
+ classifier = tf.estimator.Estimator(config=self.__get_config(),
+ model_fn=model_fn.predict_fn(
+ FLAGS.lexicon,
+ FLAGS.lexicon_prior),
+ model_dir=FLAGS.ckpt_path)
+
+ checkpoint = os.path.join(FLAGS.ckpt_path, FLAGS.checkpoint)
+
+ classifier.export_saved_model(FLAGS.export_dir,
+ serving_input_receiver_fn=self.__serving_input_receiver,
+ checkpoint_path=checkpoint)
+
+if __name__ == '__main__':
+ export = ExportSavedModel()
+ export.export_saved_model()
diff --git a/src/tf_serving_predict.py b/src/tf_serving_predict.py
new file mode 100755
index 0000000..750e0c2
--- /dev/null
+++ b/src/tf_serving_predict.py
@@ -0,0 +1,107 @@
+# CNN-LSTM-CTC-OCR
+# Copyright (C) 2017,2018 Jerod Weinman, Abyaya Lamsal, Benjamin Gafford
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+# validate.py - Run model directly on from paths to image filenames.
+# NOTE: assumes mjsynth files are given by adding an extra row of padding
+
+import numpy as np
+from PIL import Image
+
+from tensorflow import saved_model as sm
+
+import charset as charset
+import mjsynth
+
+import datetime
+
+import grpc
+import tensorflow as tf
+from tensorflow_serving.apis import predict_pb2, prediction_service_pb2_grpc
+
+global ctc_graph
+ctc_graph = tf.compat.v1.get_default_graph
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string( 'server_host','localhost:8500',
+ """IP:PORT of the Tensorflow Model Serving""" )
+tf.app.flags.DEFINE_string( 'model_spec_name','crnn',
+ """Name of the saved models batch as it is referenced in the "tf_server.conf""" )
+tf.app.flags.DEFINE_integer( 'model_version',1616119579,
+ """SavedModel version to call""" )
+tf.app.flags.DEFINE_string( 'image_path','/media/test/images/test_image.png',
+ """Image to be used as input for inference/prediction""" )
+tf.app.flags.DEFINE_integer( 'image_base_height',32,
+ """The height for all tensors to be reshaped to""" )
+
+tf.logging.set_verbosity( tf.logging.INFO )
+
+server = FLAGS.server_host
+channel = grpc.insecure_channel(server)
+stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
+
+request = predict_pb2.PredictRequest()
+request.model_spec.signature_name = sm.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+request.model_spec.name = FLAGS.model_spec_name
+request.model_spec.version.value = FLAGS.model_version
+
+IMAGE = 'image'
+WIDTH = 'width'
+LABELS = 'labels'
+
+class Predict():
+
+ def __get_image(self, filename):
+ """Load image data for placement in graph"""
+
+ pil_image = Image.open(filename)
+ width, height = pil_image.size
+ new_width = FLAGS.image_base_height * width / height
+ pil_image = pil_image.resize((int(new_width), FLAGS.image_base_height), Image.ANTIALIAS)
+ image = np.array(pil_image)
+
+ # in mjsynth, all three channels are the same in these grayscale-cum-RGB data
+ image = image[:, :, :1] # so just extract first channel, preserving 3D shape
+ image = mjsynth.preprocess_image(image)
+ image = image.eval(session=tf.compat.v1.Session())
+
+ return image
+
+ def predict(self):
+ """ Will call Tensorflow Model Server for Prediction, using gRPC API. """
+
+ print("Prediction Start Time: ", datetime.datetime.now())
+ image = self.__get_image(FLAGS.image_path)
+
+ h, w, c = image.shape
+ image = image[np.newaxis, :, :, :]
+
+ request.inputs[IMAGE].CopyFrom(tf.compat.v1.make_tensor_proto(image, dtype=None, shape=image.shape))
+ request.inputs[WIDTH].CopyFrom(tf.compat.v1.make_tensor_proto(w))
+ result = stub.Predict(request, 10)
+
+ protobuf_response = result.outputs[LABELS]
+ ndarray_response = tf.make_ndarray(protobuf_response)
+ label = charset.label_to_string(ndarray_response[0])
+
+ print("Prediction End Time: ", datetime.datetime.now())
+ print("Predicted Text: ", label)
+
+ return label
+
+if __name__ == '__main__':
+ pr = Predict()
+ pr.predict()