From c5c76f67ab916db2534c5068882708a116817ad8 Mon Sep 17 00:00:00 2001 From: Xiechang Date: Tue, 15 Aug 2017 16:41:02 +0800 Subject: [PATCH 1/3] lastbatch_err --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 35d4a4f..d800112 100644 --- a/main.py +++ b/main.py @@ -106,7 +106,7 @@ def train(train_dir=None, val_dir=None, mode='train'): model.seq_len: val_seq_len} dense_decoded, lastbatch_err, lr = \ - sess.run([model.dense_decoded, model.lrn_rate], + sess.run([model.dense_decoded, model.cost, model.lrn_rate], val_feed) # print the decode result From f03b26689966d218c8b8ecc0cec0466a95ea3f5a Mon Sep 17 00:00:00 2001 From: Xiechang Date: Tue, 15 Aug 2017 17:10:04 +0800 Subject: [PATCH 2/3] support different image shape in _build_model --- cnn_lstm_otc_ocr.py | 40 ++++++++++++++-------------------------- 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/cnn_lstm_otc_ocr.py b/cnn_lstm_otc_ocr.py index bb3a22d..c1db85c 100644 --- a/cnn_lstm_otc_ocr.py +++ b/cnn_lstm_otc_ocr.py @@ -26,33 +26,21 @@ def build_graph(self): self.merged_summay = tf.summary.merge_all() def _build_model(self): - filters = [64, 128, 128, FLAGS.max_stepsize] + filters = [1, 64, 128, 128, FLAGS.max_stepsize] strides = [1, 2] - + k_size = [3, 2] + x = self.inputs + feature_w = FLAGS.image_width + feature_h = FLAGS.image_height with tf.variable_scope('cnn'): - with tf.variable_scope('unit-1'): - x = self._conv2d(self.inputs, 'cnn-1', 3, 1, filters[0], strides[0]) - x = self._batch_norm('bn1', x) - x = self._leaky_relu(x, 0.01) - x = self._max_pool(x, 2, strides[1]) - - with tf.variable_scope('unit-2'): - x = self._conv2d(x, 'cnn-2', 3, filters[0], filters[1], strides[0]) - x = self._batch_norm('bn2', x) - x = self._leaky_relu(x, 0.01) - x = self._max_pool(x, 2, strides[1]) - - with tf.variable_scope('unit-3'): - x = self._conv2d(x, 'cnn-3', 3, filters[1], filters[2], strides[0]) - x = self._batch_norm('bn3', x) - x = self._leaky_relu(x, 0.01) - x = self._max_pool(x, 2, strides[1]) - - with tf.variable_scope('unit-4'): - x = self._conv2d(x, 'cnn-4', 3, filters[2], filters[3], strides[0]) - x = self._batch_norm('bn4', x) - x = self._leaky_relu(x, 0.01) - x = self._max_pool(x, 2, strides[1]) + for i in range(0, 4): + with tf.variable_scope('unit-%d' % (i + 1)): + x = self._conv2d(x, 'cnn-%d' % (i + 1), k_size[0], filters[i], filters[i+1], strides[0]) + x = self._batch_norm('bn%d' % (i + 1), x) + x = self._leaky_relu(x, 0.01) + x = self._max_pool(x, k_size[1], strides[1]) + feature_h = (feature_h + 1) // 2 + feature_w = (feature_w + 1) // 2 with tf.variable_scope('lstm'): # [batch_size, max_stepsize, num_features] @@ -60,7 +48,7 @@ def _build_model(self): x = tf.transpose(x, [0, 2, 1]) # batch_size * 64 * 48 # shp = x.get_shape().as_list() # x.set_shape([FLAGS.batch_size, filters[3], shp[1]]) - x.set_shape([FLAGS.batch_size, filters[3], 48]) + x.set_shape([FLAGS.batch_size, filters[3], feature_w * feature_h]) # tf.nn.rnn_cell.RNNCell, tf.nn.rnn_cell.GRUCell cell = tf.contrib.rnn.LSTMCell(FLAGS.num_hidden, state_is_tuple=True) From 8940718471137f9c7e5ed15ac91a00899d755b9d Mon Sep 17 00:00:00 2001 From: Xiechang Date: Tue, 15 Aug 2017 17:11:24 +0800 Subject: [PATCH 3/3] fix bugs --- cnn_lstm_otc_ocr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cnn_lstm_otc_ocr.py b/cnn_lstm_otc_ocr.py index c1db85c..b091b47 100644 --- a/cnn_lstm_otc_ocr.py +++ b/cnn_lstm_otc_ocr.py @@ -44,11 +44,11 @@ def _build_model(self): with tf.variable_scope('lstm'): # [batch_size, max_stepsize, num_features] - x = tf.reshape(x, [FLAGS.batch_size, -1, filters[3]]) + x = tf.reshape(x, [FLAGS.batch_size, -1, filters[4]]) x = tf.transpose(x, [0, 2, 1]) # batch_size * 64 * 48 # shp = x.get_shape().as_list() # x.set_shape([FLAGS.batch_size, filters[3], shp[1]]) - x.set_shape([FLAGS.batch_size, filters[3], feature_w * feature_h]) + x.set_shape([FLAGS.batch_size, filters[4], feature_w * feature_h]) # tf.nn.rnn_cell.RNNCell, tf.nn.rnn_cell.GRUCell cell = tf.contrib.rnn.LSTMCell(FLAGS.num_hidden, state_is_tuple=True)