Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,28 @@ Alternatively, you can run the program interactively by typing image
paths in the terminal (one per line, type Control-D when you want the
model to run the input entered so far).

# Deploying model to Tensorflow Model Serving

To host models trained by this platform using Tensorflow Model Serving,
you would first need to convert the given checkpoint into the SavedModel
format. The SavedModel conversion script (`export_to_saved_model.py`)
allows you to restore an existing checkpoint, and export it as the
SavedModel, which is required format by Tensorflow Model Serving.

Please edit scripts parameters and execute the following:

cd src ; python export_to_saved_model.py

# Validation with Tensorflow Model Serving

We have created the example script (`tf_serving_predict.py`) , which
will use the specified image file and run it against the model hosted by
the Tensorflow Model Serving.

Please edit scripts parameters and execute the following:

cd src ; python tf_serving_predict.py

# Configuration

There are many command-line options to configure training
Expand Down Expand Up @@ -240,7 +262,7 @@ Please cite the following [paper](https://weinman.cs.grinnell.edu/pubs/weinman19

```text
@inproceedings{ weinman19deep,
author = {Jerod Weinman and Ziwen Chen and Ben Gafford and Nathan Gifford and Abyaya Lamsal and Liam Niehus-Staab},
author = {Jerod Weinman and Ziwen Chen and Ben Gafford and Nathan Gifford and Abyaya Lamsal and Liam Niehus-Staab and Igor Vishnevskiy},
title = {Deep Neural Networks for Text Detection and Recognition in Historical Maps},
booktitle = {Proc. IAPR International Conference on Document Analysis and Recognition},
month = {Sep.},
Expand Down
93 changes: 93 additions & 0 deletions src/export_to_saved_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# CNN-LSTM-CTC-OCR
# Copyright (C) 2017,2018 Jerod Weinman, Abyaya Lamsal, Benjamin Gafford
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

# validate.py - Run model directly on from paths to image filenames.
# NOTE: assumes mjsynth files are given by adding an extra row of padding

import tensorflow as tf

import model_fn as model_fn
import os

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string( 'ckpt_path','/media/path/to/models/',
"""Directory for model checkpoints""" )
tf.app.flags.DEFINE_string( 'checkpoint','model.ckpt-73152',
"""The checkpoint to export. Just the name of the checkpoint,
no path needed. Exemple: 'model.ckpt-104976'""" )
tf.app.flags.DEFINE_string( 'export_dir','/media/path/to/where/to/export/saved_model/',
"""Path to the directory where SavedModel is to be exported""" )
tf.app.flags.DEFINE_string( 'lexicon','',
"""File containing lexicon of image words""" )
tf.app.flags.DEFINE_float( 'lexicon_prior',None,
"""Prior bias [0,1] for lexicon word""" )

IMAGE = 'image'
WIDTH = 'width'
LABELS = 'labels'
LENGTH = 'length'
TEXT = 'text'


class ExportSavedModel():

def __get_config(self):
"""Setup session config to soften device placement"""

device_config=tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=False )

custom_config = tf.estimator.RunConfig( session_config=device_config )

return custom_config

def __serving_input_receiver(self):
"""Placeholders function required by the Estimator to export SavedModel"""

image = tf.placeholder(dtype=tf.float32, shape=[1, 32, None, 1], name=IMAGE)
width = tf.placeholder(dtype=tf.int32, shape=[], name=WIDTH)
labels = tf.placeholder(dtype=tf.float32, shape=[None], name=LABELS)
length = tf.placeholder(dtype=tf.float32, shape=[], name=LENGTH)
text = tf.placeholder(dtype=tf.string, shape=[], name=TEXT)
features = {IMAGE: image,
WIDTH: width,
LABELS: labels,
LENGTH: length,
TEXT: text}
receiver_tensor = {IMAGE: image,
WIDTH: width}
return tf.estimator.export.ServingInputReceiver(features, receiver_tensor)

def export_saved_model(self):
"""Will restore the given checkpoint and export it into the SavedModel format"""

classifier = tf.estimator.Estimator(config=self.__get_config(),
model_fn=model_fn.predict_fn(
FLAGS.lexicon,
FLAGS.lexicon_prior),
model_dir=FLAGS.ckpt_path)

checkpoint = os.path.join(FLAGS.ckpt_path, FLAGS.checkpoint)

classifier.export_saved_model(FLAGS.export_dir,
serving_input_receiver_fn=self.__serving_input_receiver,
checkpoint_path=checkpoint)

if __name__ == '__main__':
export = ExportSavedModel()
export.export_saved_model()
107 changes: 107 additions & 0 deletions src/tf_serving_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# CNN-LSTM-CTC-OCR
# Copyright (C) 2017,2018 Jerod Weinman, Abyaya Lamsal, Benjamin Gafford
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

# validate.py - Run model directly on from paths to image filenames.
# NOTE: assumes mjsynth files are given by adding an extra row of padding

import numpy as np
from PIL import Image

from tensorflow import saved_model as sm

import charset as charset
import mjsynth

import datetime

import grpc
import tensorflow as tf
from tensorflow_serving.apis import predict_pb2, prediction_service_pb2_grpc

global ctc_graph
ctc_graph = tf.compat.v1.get_default_graph

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string( 'server_host','localhost:8500',
"""IP:PORT of the Tensorflow Model Serving""" )
tf.app.flags.DEFINE_string( 'model_spec_name','crnn',
"""Name of the saved models batch as it is referenced in the "tf_server.conf""" )
tf.app.flags.DEFINE_integer( 'model_version',1616119579,
"""SavedModel version to call""" )
tf.app.flags.DEFINE_string( 'image_path','/media/test/images/test_image.png',
"""Image to be used as input for inference/prediction""" )
tf.app.flags.DEFINE_integer( 'image_base_height',32,
"""The height for all tensors to be reshaped to""" )

tf.logging.set_verbosity( tf.logging.INFO )

server = FLAGS.server_host
channel = grpc.insecure_channel(server)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)

request = predict_pb2.PredictRequest()
request.model_spec.signature_name = sm.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
request.model_spec.name = FLAGS.model_spec_name
request.model_spec.version.value = FLAGS.model_version

IMAGE = 'image'
WIDTH = 'width'
LABELS = 'labels'

class Predict():

def __get_image(self, filename):
"""Load image data for placement in graph"""

pil_image = Image.open(filename)
width, height = pil_image.size
new_width = FLAGS.image_base_height * width / height
pil_image = pil_image.resize((int(new_width), FLAGS.image_base_height), Image.ANTIALIAS)
image = np.array(pil_image)

# in mjsynth, all three channels are the same in these grayscale-cum-RGB data
image = image[:, :, :1] # so just extract first channel, preserving 3D shape
image = mjsynth.preprocess_image(image)
image = image.eval(session=tf.compat.v1.Session())

return image

def predict(self):
""" Will call Tensorflow Model Server for Prediction, using gRPC API. """

print("Prediction Start Time: ", datetime.datetime.now())
image = self.__get_image(FLAGS.image_path)

h, w, c = image.shape
image = image[np.newaxis, :, :, :]

request.inputs[IMAGE].CopyFrom(tf.compat.v1.make_tensor_proto(image, dtype=None, shape=image.shape))
request.inputs[WIDTH].CopyFrom(tf.compat.v1.make_tensor_proto(w))
result = stub.Predict(request, 10)

protobuf_response = result.outputs[LABELS]
ndarray_response = tf.make_ndarray(protobuf_response)
label = charset.label_to_string(ndarray_response[0])

print("Prediction End Time: ", datetime.datetime.now())
print("Predicted Text: ", label)

return label

if __name__ == '__main__':
pr = Predict()
pr.predict()